Add C++ converters for C Text Classifier API

PiperOrigin-RevId: 559519880
This commit is contained in:
Sebastian Schmidt 2023-08-23 13:06:24 -07:00 committed by Copybara-Service
parent f645c59746
commit f3d069175c
14 changed files with 320 additions and 7 deletions

View File

@ -1,5 +1,3 @@
# TODO: describe this package.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,10 +18,17 @@ licenses(["notice"])
cc_library(
name = "category",
srcs = ["category.cc"],
hdrs = ["category.h"],
deps = ["//mediapipe/tasks/cc/components/containers:category"],
)
cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
":category",
"//mediapipe/tasks/cc/components/containers:classification_result",
],
)

View File

@ -0,0 +1,30 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/category.h"
namespace mediapie::tasks::c::components::containers {
void CppConvertToCategory(mediapipe::tasks::components::containers::Category in,
Category* out) {
out->index = in.index;
out->score = in.score;
out->category_name =
in.category_name.has_value() ? in.category_name->c_str() : nullptr;
out->display_name =
in.display_name.has_value() ? in.display_name->c_str() : nullptr;
}
} // namespace mediapie::tasks::c::components::containers

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_
#include "mediapipe/tasks/cc/components/containers/category.h"
extern "C" {
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
@ -32,11 +35,19 @@ struct Category {
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
char* category_name;
const char* category_name;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
char* display_name;
const char* display_name;
};
}
namespace mediapie::tasks::c::components::containers {
void CppConvertToCategory(mediapipe::tasks::components::containers::Category in,
Category* out);
} // namespace mediapie::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/containers/category.h"
namespace mediapipe::tasks::c::components::containers {
namespace {
using mediapie::tasks::c::components::containers::CppConvertToCategory;
} // namespace
void CppConvertToClassificationResult(
mediapipe::tasks::components::containers::ClassificationResult in,
ClassificationResult* out) {
out->has_timestamp_ms = in.timestamp_ms.has_value();
if (out->has_timestamp_ms) {
out->timestamp_ms = in.timestamp_ms.value();
}
out->classifications_count = in.classifications.size();
out->classifications = new Classifications[out->classifications_count];
for (uint32_t i = 0; i <= out->classifications_count; ++i) {
auto classification_in = in.classifications[i];
auto classification_out = out->classifications[i];
classification_out.categories_count = classification_in.categories.size();
classification_out.categories =
new Category[classification_out.categories_count];
for (uint32_t j = 0; j <= classification_out.categories_count; ++j) {
CppConvertToCategory(classification_in.categories[j],
&(classification_out.categories[j]));
}
classification_out.head_index = classification_in.head_index;
classification_out.head_name =
classification_in.head_name.has_value()
? classification_in.head_name.value().c_str()
: nullptr;
}
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
extern "C" {
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
@ -35,7 +39,7 @@ struct Classifications {
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
char* head_name;
const char* head_name;
};
// Defines classification results of a model.
@ -56,5 +60,14 @@ struct ClassificationResult {
// Specifies whether the timestamp contains a valid value.
bool has_timestamp_ms;
};
}
namespace mediapipe::tasks::c::components::containers {
void CppConvertToClassificationResult(
mediapipe::tasks::components::containers::ClassificationResult in,
ClassificationResult* out);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -18,5 +18,7 @@ licenses(["notice"])
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors:classifier_options"],
)

View File

@ -0,0 +1,42 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include <cstdint>
#include <vector>
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
namespace mediapie::c::components::processors {
void CppConvertToClassifierOptions(
ClassifierOptions in,
mediapipe::tasks::components::processors::ClassifierOptions* out) {
out->display_names_locale = in.display_names_locale;
out->max_results = in.max_results;
out->score_threshold = in.score_threshold;
out->category_allowlist =
std::vector<std::string>(in.category_allowlist_count);
for (uint32_t i = 0; i < in.category_allowlist_count; ++i) {
out->category_allowlist[i] = in.category_allowlist[i];
}
out->category_denylist = std::vector<std::string>(in.category_denylist_count);
for (uint32_t i = 0; i < in.category_denylist_count; ++i) {
out->category_denylist[i] = in.category_denylist[i];
}
}
} // namespace mediapie::c::components::processors

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <stdint.h>
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
// Classifier options for MediaPipe C classification Tasks.
struct ClassifierOptions {
// The locale to use for display names specified through the TFLite Model
@ -48,4 +50,12 @@ struct ClassifierOptions {
uint32_t category_denylist_count;
};
namespace mediapipe::tasks::c::components::processors {
void CppConvertToClassifierOptions(
ClassifierOptions in,
mediapipe::tasks::components::processors::ClassifierOptions* out);
} // namespace mediapipe::tasks::c::components::processors
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_

View File

@ -18,5 +18,7 @@ licenses(["notice"])
cc_library(
name = "base_options",
srcs = ["base_options.cc"],
hdrs = ["base_options.h"],
deps = ["//mediapipe/tasks/cc/core:base_options"],
)

View File

@ -0,0 +1,29 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToBaseOptions(BaseOptions in,
mediapipe::tasks::core::BaseOptions* out) {
out->model_asset_buffer =
std::make_unique<std::string>(in.model_asset_buffer);
out->model_asset_path = in.model_asset_path;
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -16,6 +16,10 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
#define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
#include "mediapipe/tasks/cc/core/base_options.h"
extern "C" {
// Base options for MediaPipe C Tasks.
struct BaseOptions {
// The model asset file contents as a string.
@ -25,4 +29,13 @@ struct BaseOptions {
char* model_asset_path;
};
} // extern C
namespace mediapipe::tasks::c::components::containers {
void CppConvertToBaseOptions(BaseOptions in,
mediapipe::tasks::core::BaseOptions* out);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_

View File

@ -18,11 +18,13 @@ licenses(["notice"])
cc_library(
name = "text_classifier",
srcs = ["text_classifier.cc"],
hdrs = ["text_classifier.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/c/components/containers:classification_result",
"//mediapipe/tasks/c/components/processors:classifier_options",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/cc/text/text_classifier",
],
)

View File

@ -0,0 +1,94 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/text/text_classifier/text_classifier.h"
#include <memory>
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
namespace mediapipe::tasks::c::text::text_classifier {
namespace {
using ::mediapipe::tasks::c::components::containers::CppConvertToBaseOptions;
using ::mediapipe::tasks::c::components::containers::
CppConvertToClassificationResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions;
using ::mediapipe::tasks::text::text_classifier::TextClassifier;
} // namespace
TextClassifier* CppTextClassifierCreate(TextClassifierOptions options) {
auto cpp_options = std::make_unique<
::mediapipe::tasks::text::text_classifier::TextClassifierOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToClassifierOptions(options.classifier_options,
&cpp_options->classifier_options);
auto classifier = TextClassifier::Create(std::move(cpp_options));
if (!classifier.ok()) {
LOG(ERROR) << "Failed to create TextClassifier: " << classifier.status();
return nullptr;
}
return classifier->release();
}
bool CppTextClassifierClassify(void* classifier, char* utf8_str,
TextClassifierResult* result) {
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
auto cpp_result = cpp_classifier->Classify(utf8_str);
if (!cpp_result.ok()) {
LOG(ERROR) << "Classification failed: " << cpp_result.status();
return false;
}
CppConvertToClassificationResult(*cpp_result, result);
return true;
}
void CppTextClassifierClose(void* classifier) {
auto cpp_classifier = static_cast<TextClassifier*>(classifier);
auto result = cpp_classifier->Close();
if (!result.ok()) {
LOG(ERROR) << "Failed to close TextClassifier: " << result;
}
delete cpp_classifier;
}
} // namespace mediapipe::tasks::c::text::text_classifier
extern "C" {
void* text_classifier_create(struct TextClassifierOptions options) {
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierCreate(
options);
}
bool text_classifier_classify(void* classifier, char* utf8_str,
TextClassifierResult* result) {
return mediapipe::tasks::c::text::text_classifier::CppTextClassifierClassify(
classifier, utf8_str, result);
}
void text_classifier_close(void* classifier) {
mediapipe::tasks::c::text::text_classifier::CppTextClassifierClose(
classifier);
}
} // extern "C"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
extern "C" {
typedef ClassificationResult TextClassifierResult;
// The options for configuring a MediaPipe text classifier task.
@ -37,10 +38,12 @@ struct TextClassifierOptions {
void* text_classifier_create(struct TextClassifierOptions options);
// Performs classification on the input `text`.
TextClassifierResult text_classifier_classify(void* classifier,
char* utf8_text);
bool text_classifier_classify(void* classifier, char* utf8_str,
TextClassifierResult* result);
// Shuts down the TextClassifier when all the work is done. Frees all memory.
void text_classifier_close(void* classifier);
} // extern C
#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_