diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 3d01639ce..fd07045c6 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -75,6 +75,8 @@ cc_library( srcs = ["mediapipe_builtin_op_resolver.cc"], hdrs = ["mediapipe_builtin_op_resolver.h"], deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", + "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", "//mediapipe/util/tflite/operations:max_unpooling", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index be8dec684..4046601fd 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_unpooling.h" @@ -43,6 +45,10 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { "Landmarks2TransformMatrix", mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); + // For the LanguageDetector model. + AddCustom("NGramHash", mediapipe::tflite_operations::Register_NGRAM_HASH()); + AddCustom("KmeansEmbeddingLookup", + mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); } } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/text/language_detector/BUILD b/mediapipe/tasks/cc/text/language_detector/BUILD new file mode 100644 index 000000000..57b9c7b51 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/BUILD @@ -0,0 +1,38 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "language_detector", + srcs = ["language_detector.cc"], + hdrs = ["language_detector.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc index 2ab3ed74d..2c9b7a172 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { namespace kmeans_embedding_lookup_op { namespace { @@ -33,6 +33,10 @@ constexpr int kEncodingTable = 1; constexpr int kCodebook = 2; constexpr int kOutputLabel = 0; +using ::tflite::GetInput; +using ::tflite::GetOutput; +using ::tflite::GetTensorData; + } // namespace TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -142,4 +146,4 @@ TfLiteRegistration* Register_KmeansEmbeddingLookup() { return &r; } -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h index 99025b1f6..31dd4abbd 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h @@ -27,10 +27,10 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { TfLiteRegistration* Register_KmeansEmbeddingLookup(); -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations #endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc index 7bfcb93b9..f1ee661d4 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup_test.cc @@ -12,14 +12,14 @@ #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/test_util.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { namespace { using ::testing::ElementsAreArray; using ::tflite::ArrayFloatNear; // Helper class for testing the op. -class KmeansEmbeddingLookupModel : public SingleOpModel { +class KmeansEmbeddingLookupModel : public tflite::SingleOpModel { public: explicit KmeansEmbeddingLookupModel( std::initializer_list input_shape, @@ -27,7 +27,7 @@ class KmeansEmbeddingLookupModel : public SingleOpModel { std::initializer_list codebook_shape, std::initializer_list output_shape) { // Setup the model inputs and the interpreter. - output_ = AddOutput({TensorType_FLOAT32, output_shape}); + output_ = AddOutput({tflite::TensorType_FLOAT32, output_shape}); SetCustomOp("KmeansEmbeddingLookup", std::vector(), Register_KmeansEmbeddingLookup); BuildInterpreter({input_shape, encoding_table_shape, codebook_shape}); @@ -68,9 +68,9 @@ class KmeansEmbeddingLookupModel : public SingleOpModel { std::vector GetOutputShape() { return GetTensorShape(output_); } private: - int input_ = AddInput(TensorType_INT32); - int encoding_table_ = AddInput(TensorType_UINT8); - int codebook_ = AddInput(TensorType_FLOAT32); + int input_ = AddInput(tflite::TensorType_INT32); + int encoding_table_ = AddInput(tflite::TensorType_UINT8); + int codebook_ = AddInput(tflite::TensorType_FLOAT32); int output_; }; @@ -173,4 +173,4 @@ TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) { } } // namespace -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc index 738fa1128..efe39a01f 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/string_util.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { namespace ngram_op { @@ -217,21 +217,21 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel); TF_LITE_ENSURE(context, output != nullptr); - SetTensorToDynamic(output); + tflite::SetTensorToDynamic(output); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { NGramHashParams* params = reinterpret_cast(node->user_data); TF_LITE_ENSURE_OK( - context, - params->PreprocessInput(GetInput(context, node, kInputMessage), context)); + context, params->PreprocessInput( + tflite::GetInput(context, node, kInputMessage), context)); - TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel); TF_LITE_ENSURE(context, output != nullptr); - if (IsDynamicTensor(output)) { + if (tflite::IsDynamicTensor(output)) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); output_size->data[0] = 1; output_size->data[1] = params->GetNumNGrams(); @@ -261,4 +261,4 @@ TfLiteRegistration* Register_NGRAM_HASH() { return &r; } -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h index a061357bd..c32e91c62 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h @@ -18,10 +18,10 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { TfLiteRegistration* Register_NGRAM_HASH(); -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations #endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc index 28d2dea6e..06af5f971 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/lite/model.h" #include "tensorflow/lite/string_util.h" -namespace tflite::ops::custom { +namespace mediapipe::tflite_operations { namespace { using ::flexbuffers::Builder; @@ -42,7 +42,7 @@ using ::testing::ElementsAreArray; using ::testing::Message; // Helper class for testing the op. -class NGramHashModel : public SingleOpModel { +class NGramHashModel : public tflite::SingleOpModel { public: explicit NGramHashModel(const uint64_t seed, const std::vector& ngram_lengths, @@ -71,7 +71,7 @@ class NGramHashModel : public SingleOpModel { } fbb.EndMap(start); fbb.Finish(); - output_ = AddOutput({TensorType_INT32, {}}); + output_ = AddOutput({tflite::TensorType_INT32, {}}); SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); BuildInterpreter({GetShape(input_)}); } @@ -100,7 +100,7 @@ class NGramHashModel : public SingleOpModel { std::vector GetOutputShape() { return GetTensorShape(output_); } private: - int input_ = AddInput(TensorType_STRING); + int input_ = AddInput(tflite::TensorType_STRING); int output_; }; @@ -173,7 +173,7 @@ TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { - const string& testcase_input = testcase_inputs[test_idx]; + const std::string& testcase_input = testcase_inputs[test_idx]; m.Invoke(testcase_input); SCOPED_TRACE(Message() << "Where the testcases' input is: " << testcase_input); @@ -310,4 +310,4 @@ TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { } } // namespace -} // namespace tflite::ops::custom +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.cc b/mediapipe/tasks/cc/text/language_detector/language_detector.cc new file mode 100644 index 000000000..e3841211b --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/language_detector.cc @@ -0,0 +1,126 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace mediapipe::tasks::text::language_detector { + +namespace { + +using ::mediapipe::tasks::components::containers::Category; +using ::mediapipe::tasks::components::containers::ClassificationResult; +using ::mediapipe::tasks::components::containers::Classifications; +using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::text::text_classifier::proto:: + TextClassifierGraphOptions; + +constexpr char kTextStreamName[] = "text_in"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kClassificationsStreamName[] = "classifications_out"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// type "TextClassifierGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); + subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> + graph.Out(kClassificationsTag); + return graph.GetConfig(); +} + +// Converts the user-facing LanguageDetectorOptions struct to the internal +// TextClassifierGraphOptions proto. +std::unique_ptr +ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->classifier_options))); + options_proto->mutable_classifier_options()->Swap( + classifier_options_proto.get()); + return options_proto; +} + +absl::StatusOr +ExtractLanguageDetectorResultFromClassificationResult( + const ClassificationResult& classification_result) { + if (classification_result.classifications.size() != 1) { + return absl::InvalidArgumentError( + "The LanguageDetector TextClassifierGraph should have exactly one " + "classification head."); + } + const Classifications& languages_and_scores = + classification_result.classifications[0]; + LanguageDetectorResult language_detector_result; + for (const Category& category : languages_and_scores.categories) { + if (!category.category_name.has_value()) { + return absl::InvalidArgumentError( + "LanguageDetector ClassificationResult has a missing language code."); + } + language_detector_result.push_back( + {.language_code = *category.category_name, + .probability = category.score}); + } + return language_detector_result; +} + +} // namespace + +absl::StatusOr> LanguageDetector::Create( + std::unique_ptr options) { + auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr LanguageDetector::Detect( + absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextStreamName, MakePacket(std::string(text))}})); + ClassificationResult classification_result = + ConvertToClassificationResult(output_packets[kClassificationsStreamName] + .Get()); + return ExtractLanguageDetectorResultFromClassificationResult( + classification_result); +} + +} // namespace mediapipe::tasks::text::language_detector diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.h b/mediapipe/tasks/cc/text/language_detector/language_detector.h new file mode 100644 index 000000000..bbe58dedf --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/language_detector.h @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe::tasks::text::language_detector { + +// A language code and its probability. +struct LanguageDetectorPrediction { + // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + // "ja"-Latn for Japanese (romaji). + std::string language_code; + + float probability; +}; + +// Task output. +using LanguageDetectorResult = std::vector; + +// The options for configuring a MediaPipe LanguageDetector task. +struct LanguageDetectorOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + components::processors::ClassifierOptions classifier_options; +}; + +// Predicts the language of an input text. +// +// This API expects a TFLite model with TFLite Model Metadata that +// contains the mandatory (described below) input tensors, output tensor, +// and the language codes in an AssociatedFile. +// +// Input tensors: +// (kTfLiteString) +// - 1 input tensor that is scalar or has shape [1] containing the input +// string. +// Output tensor: +// (kTfLiteFloat32) +// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages. +class LanguageDetector : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a LanguageDetector instance from the provided `options`. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Predicts the language of the input `text`. + absl::StatusOr Detect(absl::string_view text); + + // Shuts down the LanguageDetector instance when all the work is done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace mediapipe::tasks::text::language_detector + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc new file mode 100644 index 000000000..92dc493e0 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::language_detector { +namespace { + +using ::mediapipe::file::JoinPath; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; +constexpr char kLanguageDetector[] = "language_detector.tflite"; + +constexpr float kTolerance = 0.000001; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::Status MatchesLanguageDetectorResult( + const LanguageDetectorResult& expected, + const LanguageDetectorResult& actual, float tolerance) { + if (expected.size() != actual.size()) { + return absl::FailedPreconditionError(absl::Substitute( + "Expected $0 predictions, but got $1", expected.size(), actual.size())); + } + for (int i = 0; i < expected.size(); ++i) { + if (expected[i].language_code != actual[i].language_code) { + return absl::FailedPreconditionError(absl::Substitute( + "Expected prediction $0 to have language_code $1, but got $2", i, + expected[i].language_code, actual[i].language_code)); + } + if (std::abs(expected[i].probability - actual[i].probability) > tolerance) { + return absl::FailedPreconditionError(absl::Substitute( + "Expected prediction $0 to have probability $1, but got $2", i, + expected[i].probability, actual[i].probability)); + } + } + return absl::OkStatus(); +} + +} // namespace + +class LanguageDetectorTest : public tflite_shims::testing::Test {}; + +TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); + absl::StatusOr> language_detector = + LanguageDetector::Create(std::move(options)); + + EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound); + EXPECT_THAT(language_detector.status().message(), + HasSubstr("Unable to open file at")); + EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(LanguageDetectorTest, TestL2CModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kLanguageDetector); + options->classifier_options.score_threshold = 0.3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, + LanguageDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + LanguageDetectorResult result_en, + language_detector->Detect("To be, or not to be, that is the question")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "en", .probability = 0.999856}}, result_en, + kTolerance)); + MP_ASSERT_OK_AND_ASSIGN( + LanguageDetectorResult result_fr, + language_detector->Detect( + "Il y a beaucoup de bouches qui parlent et fort peu " + "de têtes qui pensent.")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "fr", .probability = 0.999781}}, result_fr, + kTolerance)); + MP_ASSERT_OK_AND_ASSIGN( + LanguageDetectorResult result_ru, + language_detector->Detect("это какой-то английский язык")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "ru", .probability = 0.993362}}, result_ru, + kTolerance)); +} + +TEST_F(LanguageDetectorTest, TestMultiplePredictions) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kLanguageDetector); + options->classifier_options.score_threshold = 0.3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, + LanguageDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed, + language_detector->Detect("分久必合合久必分")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "zh", .probability = 0.505424}, + {.language_code = "ja", .probability = 0.481617}}, + result_mixed, kTolerance)); +} + +TEST_F(LanguageDetectorTest, TestAllowList) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kLanguageDetector); + options->classifier_options.category_allowlist = {"ja"}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, + LanguageDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja, + language_detector->Detect("分久必合合久必分")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "ja", .probability = 0.481617}}, result_ja, + kTolerance)); +} + +TEST_F(LanguageDetectorTest, TestDenyList) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kLanguageDetector); + options->classifier_options.score_threshold = 0.3; + options->classifier_options.category_denylist = {"ja"}; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, + LanguageDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh, + language_detector->Detect("分久必合合久必分")); + MP_EXPECT_OK(MatchesLanguageDetectorResult( + {{.language_code = "zh", .probability = 0.505424}}, result_zh, + kTolerance)); +} + +} // namespace mediapipe::tasks::text::language_detector