From 58fa1e2ec358d1b8068d5b2c79d3dddceae99685 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 23 Mar 2023 03:27:55 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 518813508 --- mediapipe/tasks/cc/core/BUILD | 2 - .../cc/core/mediapipe_builtin_op_resolver.cc | 6 - .../tasks/cc/text/language_detector/BUILD | 38 ---- .../language_detector/language_detector.cc | 126 -------------- .../language_detector/language_detector.h | 84 --------- .../language_detector_test.cc | 163 ------------------ 6 files changed, 419 deletions(-) delete mode 100644 mediapipe/tasks/cc/text/language_detector/BUILD delete mode 100644 mediapipe/tasks/cc/text/language_detector/language_detector.cc delete mode 100644 mediapipe/tasks/cc/text/language_detector/language_detector.h delete mode 100644 mediapipe/tasks/cc/text/language_detector/language_detector_test.cc diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index fd07045c6..3d01639ce 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -75,8 +75,6 @@ cc_library( srcs = ["mediapipe_builtin_op_resolver.cc"], hdrs = ["mediapipe_builtin_op_resolver.h"], deps = [ - "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", - "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", "//mediapipe/util/tflite/operations:max_unpooling", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index 5b4df930e..be8dec684 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -15,8 +15,6 @@ limitations under the License. #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" -#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" -#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_unpooling.h" @@ -45,10 +43,6 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { "Landmarks2TransformMatrix", mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); - // For the LanguageDetector model. - AddCustom("NGramHash", ::tflite::ops::custom::Register_NGRAM_HASH()); - AddCustom("KmeansEmbeddingLookup", - ::tflite::ops::custom::Register_KmeansEmbeddingLookup()); } } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/text/language_detector/BUILD b/mediapipe/tasks/cc/text/language_detector/BUILD deleted file mode 100644 index 57b9c7b51..000000000 --- a/mediapipe/tasks/cc/text/language_detector/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -cc_library( - name = "language_detector", - srcs = ["language_detector.cc"], - hdrs = ["language_detector.h"], - visibility = ["//visibility:public"], - deps = [ - "//mediapipe/framework/api2:builder", - "//mediapipe/tasks/cc/components/containers:category", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", - "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.cc b/mediapipe/tasks/cc/text/language_detector/language_detector.cc deleted file mode 100644 index e3841211b..000000000 --- a/mediapipe/tasks/cc/text/language_detector/language_detector.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" - -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/api2/builder.h" -#include "mediapipe/tasks/cc/components/containers/category.h" -#include "mediapipe/tasks/cc/components/containers/classification_result.h" -#include "mediapipe/tasks/cc/core/task_api_factory.h" -#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" - -namespace mediapipe::tasks::text::language_detector { - -namespace { - -using ::mediapipe::tasks::components::containers::Category; -using ::mediapipe::tasks::components::containers::ClassificationResult; -using ::mediapipe::tasks::components::containers::Classifications; -using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; -using ClassificationResultProto = - ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::text::text_classifier::proto:: - TextClassifierGraphOptions; - -constexpr char kTextStreamName[] = "text_in"; -constexpr char kTextTag[] = "TEXT"; -constexpr char kClassificationsStreamName[] = "classifications_out"; -constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; -constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; - -// Creates a MediaPipe graph config that only contains a single subgraph node of -// type "TextClassifierGraph". -CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options) { - api2::builder::Graph graph; - auto& subgraph = graph.AddNode(kSubgraphTypeName); - subgraph.GetOptions().Swap(options.get()); - graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); - subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> - graph.Out(kClassificationsTag); - return graph.GetConfig(); -} - -// Converts the user-facing LanguageDetectorOptions struct to the internal -// TextClassifierGraphOptions proto. -std::unique_ptr -ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) { - auto options_proto = std::make_unique(); - auto base_options_proto = std::make_unique( - tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); - options_proto->mutable_base_options()->Swap(base_options_proto.get()); - auto classifier_options_proto = - std::make_unique( - components::processors::ConvertClassifierOptionsToProto( - &(options->classifier_options))); - options_proto->mutable_classifier_options()->Swap( - classifier_options_proto.get()); - return options_proto; -} - -absl::StatusOr -ExtractLanguageDetectorResultFromClassificationResult( - const ClassificationResult& classification_result) { - if (classification_result.classifications.size() != 1) { - return absl::InvalidArgumentError( - "The LanguageDetector TextClassifierGraph should have exactly one " - "classification head."); - } - const Classifications& languages_and_scores = - classification_result.classifications[0]; - LanguageDetectorResult language_detector_result; - for (const Category& category : languages_and_scores.categories) { - if (!category.category_name.has_value()) { - return absl::InvalidArgumentError( - "LanguageDetector ClassificationResult has a missing language code."); - } - language_detector_result.push_back( - {.language_code = *category.category_name, - .probability = category.score}); - } - return language_detector_result; -} - -} // namespace - -absl::StatusOr> LanguageDetector::Create( - std::unique_ptr options) { - auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get()); - return core::TaskApiFactory::Create( - CreateGraphConfig(std::move(options_proto)), - std::move(options->base_options.op_resolver)); -} - -absl::StatusOr LanguageDetector::Detect( - absl::string_view text) { - ASSIGN_OR_RETURN( - auto output_packets, - runner_->Process( - {{kTextStreamName, MakePacket(std::string(text))}})); - ClassificationResult classification_result = - ConvertToClassificationResult(output_packets[kClassificationsStreamName] - .Get()); - return ExtractLanguageDetectorResultFromClassificationResult( - classification_result); -} - -} // namespace mediapipe::tasks::text::language_detector diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector.h b/mediapipe/tasks/cc/text/language_detector/language_detector.h deleted file mode 100644 index bbe58dedf..000000000 --- a/mediapipe/tasks/cc/text/language_detector/language_detector.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ - -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "mediapipe/tasks/cc/components/processors/classifier_options.h" -#include "mediapipe/tasks/cc/core/base_options.h" -#include "mediapipe/tasks/cc/core/base_task_api.h" - -namespace mediapipe::tasks::text::language_detector { - -// A language code and its probability. -struct LanguageDetectorPrediction { - // An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, - // "ja"-Latn for Japanese (romaji). - std::string language_code; - - float probability; -}; - -// Task output. -using LanguageDetectorResult = std::vector; - -// The options for configuring a MediaPipe LanguageDetector task. -struct LanguageDetectorOptions { - // Base options for configuring MediaPipe Tasks, such as specifying the model - // file with metadata, accelerator options, op resolver, etc. - tasks::core::BaseOptions base_options; - - // Options for configuring the classifier behavior, such as score threshold, - // number of results, etc. - components::processors::ClassifierOptions classifier_options; -}; - -// Predicts the language of an input text. -// -// This API expects a TFLite model with TFLite Model Metadata that -// contains the mandatory (described below) input tensors, output tensor, -// and the language codes in an AssociatedFile. -// -// Input tensors: -// (kTfLiteString) -// - 1 input tensor that is scalar or has shape [1] containing the input -// string. -// Output tensor: -// (kTfLiteFloat32) -// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages. -class LanguageDetector : core::BaseTaskApi { - public: - using BaseTaskApi::BaseTaskApi; - - // Creates a LanguageDetector instance from the provided `options`. - static absl::StatusOr> Create( - std::unique_ptr options); - - // Predicts the language of the input `text`. - absl::StatusOr Detect(absl::string_view text); - - // Shuts down the LanguageDetector instance when all the work is done. - absl::Status Close() { return runner_->Close(); } -}; - -} // namespace mediapipe::tasks::text::language_detector - -#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc deleted file mode 100644 index 92dc493e0..000000000 --- a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/text/language_detector/language_detector.h" - -#include -#include -#include -#include -#include - -#include "absl/flags/flag.h" -#include "absl/status/status.h" -#include "absl/strings/cord.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/strings/substitute.h" -#include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/status_matchers.h" -#include "mediapipe/tasks/cc/common.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" - -namespace mediapipe::tasks::text::language_detector { -namespace { - -using ::mediapipe::file::JoinPath; -using ::testing::HasSubstr; -using ::testing::Optional; - -constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; -constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; -constexpr char kLanguageDetector[] = "language_detector.tflite"; - -constexpr float kTolerance = 0.000001; - -std::string GetFullPath(absl::string_view file_name) { - return JoinPath("./", kTestDataDirectory, file_name); -} - -absl::Status MatchesLanguageDetectorResult( - const LanguageDetectorResult& expected, - const LanguageDetectorResult& actual, float tolerance) { - if (expected.size() != actual.size()) { - return absl::FailedPreconditionError(absl::Substitute( - "Expected $0 predictions, but got $1", expected.size(), actual.size())); - } - for (int i = 0; i < expected.size(); ++i) { - if (expected[i].language_code != actual[i].language_code) { - return absl::FailedPreconditionError(absl::Substitute( - "Expected prediction $0 to have language_code $1, but got $2", i, - expected[i].language_code, actual[i].language_code)); - } - if (std::abs(expected[i].probability - actual[i].probability) > tolerance) { - return absl::FailedPreconditionError(absl::Substitute( - "Expected prediction $0 to have probability $1, but got $2", i, - expected[i].probability, actual[i].probability)); - } - } - return absl::OkStatus(); -} - -} // namespace - -class LanguageDetectorTest : public tflite_shims::testing::Test {}; - -TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); - absl::StatusOr> language_detector = - LanguageDetector::Create(std::move(options)); - - EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound); - EXPECT_THAT(language_detector.status().message(), - HasSubstr("Unable to open file at")); - EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - -TEST_F(LanguageDetectorTest, TestL2CModel) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kLanguageDetector); - options->classifier_options.score_threshold = 0.3; - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, - LanguageDetector::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - LanguageDetectorResult result_en, - language_detector->Detect("To be, or not to be, that is the question")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "en", .probability = 0.999856}}, result_en, - kTolerance)); - MP_ASSERT_OK_AND_ASSIGN( - LanguageDetectorResult result_fr, - language_detector->Detect( - "Il y a beaucoup de bouches qui parlent et fort peu " - "de têtes qui pensent.")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "fr", .probability = 0.999781}}, result_fr, - kTolerance)); - MP_ASSERT_OK_AND_ASSIGN( - LanguageDetectorResult result_ru, - language_detector->Detect("это какой-то английский язык")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "ru", .probability = 0.993362}}, result_ru, - kTolerance)); -} - -TEST_F(LanguageDetectorTest, TestMultiplePredictions) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kLanguageDetector); - options->classifier_options.score_threshold = 0.3; - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, - LanguageDetector::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed, - language_detector->Detect("分久必合合久必分")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "zh", .probability = 0.505424}, - {.language_code = "ja", .probability = 0.481617}}, - result_mixed, kTolerance)); -} - -TEST_F(LanguageDetectorTest, TestAllowList) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kLanguageDetector); - options->classifier_options.category_allowlist = {"ja"}; - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, - LanguageDetector::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja, - language_detector->Detect("分久必合合久必分")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "ja", .probability = 0.481617}}, result_ja, - kTolerance)); -} - -TEST_F(LanguageDetectorTest, TestDenyList) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kLanguageDetector); - options->classifier_options.score_threshold = 0.3; - options->classifier_options.category_denylist = {"ja"}; - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr language_detector, - LanguageDetector::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh, - language_detector->Detect("分久必合合久必分")); - MP_EXPECT_OK(MatchesLanguageDetectorResult( - {{.language_code = "zh", .probability = 0.505424}}, result_zh, - kTolerance)); -} - -} // namespace mediapipe::tasks::text::language_detector