From f3d069175cd60b98adf3113e68f2b7c3d1039b6f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Aug 2023 13:06:24 -0700 Subject: [PATCH] Add C++ converters for C Text Classifier API PiperOrigin-RevId: 559519880 --- mediapipe/tasks/c/components/containers/BUILD | 9 +- .../tasks/c/components/containers/category.cc | 30 ++++++ .../tasks/c/components/containers/category.h | 15 ++- .../containers/classification_result.cc | 57 +++++++++++ .../containers/classification_result.h | 15 ++- mediapipe/tasks/c/components/processors/BUILD | 2 + .../processors/classifier_options.cc | 42 +++++++++ .../processors/classifier_options.h | 10 ++ mediapipe/tasks/c/core/BUILD | 2 + mediapipe/tasks/c/core/base_options.cc | 29 ++++++ mediapipe/tasks/c/core/base_options.h | 13 +++ mediapipe/tasks/c/text/text_classifier/BUILD | 2 + .../c/text/text_classifier/text_classifier.cc | 94 +++++++++++++++++++ .../c/text/text_classifier/text_classifier.h | 7 +- 14 files changed, 320 insertions(+), 7 deletions(-) create mode 100644 mediapipe/tasks/c/components/containers/category.cc create mode 100644 mediapipe/tasks/c/components/containers/classification_result.cc create mode 100644 mediapipe/tasks/c/components/processors/classifier_options.cc create mode 100644 mediapipe/tasks/c/core/base_options.cc create mode 100644 mediapipe/tasks/c/text/text_classifier/text_classifier.cc diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 4d1f190bb..0f55d18d7 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -1,5 +1,3 @@ -# TODO: describe this package. - # Copyright 2022 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,10 +18,17 @@ licenses(["notice"]) cc_library( name = "category", + srcs = ["category.cc"], hdrs = ["category.h"], + deps = ["//mediapipe/tasks/cc/components/containers:category"], ) cc_library( name = "classification_result", + srcs = ["classification_result.cc"], hdrs = ["classification_result.h"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers:classification_result", + ], ) diff --git a/mediapipe/tasks/c/components/containers/category.cc b/mediapipe/tasks/c/components/containers/category.cc new file mode 100644 index 000000000..2311f6372 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category.cc @@ -0,0 +1,30 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/category.h" + +namespace mediapie::tasks::c::components::containers { + +void CppConvertToCategory(mediapipe::tasks::components::containers::Category in, + Category* out) { + out->index = in.index; + out->score = in.score; + out->category_name = + in.category_name.has_value() ? in.category_name->c_str() : nullptr; + out->display_name = + in.display_name.has_value() ? in.display_name->c_str() : nullptr; +} + +} // namespace mediapie::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/category.h b/mediapipe/tasks/c/components/containers/category.h index 565dd65fe..c83140af6 100644 --- a/mediapipe/tasks/c/components/containers/category.h +++ b/mediapipe/tasks/c/components/containers/category.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ #define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ +#include "mediapipe/tasks/cc/components/containers/category.h" + +extern "C" { // Defines a single classification result. // // The label maps packed into the TFLite Model Metadata [1] are used to populate @@ -32,11 +35,19 @@ struct Category { // The optional ID for the category, read from the label map packed in the // TFLite Model Metadata if present. Not necessarily human-readable. - char* category_name; + const char* category_name; // The optional human-readable name for the category, read from the label map // packed in the TFLite Model Metadata if present. - char* display_name; + const char* display_name; }; +} + +namespace mediapie::tasks::c::components::containers { + +void CppConvertToCategory(mediapipe::tasks::components::containers::Category in, + Category* out); + +} // namespace mediapie::tasks::c::components::containers #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ diff --git a/mediapipe/tasks/c/components/containers/classification_result.cc b/mediapipe/tasks/c/components/containers/classification_result.cc new file mode 100644 index 000000000..4e6b1036e --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result.cc @@ -0,0 +1,57 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/classification_result.h" + +#include "mediapipe/tasks/c/components/containers/category.h" + +namespace mediapipe::tasks::c::components::containers { + +namespace { +using mediapie::tasks::c::components::containers::CppConvertToCategory; +} // namespace + +void CppConvertToClassificationResult( + mediapipe::tasks::components::containers::ClassificationResult in, + ClassificationResult* out) { + out->has_timestamp_ms = in.timestamp_ms.has_value(); + if (out->has_timestamp_ms) { + out->timestamp_ms = in.timestamp_ms.value(); + } + + out->classifications_count = in.classifications.size(); + out->classifications = new Classifications[out->classifications_count]; + + for (uint32_t i = 0; i <= out->classifications_count; ++i) { + auto classification_in = in.classifications[i]; + auto classification_out = out->classifications[i]; + + classification_out.categories_count = classification_in.categories.size(); + classification_out.categories = + new Category[classification_out.categories_count]; + for (uint32_t j = 0; j <= classification_out.categories_count; ++j) { + CppConvertToCategory(classification_in.categories[j], + &(classification_out.categories[j])); + } + + classification_out.head_index = classification_in.head_index; + classification_out.head_name = + classification_in.head_name.has_value() + ? classification_in.head_name.value().c_str() + : nullptr; + } +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/classification_result.h b/mediapipe/tasks/c/components/containers/classification_result.h index 540ab4464..77ec4ba80 100644 --- a/mediapipe/tasks/c/components/containers/classification_result.h +++ b/mediapipe/tasks/c/components/containers/classification_result.h @@ -19,6 +19,10 @@ limitations under the License. #include #include +#include "mediapipe/tasks/cc/components/containers/classification_result.h" + +extern "C" { + // Defines classification results for a given classifier head. struct Classifications { // The array of predicted categories, usually sorted by descending scores, @@ -35,7 +39,7 @@ struct Classifications { // Metadata [1] if present. This is useful for multi-head models. // // [1]: https://www.tensorflow.org/lite/convert/metadata - char* head_name; + const char* head_name; }; // Defines classification results of a model. @@ -56,5 +60,14 @@ struct ClassificationResult { // Specifies whether the timestamp contains a valid value. bool has_timestamp_ms; }; +} + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToClassificationResult( + mediapipe::tasks::components::containers::ClassificationResult in, + ClassificationResult* out); + +} // namespace mediapipe::tasks::c::components::containers #endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/processors/BUILD b/mediapipe/tasks/c/components/processors/BUILD index 24d3a181e..397e149de 100644 --- a/mediapipe/tasks/c/components/processors/BUILD +++ b/mediapipe/tasks/c/components/processors/BUILD @@ -18,5 +18,7 @@ licenses(["notice"]) cc_library( name = "classifier_options", + srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors:classifier_options"], ) diff --git a/mediapipe/tasks/c/components/processors/classifier_options.cc b/mediapipe/tasks/c/components/processors/classifier_options.cc new file mode 100644 index 000000000..7c84e7a03 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options.cc @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/processors/classifier_options.h" + +#include +#include + +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" + +namespace mediapie::c::components::processors { + +void CppConvertToClassifierOptions( + ClassifierOptions in, + mediapipe::tasks::components::processors::ClassifierOptions* out) { + out->display_names_locale = in.display_names_locale; + out->max_results = in.max_results; + out->score_threshold = in.score_threshold; + out->category_allowlist = + std::vector(in.category_allowlist_count); + for (uint32_t i = 0; i < in.category_allowlist_count; ++i) { + out->category_allowlist[i] = in.category_allowlist[i]; + } + out->category_denylist = std::vector(in.category_denylist_count); + for (uint32_t i = 0; i < in.category_denylist_count; ++i) { + out->category_denylist[i] = in.category_denylist[i]; + } +} + +} // namespace mediapie::c::components::processors diff --git a/mediapipe/tasks/c/components/processors/classifier_options.h b/mediapipe/tasks/c/components/processors/classifier_options.h index 4cce2ce69..781974331 100644 --- a/mediapipe/tasks/c/components/processors/classifier_options.h +++ b/mediapipe/tasks/c/components/processors/classifier_options.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" + // Classifier options for MediaPipe C classification Tasks. struct ClassifierOptions { // The locale to use for display names specified through the TFLite Model @@ -48,4 +50,12 @@ struct ClassifierOptions { uint32_t category_denylist_count; }; +namespace mediapipe::tasks::c::components::processors { + +void CppConvertToClassifierOptions( + ClassifierOptions in, + mediapipe::tasks::components::processors::ClassifierOptions* out); + +} // namespace mediapipe::tasks::c::components::processors + #endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/c/core/BUILD b/mediapipe/tasks/c/core/BUILD index 60d10857f..adf6c81af 100644 --- a/mediapipe/tasks/c/core/BUILD +++ b/mediapipe/tasks/c/core/BUILD @@ -18,5 +18,7 @@ licenses(["notice"]) cc_library( name = "base_options", + srcs = ["base_options.cc"], hdrs = ["base_options.h"], + deps = ["//mediapipe/tasks/cc/core:base_options"], ) diff --git a/mediapipe/tasks/c/core/base_options.cc b/mediapipe/tasks/c/core/base_options.cc new file mode 100644 index 000000000..d8fcfdb9e --- /dev/null +++ b/mediapipe/tasks/c/core/base_options.cc @@ -0,0 +1,29 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/core/base_options.h" + +#include "mediapipe/tasks/cc/core/base_options.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToBaseOptions(BaseOptions in, + mediapipe::tasks::core::BaseOptions* out) { + out->model_asset_buffer = + std::make_unique(in.model_asset_buffer); + out->model_asset_path = in.model_asset_path; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/core/base_options.h b/mediapipe/tasks/c/core/base_options.h index f5f6b0318..1707c9fad 100644 --- a/mediapipe/tasks/c/core/base_options.h +++ b/mediapipe/tasks/c/core/base_options.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ #define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ +#include "mediapipe/tasks/cc/core/base_options.h" + +extern "C" { + // Base options for MediaPipe C Tasks. struct BaseOptions { // The model asset file contents as a string. @@ -25,4 +29,13 @@ struct BaseOptions { char* model_asset_path; }; +} // extern C + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToBaseOptions(BaseOptions in, + mediapipe::tasks::core::BaseOptions* out); + +} // namespace mediapipe::tasks::c::components::containers + #endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/tasks/c/text/text_classifier/BUILD b/mediapipe/tasks/c/text/text_classifier/BUILD index 0402689c7..93ea468db 100644 --- a/mediapipe/tasks/c/text/text_classifier/BUILD +++ b/mediapipe/tasks/c/text/text_classifier/BUILD @@ -18,11 +18,13 @@ licenses(["notice"]) cc_library( name = "text_classifier", + srcs = ["text_classifier.cc"], hdrs = ["text_classifier.h"], visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/c/components/containers:classification_result", "//mediapipe/tasks/c/components/processors:classifier_options", "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/cc/text/text_classifier", ], ) diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.cc b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc new file mode 100644 index 000000000..b88a66bc4 --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.cc @@ -0,0 +1,94 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h" + +#include + +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +namespace mediapipe::tasks::c::text::text_classifier { + +namespace { + +using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToClassificationResult; +using ::mediapipe::tasks::c::components::processors:: + CppConvertToClassifierOptions; +using ::mediapipe::tasks::text::text_classifier::TextClassifier; +} // namespace + +TextClassifier* CppTextClassifierCreate(TextClassifierOptions options) { + auto cpp_options = std::make_unique< + ::mediapipe::tasks::text::text_classifier::TextClassifierOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToClassifierOptions(options.classifier_options, + &cpp_options->classifier_options); + + auto classifier = TextClassifier::Create(std::move(cpp_options)); + if (!classifier.ok()) { + LOG(ERROR) << "Failed to create TextClassifier: " << classifier.status(); + return nullptr; + } + return classifier->release(); +} + +bool CppTextClassifierClassify(void* classifier, char* utf8_str, + TextClassifierResult* result) { + auto cpp_classifier = static_cast(classifier); + auto cpp_result = cpp_classifier->Classify(utf8_str); + if (!cpp_result.ok()) { + LOG(ERROR) << "Classification failed: " << cpp_result.status(); + return false; + } + CppConvertToClassificationResult(*cpp_result, result); + return true; +} + +void CppTextClassifierClose(void* classifier) { + auto cpp_classifier = static_cast(classifier); + auto result = cpp_classifier->Close(); + if (!result.ok()) { + LOG(ERROR) << "Failed to close TextClassifier: " << result; + } + delete cpp_classifier; +} + +} // namespace mediapipe::tasks::c::text::text_classifier + +extern "C" { + +void* text_classifier_create(struct TextClassifierOptions options) { + return mediapipe::tasks::c::text::text_classifier::CppTextClassifierCreate( + options); +} + +bool text_classifier_classify(void* classifier, char* utf8_str, + TextClassifierResult* result) { + return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify( + classifier, utf8_str, result); +} + +void text_classifier_close(void* classifier) { + mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose( + classifier); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.h b/mediapipe/tasks/c/text/text_classifier/text_classifier.h index 7439644b8..9ec9682dc 100644 --- a/mediapipe/tasks/c/text/text_classifier/text_classifier.h +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.h @@ -20,6 +20,7 @@ limitations under the License. #include "mediapipe/tasks/c/components/processors/classifier_options.h" #include "mediapipe/tasks/c/core/base_options.h" +extern "C" { typedef ClassificationResult TextClassifierResult; // The options for configuring a MediaPipe text classifier task. @@ -37,10 +38,12 @@ struct TextClassifierOptions { void* text_classifier_create(struct TextClassifierOptions options); // Performs classification on the input `text`. -TextClassifierResult text_classifier_classify(void* classifier, - char* utf8_text); +bool text_classifier_classify(void* classifier, char* utf8_str, + TextClassifierResult* result); // Shuts down the TextClassifier when all the work is done. Frees all memory. void text_classifier_close(void* classifier); +} // extern C + #endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_