Internal change
PiperOrigin-RevId: 519013105
This commit is contained in:
		
							parent
							
								
									8a55f11952
								
							
						
					
					
						commit
						712ea6f15b
					
				|  | @ -75,6 +75,8 @@ cc_library( | |||
|     srcs = ["mediapipe_builtin_op_resolver.cc"], | ||||
|     hdrs = ["mediapipe_builtin_op_resolver.h"], | ||||
|     deps = [ | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", | ||||
|         "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", | ||||
|         "//mediapipe/util/tflite/operations:max_pool_argmax", | ||||
|         "//mediapipe/util/tflite/operations:max_unpooling", | ||||
|  |  | |||
|  | @ -15,6 +15,8 @@ limitations under the License. | |||
| 
 | ||||
| #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" | ||||
| #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" | ||||
| #include "mediapipe/util/tflite/operations/max_pool_argmax.h" | ||||
| #include "mediapipe/util/tflite/operations/max_unpooling.h" | ||||
|  | @ -43,6 +45,10 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { | |||
|       "Landmarks2TransformMatrix", | ||||
|       mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), | ||||
|       /*version=*/2); | ||||
|   // For the LanguageDetector model.
 | ||||
|   AddCustom("NGramHash", mediapipe::tflite_operations::Register_NGRAM_HASH()); | ||||
|   AddCustom("KmeansEmbeddingLookup", | ||||
|             mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); | ||||
| } | ||||
| }  // namespace core
 | ||||
| }  // namespace tasks
 | ||||
|  |  | |||
							
								
								
									
										38
									
								
								mediapipe/tasks/cc/text/language_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/cc/text/language_detector/BUILD
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,38 @@ | |||
| # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| package(default_visibility = ["//mediapipe/tasks:internal"]) | ||||
| 
 | ||||
| licenses(["notice"]) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "language_detector", | ||||
|     srcs = ["language_detector.cc"], | ||||
|     hdrs = ["language_detector.h"], | ||||
|     visibility = ["//visibility:public"], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework/api2:builder", | ||||
|         "//mediapipe/tasks/cc/components/containers:category", | ||||
|         "//mediapipe/tasks/cc/components/containers:classification_result", | ||||
|         "//mediapipe/tasks/cc/components/processors:classifier_options", | ||||
|         "//mediapipe/tasks/cc/core:base_options", | ||||
|         "//mediapipe/tasks/cc/core:base_task_api", | ||||
|         "//mediapipe/tasks/cc/core:task_api_factory", | ||||
|         "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", | ||||
|         "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
| ) | ||||
|  | @ -23,7 +23,7 @@ limitations under the License. | |||
| #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" | ||||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| namespace kmeans_embedding_lookup_op { | ||||
| 
 | ||||
