From 882ec323f0cda4d4b6056dd3b96ee95989385e55 Mon Sep 17 00:00:00 2001 From: Kinar Date: Fri, 6 Oct 2023 11:39:23 -0700 Subject: [PATCH] Added files for the Language Detector C API and tests --- mediapipe/tasks/c/components/containers/BUILD | 15 +++ .../containers/language_detection_result.h | 46 +++++++ .../language_detection_result_converter.cc | 55 ++++++++ .../language_detection_result_converter.h | 32 +++++ .../tasks/c/text/language_detector/BUILD | 85 ++++++++++++ .../language_detector/language_detector.cc | 126 ++++++++++++++++++ .../language_detector/language_detector.h | 73 ++++++++++ .../language_detector_test.cc | 87 ++++++++++++ 8 files changed, 519 insertions(+) create mode 100644 mediapipe/tasks/c/components/containers/language_detection_result.h create mode 100644 mediapipe/tasks/c/components/containers/language_detection_result_converter.cc create mode 100644 mediapipe/tasks/c/components/containers/language_detection_result_converter.h create mode 100644 mediapipe/tasks/c/text/language_detector/BUILD create mode 100644 mediapipe/tasks/c/text/language_detector/language_detector.cc create mode 100644 mediapipe/tasks/c/text/language_detector/language_detector.h create mode 100644 mediapipe/tasks/c/text/language_detector/language_detector_test.cc diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD index 0d89c820e..c4697a9dd 100644 --- a/mediapipe/tasks/c/components/containers/BUILD +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -98,3 +98,18 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "language_detection_result", + hdrs = ["language_detection_result.h"], +) + +cc_library( + name = "language_detection_result_converter", + srcs = ["language_detection_result_converter.cc"], + hdrs = ["language_detection_result_converter.h"], + deps = [ + ":language_detection_result", + "//mediapipe/tasks/cc/text/language_detector:language_detector", + ], +) diff --git a/mediapipe/tasks/c/components/containers/language_detection_result.h b/mediapipe/tasks/c/components/containers/language_detection_result.h new file mode 100644 index 000000000..a2b7ee9db --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// A language code and its probability. +struct LanguageDetectorPrediction { + // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + // "ja"-Latn for Japanese (romaji). + char* language_code; + + float probability; +}; + +// Task output. +struct LanguageDetectorResult { + struct LanguageDetectorPrediction* predictions; + + // Keep the count of predictions. + uint32_t predictions_count; +}; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc b/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc new file mode 100644 index 000000000..e9a7b8bab --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result_converter.cc @@ -0,0 +1,55 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" + +#include + +#include "mediapipe/tasks/c/components/containers/language_detection_result.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToLanguageDetectionResult( + const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, + LanguageDetectorResult* out) { + out->predictions_count = in.size(); + out->predictions = + out->predictions_count + ? new LanguageDetectorPrediction[out->predictions_count] + : nullptr; + + for (uint32_t i = 0; i < out->predictions_count; ++i) { + auto language_detection_prediction_in = in[i]; + auto& language_detection_prediction_out = out->predictions[i]; + language_detection_prediction_out.probability = + language_detection_prediction_in.probability; + language_detection_prediction_out.language_code = + strdup(language_detection_prediction_in.language_code.c_str()); + } +} + +void CppCloseLanguageDetectionResult(LanguageDetectorResult* in) { + for (uint32_t i = 0; i < in->predictions_count; ++i) { + auto prediction_in = in->predictions[i]; + + free(prediction_in.language_code); + prediction_in.language_code = nullptr; + } + delete[] in->predictions; + in->predictions = nullptr; +} + +} // namespace mediapipe::tasks::c::components::containers diff --git a/mediapipe/tasks/c/components/containers/language_detection_result_converter.h b/mediapipe/tasks/c/components/containers/language_detection_result_converter.h new file mode 100644 index 000000000..74535de7f --- /dev/null +++ b/mediapipe/tasks/c/components/containers/language_detection_result_converter.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ + +#include "mediapipe/tasks/c/components/containers/language_detection_result.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::components::containers { + +void CppConvertToLanguageDetectionResult( + const mediapipe::tasks::text::language_detector::LanguageDetectorResult& in, + LanguageDetectorResult* out); + +void CppCloseLanguageDetectionResult(LanguageDetectorResult* in); + +} // namespace mediapipe::tasks::c::components::containers + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_LANGUAGE_DETECTION_RESULT_CONVERTER_H_ diff --git a/mediapipe/tasks/c/text/language_detector/BUILD b/mediapipe/tasks/c/text/language_detector/BUILD new file mode 100644 index 000000000..cd711b696 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/BUILD @@ -0,0 +1,85 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "language_detector_lib", + srcs = ["language_detector.cc"], + hdrs = ["language_detector.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/c/components/containers:language_detection_result", + "//mediapipe/tasks/c/components/containers:language_detection_result_converter", + "//mediapipe/tasks/c/components/processors:classifier_options", + "//mediapipe/tasks/c/components/processors:classifier_options_converter", + "//mediapipe/tasks/c/core:base_options", + "//mediapipe/tasks/c/core:base_options_converter", + "//mediapipe/tasks/cc/text/language_detector", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.so +cc_binary( + name = "liblanguage_detector.so", + linkopts = [ + "-Wl,-soname=liblanguage_detector.so", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":language_detector_lib"], +) + +# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \ +# //mediapipe/tasks/c/text/language_detector:liblanguage_detector.dylib +cc_binary( + name = "liblanguage_detector.dylib", + linkopts = [ + "-Wl,-install_name,liblanguage_detector.dylib", + "-fvisibility=hidden", + ], + linkshared = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], + deps = [":language_detector_lib"], +) + +cc_test( + name = "language_detector_test", + srcs = ["language_detector_test.cc"], + data = ["//mediapipe/tasks/testdata/text:language_detector"], + linkstatic = 1, + deps = [ + ":language_detector_lib", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/mediapipe/tasks/c/text/language_detector/language_detector.cc b/mediapipe/tasks/c/text/language_detector/language_detector.cc new file mode 100644 index 000000000..c6ea750c1 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector.cc @@ -0,0 +1,126 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "mediapipe/tasks/c/components/containers/language_detection_result_converter.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h" +#include "mediapipe/tasks/c/core/base_options.h" +#include "mediapipe/tasks/c/core/base_options_converter.h" +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +namespace mediapipe::tasks::c::text::language_detector { + +namespace { + +using ::mediapipe::tasks::c::components::containers:: + CppCloseLanguageDetectionResult; +using ::mediapipe::tasks::c::components::containers:: + CppConvertToLanguageDetectionResult; +using ::mediapipe::tasks::c::components::processors:: + CppConvertToClassifierOptions; +using ::mediapipe::tasks::c::core::CppConvertToBaseOptions; +using ::mediapipe::tasks::text::language_detector::LanguageDetector; + +int CppProcessError(absl::Status status, char** error_msg) { + if (error_msg) { + *error_msg = strdup(status.ToString().c_str()); + } + return status.raw_code(); +} + +} // namespace + +LanguageDetector* CppLanguageDetectorCreate( + const LanguageDetectorOptions& options, char** error_msg) { + auto cpp_options = std::make_unique< + ::mediapipe::tasks::text::language_detector::LanguageDetectorOptions>(); + + CppConvertToBaseOptions(options.base_options, &cpp_options->base_options); + CppConvertToClassifierOptions(options.classifier_options, + &cpp_options->classifier_options); + + auto detector = LanguageDetector::Create(std::move(cpp_options)); + if (!detector.ok()) { + ABSL_LOG(ERROR) << "Failed to create LanguageDetector: " + << detector.status(); + CppProcessError(detector.status(), error_msg); + return nullptr; + } + return detector->release(); +} + +int CppLanguageDetectorDetect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, + char** error_msg) { + auto cpp_detector = static_cast(detector); + auto cpp_result = cpp_detector->Detect(utf8_str); + if (!cpp_result.ok()) { + ABSL_LOG(ERROR) << "Language Detection failed: " << cpp_result.status(); + return CppProcessError(cpp_result.status(), error_msg); + } + + CppConvertToLanguageDetectionResult(*cpp_result, result); + return 0; +} + +void CppLanguageDetectorCloseResult(LanguageDetectorResult* result) { + CppCloseLanguageDetectionResult(result); +} + +int CppLanguageDetectorClose(void* detector, char** error_msg) { + auto cpp_detector = static_cast(detector); + auto result = cpp_detector->Close(); + if (!result.ok()) { + ABSL_LOG(ERROR) << "Failed to close LanguageDetector: " << result; + return CppProcessError(result, error_msg); + } + delete cpp_detector; + return 0; +} + +} // namespace mediapipe::tasks::c::text::language_detector + +extern "C" { + +void* language_detector_create(struct LanguageDetectorOptions* options, + char** error_msg) { + return mediapipe::tasks::c::text::language_detector:: + CppLanguageDetectorCreate(*options, error_msg); +} + +int language_detector_detect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, char** error_msg) { + return mediapipe::tasks::c::text::language_detector:: + CppLanguageDetectorDetect(detector, utf8_str, result, error_msg); +} + +void language_detector_close_result(LanguageDetectorResult* result) { + mediapipe::tasks::c::text::language_detector::CppLanguageDetectorCloseResult( + result); +} + +int language_detector_close(void* detector, char** error_ms) { + return mediapipe::tasks::c::text::language_detector::CppLanguageDetectorClose( + detector, error_ms); +} + +} // extern "C" diff --git a/mediapipe/tasks/c/text/language_detector/language_detector.h b/mediapipe/tasks/c/text/language_detector/language_detector.h new file mode 100644 index 000000000..d19fc8ca1 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector.h @@ -0,0 +1,73 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ +#define MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ + +#include "mediapipe/tasks/c/components/containers/language_detection_result.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/core/base_options.h" + +#ifndef MP_EXPORT +#define MP_EXPORT __attribute__((visibility("default"))) +#endif // MP_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +// The options for configuring a MediaPipe language detector task. +struct LanguageDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the detector behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// Creates a LanguageDetector from the provided `options`. +// Returns a pointer to the language detector on success. +// If an error occurs, returns `nullptr` and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT void* language_detector_create( + struct LanguageDetectorOptions* options, char** error_msg = nullptr); + +// Performs language detection on the input `text`. Returns `0` on success. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_detect(void* detector, const char* utf8_str, + LanguageDetectorResult* result, + char** error_msg = nullptr); + +// Frees the memory allocated inside a LanguageDetectorResult result. Does not +// free the result pointer itself. +MP_EXPORT void language_detector_close_result(LanguageDetectorResult* result); + +// Shuts down the LanguageDetector when all the work is done. Frees all memory. +// If an error occurs, returns an error code and sets the error parameter to an +// an error message (if `error_msg` is not nullptr). You must free the memory +// allocated for the error message. +MP_EXPORT int language_detector_close(void* detector, + char** error_msg = nullptr); + +#ifdef __cplusplus +} // extern C +#endif + +#endif // MEDIAPIPE_TASKS_C_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ diff --git a/mediapipe/tasks/c/text/language_detector/language_detector_test.cc b/mediapipe/tasks/c/text/language_detector/language_detector_test.cc new file mode 100644 index 000000000..b8653e616 --- /dev/null +++ b/mediapipe/tasks/c/text/language_detector/language_detector_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/c/text/language_detector/language_detector.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace { + +using ::mediapipe::file::JoinPath; +using testing::HasSubstr; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestLanguageDetectorModelPath[] = "language_detector.tflite"; +constexpr char kTestString[] = + "Il y a beaucoup de bouches qui parlent et fort peu " + "de tĂȘtes qui pensent."; +constexpr float kPrecision = 1e-6; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +TEST(LanguageDetectorTest, SmokeTest) { + std::string model_path = GetFullPath(kTestLanguageDetectorModelPath); + LanguageDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ model_path.c_str()}, + /* classifier_options= */ + {/* display_names_locale= */ nullptr, + /* max_results= */ -1, + /* score_threshold= */ 0.0, + /* category_allowlist= */ nullptr, + /* category_allowlist_count= */ 0, + /* category_denylist= */ nullptr, + /* category_denylist_count= */ 0}, + }; + + void* detector = language_detector_create(&options); + EXPECT_NE(detector, nullptr); + + LanguageDetectorResult result; + language_detector_detect(detector, kTestString, &result); + EXPECT_EQ(std::string(result.predictions[0].language_code), "fr"); + EXPECT_NEAR(result.predictions[0].probability, 0.999781, kPrecision); + + language_detector_close_result(&result); + language_detector_close(detector); +} + +TEST(LanguageDetectorTest, ErrorHandling) { + // It is an error to set neither the asset buffer nor the path. + LanguageDetectorOptions options = { + /* base_options= */ {/* model_asset_buffer= */ nullptr, + /* model_asset_path= */ nullptr}, + /* classifier_options= */ {}, + }; + + char* error_msg; + void* detector = language_detector_create(&options, &error_msg); + EXPECT_EQ(detector, nullptr); + + EXPECT_THAT(error_msg, HasSubstr("INVALID_ARGUMENT")); + + free(error_msg); +} + +} // namespace