| namespace { | ||||
|  | @ -33,6 +33,10 @@ constexpr int kEncodingTable = 1; | |||
| constexpr int kCodebook = 2; | ||||
| constexpr int kOutputLabel = 0; | ||||
| 
 | ||||
| using ::tflite::GetInput; | ||||
| using ::tflite::GetOutput; | ||||
| using ::tflite::GetTensorData; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||||
|  | @ -142,4 +146,4 @@ TfLiteRegistration* Register_KmeansEmbeddingLookup() { | |||
|   return &r; | ||||
| } | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
|  |  | |||
|  | @ -27,10 +27,10 @@ limitations under the License. | |||
| 
 | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| 
 | ||||
| TfLiteRegistration* Register_KmeansEmbeddingLookup(); | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
 | ||||
|  |  | |||
|  | @ -12,14 +12,14 @@ | |||
| #include "tensorflow/lite/interpreter.h" | ||||
| #include "tensorflow/lite/kernels/test_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| namespace { | ||||
| 
 | ||||
| using ::testing::ElementsAreArray; | ||||
| using ::tflite::ArrayFloatNear; | ||||
| 
 | ||||
| // Helper class for testing the op.
 | ||||
| class KmeansEmbeddingLookupModel : public SingleOpModel { | ||||
| class KmeansEmbeddingLookupModel : public tflite::SingleOpModel { | ||||
|  public: | ||||
|   explicit KmeansEmbeddingLookupModel( | ||||
|       std::initializer_list<int> input_shape, | ||||
|  | @ -27,7 +27,7 @@ class KmeansEmbeddingLookupModel : public SingleOpModel { | |||
|       std::initializer_list<int> codebook_shape, | ||||
|       std::initializer_list<int> output_shape) { | ||||
|     // Setup the model inputs and the interpreter.
 | ||||
|     output_ = AddOutput({TensorType_FLOAT32, output_shape}); | ||||
|     output_ = AddOutput({tflite::TensorType_FLOAT32, output_shape}); | ||||
|     SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(), | ||||
|                 Register_KmeansEmbeddingLookup); | ||||
|     BuildInterpreter({input_shape, encoding_table_shape, codebook_shape}); | ||||
|  | @ -68,9 +68,9 @@ class KmeansEmbeddingLookupModel : public SingleOpModel { | |||
|   std::vector<int> GetOutputShape() { return GetTensorShape(output_); } | ||||
| 
 | ||||
|  private: | ||||
|   int input_ = AddInput(TensorType_INT32); | ||||
|   int encoding_table_ = AddInput(TensorType_UINT8); | ||||
|   int codebook_ = AddInput(TensorType_FLOAT32); | ||||
|   int input_ = AddInput(tflite::TensorType_INT32); | ||||
|   int encoding_table_ = AddInput(tflite::TensorType_UINT8); | ||||
|   int codebook_ = AddInput(tflite::TensorType_FLOAT32); | ||||
|   int output_; | ||||
| }; | ||||
| 
 | ||||
|  | @ -173,4 +173,4 @@ TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) { | |||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ limitations under the License. | |||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| 
 | ||||
| namespace ngram_op { | ||||
| 
 | ||||
|  | @ -217,21 +217,21 @@ void Free(TfLiteContext* context, void* buffer) { | |||
| } | ||||
| 
 | ||||
| TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { | ||||
|   TfLiteTensor* output = GetOutput(context, node, kOutputLabel); | ||||
|   TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel); | ||||
|   TF_LITE_ENSURE(context, output != nullptr); | ||||
|   SetTensorToDynamic(output); | ||||
|   tflite::SetTensorToDynamic(output); | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|   NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data); | ||||
|   TF_LITE_ENSURE_OK( | ||||
|       context, | ||||
|       params->PreprocessInput(GetInput(context, node, kInputMessage), context)); | ||||
|       context, params->PreprocessInput( | ||||
|                    tflite::GetInput(context, node, kInputMessage), context)); | ||||
| 
 | ||||
|   TfLiteTensor* output = GetOutput(context, node, kOutputLabel); | ||||
|   TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel); | ||||
|   TF_LITE_ENSURE(context, output != nullptr); | ||||
|   if (IsDynamicTensor(output)) { | ||||
|   if (tflite::IsDynamicTensor(output)) { | ||||
|     TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); | ||||
|     output_size->data[0] = 1; | ||||
|     output_size->data[1] = params->GetNumNGrams(); | ||||
|  | @ -261,4 +261,4 @@ TfLiteRegistration* Register_NGRAM_HASH() { | |||
|   return &r; | ||||
| } | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
|  |  | |||
|  | @ -18,10 +18,10 @@ limitations under the License. | |||
| 
 | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| 
 | ||||
| TfLiteRegistration* Register_NGRAM_HASH(); | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
 | ||||
|  |  | |||
|  | @ -32,7 +32,7 @@ limitations under the License. | |||
| #include "tensorflow/lite/model.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace mediapipe::tflite_operations { | ||||
| namespace { | ||||
| 
 | ||||
| using ::flexbuffers::Builder; | ||||
|  | @ -42,7 +42,7 @@ using ::testing::ElementsAreArray; | |||
| using ::testing::Message; | ||||
| 
 | ||||
| // Helper class for testing the op.
 | ||||
| class NGramHashModel : public SingleOpModel { | ||||
| class NGramHashModel : public tflite::SingleOpModel { | ||||
|  public: | ||||
|   explicit NGramHashModel(const uint64_t seed, | ||||
|                           const std::vector<int>& ngram_lengths, | ||||
|  | @ -71,7 +71,7 @@ class NGramHashModel : public SingleOpModel { | |||
|     } | ||||
|     fbb.EndMap(start); | ||||
|     fbb.Finish(); | ||||
|     output_ = AddOutput({TensorType_INT32, {}}); | ||||
|     output_ = AddOutput({tflite::TensorType_INT32, {}}); | ||||
|     SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); | ||||
|     BuildInterpreter({GetShape(input_)}); | ||||
|   } | ||||
|  | @ -100,7 +100,7 @@ class NGramHashModel : public SingleOpModel { | |||
|   std::vector<int> GetOutputShape() { return GetTensorShape(output_); } | ||||
| 
 | ||||
|  private: | ||||
|   int input_ = AddInput(TensorType_STRING); | ||||
|   int input_ = AddInput(tflite::TensorType_STRING); | ||||
|   int output_; | ||||
| }; | ||||
| 
 | ||||
|  | @ -173,7 +173,7 @@ TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { | |||
| 
 | ||||
|   NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); | ||||
|   for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { | ||||
|     const string& testcase_input = testcase_inputs[test_idx]; | ||||
|     const std::string& testcase_input = testcase_inputs[test_idx]; | ||||
|     m.Invoke(testcase_input); | ||||
|     SCOPED_TRACE(Message() << "Where the testcases' input is: " | ||||
|                            << testcase_input); | ||||
|  | @ -310,4 +310,4 @@ TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { | |||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
|  |  | |||
							
								
								
									
										126
									
								
								mediapipe/tasks/cc/text/language_detector/language_detector.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								mediapipe/tasks/cc/text/language_detector/language_detector.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,126 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/language_detector/language_detector.h" | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "mediapipe/framework/api2/builder.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/category.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||
| #include "mediapipe/tasks/cc/core/task_api_factory.h" | ||||
| #include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::language_detector { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::tasks::components::containers::Category; | ||||
| using ::mediapipe::tasks::components::containers::ClassificationResult; | ||||
| using ::mediapipe::tasks::components::containers::Classifications; | ||||
| using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; | ||||
| using ClassificationResultProto = | ||||
|     ::mediapipe::tasks::components::containers::proto::ClassificationResult; | ||||
| using ::mediapipe::tasks::text::text_classifier::proto:: | ||||
|     TextClassifierGraphOptions; | ||||
| 
 | ||||
| constexpr char kTextStreamName[] = "text_in"; | ||||
| constexpr char kTextTag[] = "TEXT"; | ||||
| constexpr char kClassificationsStreamName[] = "classifications_out"; | ||||
| constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||
| constexpr char kSubgraphTypeName[] = | ||||
|     "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; | ||||
| 
 | ||||
| // Creates a MediaPipe graph config that only contains a single subgraph node of
 | ||||
| // type "TextClassifierGraph".
 | ||||
| CalculatorGraphConfig CreateGraphConfig( | ||||
|     std::unique_ptr<TextClassifierGraphOptions> options) { | ||||
|   api2::builder::Graph graph; | ||||
|   auto& subgraph = graph.AddNode(kSubgraphTypeName); | ||||
|   subgraph.GetOptions<TextClassifierGraphOptions>().Swap(options.get()); | ||||
|   graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); | ||||
|   subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> | ||||
|       graph.Out(kClassificationsTag); | ||||
|   return graph.GetConfig(); | ||||
| } | ||||
| 
 | ||||
| // Converts the user-facing LanguageDetectorOptions struct to the internal
 | ||||
| // TextClassifierGraphOptions proto.
 | ||||
| std::unique_ptr<TextClassifierGraphOptions> | ||||
| ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) { | ||||
|   auto options_proto = std::make_unique<TextClassifierGraphOptions>(); | ||||
|   auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( | ||||
|       tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); | ||||
|   options_proto->mutable_base_options()->Swap(base_options_proto.get()); | ||||
|   auto classifier_options_proto = | ||||
|       std::make_unique<tasks::components::processors::proto::ClassifierOptions>( | ||||
|           components::processors::ConvertClassifierOptionsToProto( | ||||
|               &(options->classifier_options))); | ||||
|   options_proto->mutable_classifier_options()->Swap( | ||||
|       classifier_options_proto.get()); | ||||
|   return options_proto; | ||||
| } | ||||
| 
 | ||||
| absl::StatusOr<LanguageDetectorResult> | ||||
| ExtractLanguageDetectorResultFromClassificationResult( | ||||
|     const ClassificationResult& classification_result) { | ||||
|   if (classification_result.classifications.size() != 1) { | ||||
|     return absl::InvalidArgumentError( | ||||
|         "The LanguageDetector TextClassifierGraph should have exactly one " | ||||
|         "classification head."); | ||||
|   } | ||||
|   const Classifications& languages_and_scores = | ||||
|       classification_result.classifications[0]; | ||||
|   LanguageDetectorResult language_detector_result; | ||||
|   for (const Category& category : languages_and_scores.categories) { | ||||
|     if (!category.category_name.has_value()) { | ||||
|       return absl::InvalidArgumentError( | ||||
|           "LanguageDetector ClassificationResult has a missing language code."); | ||||
|     } | ||||
|     language_detector_result.push_back( | ||||
|         {.language_code = *category.category_name, | ||||
|          .probability = category.score}); | ||||
|   } | ||||
|   return language_detector_result; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| absl::StatusOr<std::unique_ptr<LanguageDetector>> LanguageDetector::Create( | ||||
|     std::unique_ptr<LanguageDetectorOptions> options) { | ||||
|   auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get()); | ||||
|   return core::TaskApiFactory::Create<LanguageDetector, | ||||
|                                       TextClassifierGraphOptions>( | ||||
|       CreateGraphConfig(std::move(options_proto)), | ||||
|       std::move(options->base_options.op_resolver)); | ||||
| } | ||||
| 
 | ||||
| absl::StatusOr<LanguageDetectorResult> LanguageDetector::Detect( | ||||
|     absl::string_view text) { | ||||
|   ASSIGN_OR_RETURN( | ||||
|       auto output_packets, | ||||
|       runner_->Process( | ||||
|           {{kTextStreamName, MakePacket<std::string>(std::string(text))}})); | ||||
|   ClassificationResult classification_result = | ||||
|       ConvertToClassificationResult(output_packets[kClassificationsStreamName] | ||||
|                                         .Get<ClassificationResultProto>()); | ||||
|   return ExtractLanguageDetectorResultFromClassificationResult( | ||||
|       classification_result); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::text::language_detector
 | ||||
|  | @ -0,0 +1,84 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | ||||
| #include "mediapipe/tasks/cc/core/base_options.h" | ||||
| #include "mediapipe/tasks/cc/core/base_task_api.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::language_detector { | ||||
| 
 | ||||
| // A language code and its probability.
 | ||||
| struct LanguageDetectorPrediction { | ||||
|   // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
 | ||||
|   // "ja"-Latn for Japanese (romaji).
 | ||||
|   std::string language_code; | ||||
| 
 | ||||
|   float probability; | ||||
| }; | ||||
| 
 | ||||
| // Task output.
 | ||||
| using LanguageDetectorResult = std::vector<LanguageDetectorPrediction>; | ||||
| 
 | ||||
| // The options for configuring a MediaPipe LanguageDetector task.
 | ||||
| struct LanguageDetectorOptions { | ||||
|   // Base options for configuring MediaPipe Tasks, such as specifying the model
 | ||||
|   // file with metadata, accelerator options, op resolver, etc.
 | ||||
|   tasks::core::BaseOptions base_options; | ||||
| 
 | ||||
|   // Options for configuring the classifier behavior, such as score threshold,
 | ||||
|   // number of results, etc.
 | ||||
|   components::processors::ClassifierOptions classifier_options; | ||||
| }; | ||||
| 
 | ||||
| // Predicts the language of an input text.
 | ||||
| //
 | ||||
| // This API expects a TFLite model with TFLite Model Metadata that
 | ||||
| // contains the mandatory (described below) input tensors, output tensor,
 | ||||
| // and the language codes in an AssociatedFile.
 | ||||
| //
 | ||||
| // Input tensors:
 | ||||
| //   (kTfLiteString)
 | ||||
| //    - 1 input tensor that is scalar or has shape [1] containing the input
 | ||||
| //      string.
 | ||||
| // Output tensor:
 | ||||
| //   (kTfLiteFloat32)
 | ||||
| //    - 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
 | ||||
| class LanguageDetector : core::BaseTaskApi { | ||||
|  public: | ||||
|   using BaseTaskApi::BaseTaskApi; | ||||
| 
 | ||||
|   // Creates a LanguageDetector instance from the provided `options`.
 | ||||
|   static absl::StatusOr<std::unique_ptr<LanguageDetector>> Create( | ||||
|       std::unique_ptr<LanguageDetectorOptions> options); | ||||
| 
 | ||||
|   // Predicts the language of the input `text`.
 | ||||
|   absl::StatusOr<LanguageDetectorResult> Detect(absl::string_view text); | ||||
| 
 | ||||
|   // Shuts down the LanguageDetector instance when all the work is done.
 | ||||
|   absl::Status Close() { return runner_->Close(); } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::text::language_detector
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
 | ||||
|  | @ -0,0 +1,163 @@ | |||
| /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/language_detector/language_detector.h" | ||||
| 
 | ||||
| #include <cmath> | ||||
| #include <cstdlib> | ||||
| #include <memory> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "absl/flags/flag.h" | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/strings/cord.h" | ||||
| #include "absl/strings/str_cat.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "absl/strings/substitute.h" | ||||
| #include "mediapipe/framework/deps/file_path.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::language_detector { | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::file::JoinPath; | ||||
| using ::testing::HasSubstr; | ||||
| using ::testing::Optional; | ||||
| 
 | ||||
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; | ||||
| constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; | ||||
| constexpr char kLanguageDetector[] = "language_detector.tflite"; | ||||
| 
 | ||||
| constexpr float kTolerance = 0.000001; | ||||
| 
 | ||||
| std::string GetFullPath(absl::string_view file_name) { | ||||
|   return JoinPath("./", kTestDataDirectory, file_name); | ||||
| } | ||||
| 
 | ||||
| absl::Status MatchesLanguageDetectorResult( | ||||
|     const LanguageDetectorResult& expected, | ||||
|     const LanguageDetectorResult& actual, float tolerance) { | ||||
|   if (expected.size() != actual.size()) { | ||||
|     return absl::FailedPreconditionError(absl::Substitute( | ||||
|         "Expected $0 predictions, but got $1", expected.size(), actual.size())); | ||||
|   } | ||||
|   for (int i = 0; i < expected.size(); ++i) { | ||||
|     if (expected[i].language_code != actual[i].language_code) { | ||||
|       return absl::FailedPreconditionError(absl::Substitute( | ||||
|           "Expected prediction $0 to have language_code $1, but got $2", i, | ||||
|           expected[i].language_code, actual[i].language_code)); | ||||
|     } | ||||
|     if (std::abs(expected[i].probability - actual[i].probability) > tolerance) { | ||||
|       return absl::FailedPreconditionError(absl::Substitute( | ||||
|           "Expected prediction $0 to have probability $1, but got $2", i, | ||||
|           expected[i].probability, actual[i].probability)); | ||||
|     } | ||||
|   } | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class LanguageDetectorTest : public tflite_shims::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|   options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); | ||||
|   absl::StatusOr<std::unique_ptr<LanguageDetector>> language_detector = | ||||
|       LanguageDetector::Create(std::move(options)); | ||||
| 
 | ||||
|   EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound); | ||||
|   EXPECT_THAT(language_detector.status().message(), | ||||
|               HasSubstr("Unable to open file at")); | ||||
|   EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload), | ||||
|               Optional(absl::Cord(absl::StrCat( | ||||
|                   MediaPipeTasksStatus::kRunnerInitializationError)))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, TestL2CModel) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|   options->base_options.model_asset_path = GetFullPath(kLanguageDetector); | ||||
|   options->classifier_options.score_threshold = 0.3; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector, | ||||
|                           LanguageDetector::Create(std::move(options))); | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       LanguageDetectorResult result_en, | ||||
|       language_detector->Detect("To be, or not to be, that is the question")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "en", .probability = 0.999856}}, result_en, | ||||
|       kTolerance)); | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       LanguageDetectorResult result_fr, | ||||
|       language_detector->Detect( | ||||
|           "Il y a beaucoup de bouches qui parlent et fort peu " | ||||
|           "de têtes qui pensent.")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "fr", .probability = 0.999781}}, result_fr, | ||||
|       kTolerance)); | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       LanguageDetectorResult result_ru, | ||||
|       language_detector->Detect("это какой-то английский язык")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "ru", .probability = 0.993362}}, result_ru, | ||||
|       kTolerance)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, TestMultiplePredictions) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|   options->base_options.model_asset_path = GetFullPath(kLanguageDetector); | ||||
|   options->classifier_options.score_threshold = 0.3; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector, | ||||
|                           LanguageDetector::Create(std::move(options))); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed, | ||||
|                           language_detector->Detect("分久必合合久必分")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "zh", .probability = 0.505424}, | ||||
|        {.language_code = "ja", .probability = 0.481617}}, | ||||
|       result_mixed, kTolerance)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, TestAllowList) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|   options->base_options.model_asset_path = GetFullPath(kLanguageDetector); | ||||
|   options->classifier_options.category_allowlist = {"ja"}; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector, | ||||
|                           LanguageDetector::Create(std::move(options))); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja, | ||||
|                           language_detector->Detect("分久必合合久必分")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "ja", .probability = 0.481617}}, result_ja, | ||||
|       kTolerance)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, TestDenyList) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|   options->base_options.model_asset_path = GetFullPath(kLanguageDetector); | ||||
|   options->classifier_options.score_threshold = 0.3; | ||||
|   options->classifier_options.category_denylist = {"ja"}; | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector, | ||||
|                           LanguageDetector::Create(std::move(options))); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh, | ||||
|                           language_detector->Detect("分久必合合久必分")); | ||||
|   MP_EXPECT_OK(MatchesLanguageDetectorResult( | ||||
|       {{.language_code = "zh", .probability = 0.505424}}, result_zh, | ||||
|       kTolerance)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tasks::text::language_detector
 | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user