Internal change
PiperOrigin-RevId: 518813508
This commit is contained in:
parent
eac6348fd3
commit
58fa1e2ec3
|
@ -75,8 +75,6 @@ cc_library(
|
|||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
|
|
|
@ -15,8 +15,6 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||
|
@ -45,10 +43,6 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
|||
"Landmarks2TransformMatrix",
|
||||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||
/*version=*/2);
|
||||
// For the LanguageDetector model.
|
||||
AddCustom("NGramHash", ::tflite::ops::custom::Register_NGRAM_HASH());
|
||||
AddCustom("KmeansEmbeddingLookup",
|
||||
::tflite::ops::custom::Register_KmeansEmbeddingLookup());
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "language_detector",
|
||||
srcs = ["language_detector.cc"],
|
||||
hdrs = ["language_detector.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers:category",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
|
@ -1,126 +0,0 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::Category;
|
||||
using ::mediapipe::tasks::components::containers::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||
using ClassificationResultProto =
|
||||
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::text::text_classifier::proto::
|
||||
TextClassifierGraphOptions;
|
||||
|
||||
constexpr char kTextStreamName[] = "text_in";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// type "TextClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<TextClassifierGraphOptions> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
subgraph.GetOptions<TextClassifierGraphOptions>().Swap(options.get());
|
||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||
graph.Out(kClassificationsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing LanguageDetectorOptions struct to the internal
|
||||
// TextClassifierGraphOptions proto.
|
||||
std::unique_ptr<TextClassifierGraphOptions>
|
||||
ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) {
|
||||
auto options_proto = std::make_unique<TextClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<LanguageDetectorResult>
|
||||
ExtractLanguageDetectorResultFromClassificationResult(
|
||||
const ClassificationResult& classification_result) {
|
||||
if (classification_result.classifications.size() != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The LanguageDetector TextClassifierGraph should have exactly one "
|
||||
"classification head.");
|
||||
}
|
||||
const Classifications& languages_and_scores =
|
||||
classification_result.classifications[0];
|
||||
LanguageDetectorResult language_detector_result;
|
||||
for (const Category& category : languages_and_scores.categories) {
|
||||
if (!category.category_name.has_value()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"LanguageDetector ClassificationResult has a missing language code.");
|
||||
}
|
||||
language_detector_result.push_back(
|
||||
{.language_code = *category.category_name,
|
||||
.probability = category.score});
|
||||
}
|
||||
return language_detector_result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<LanguageDetector>> LanguageDetector::Create(
|
||||
std::unique_ptr<LanguageDetectorOptions> options) {
|
||||
auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<LanguageDetector,
|
||||
TextClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<LanguageDetectorResult> LanguageDetector::Detect(
|
||||
absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
ClassificationResult classification_result =
|
||||
ConvertToClassificationResult(output_packets[kClassificationsStreamName]
|
||||
.Get<ClassificationResultProto>());
|
||||
return ExtractLanguageDetectorResultFromClassificationResult(
|
||||
classification_result);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
|
@ -1,84 +0,0 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
|
||||
// A language code and its probability.
|
||||
struct LanguageDetectorPrediction {
|
||||
// An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||
// "ja"-Latn for Japanese (romaji).
|
||||
std::string language_code;
|
||||
|
||||
float probability;
|
||||
};
|
||||
|
||||
// Task output.
|
||||
using LanguageDetectorResult = std::vector<LanguageDetectorPrediction>;
|
||||
|
||||
// The options for configuring a MediaPipe LanguageDetector task.
|
||||
struct LanguageDetectorOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
};
|
||||
|
||||
// Predicts the language of an input text.
|
||||
//
|
||||
// This API expects a TFLite model with TFLite Model Metadata that
|
||||
// contains the mandatory (described below) input tensors, output tensor,
|
||||
// and the language codes in an AssociatedFile.
|
||||
//
|
||||
// Input tensors:
|
||||
// (kTfLiteString)
|
||||
// - 1 input tensor that is scalar or has shape [1] containing the input
|
||||
// string.
|
||||
// Output tensor:
|
||||
// (kTfLiteFloat32)
|
||||
// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||
class LanguageDetector : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a LanguageDetector instance from the provided `options`.
|
||||
static absl::StatusOr<std::unique_ptr<LanguageDetector>> Create(
|
||||
std::unique_ptr<LanguageDetectorOptions> options);
|
||||
|
||||
// Predicts the language of the input `text`.
|
||||
absl::StatusOr<LanguageDetectorResult> Detect(absl::string_view text);
|
||||
|
||||
// Shuts down the LanguageDetector instance when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
|
@ -1,163 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||
constexpr char kLanguageDetector[] = "language_detector.tflite";
|
||||
|
||||
constexpr float kTolerance = 0.000001;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
absl::Status MatchesLanguageDetectorResult(
|
||||
const LanguageDetectorResult& expected,
|
||||
const LanguageDetectorResult& actual, float tolerance) {
|
||||
if (expected.size() != actual.size()) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected $0 predictions, but got $1", expected.size(), actual.size()));
|
||||
}
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
if (expected[i].language_code != actual[i].language_code) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected prediction $0 to have language_code $1, but got $2", i,
|
||||
expected[i].language_code, actual[i].language_code));
|
||||
}
|
||||
if (std::abs(expected[i].probability - actual[i].probability) > tolerance) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected prediction $0 to have probability $1, but got $2", i,
|
||||
expected[i].probability, actual[i].probability));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||
absl::StatusOr<std::unique_ptr<LanguageDetector>> language_detector =
|
||||
LanguageDetector::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(language_detector.status().message(),
|
||||
HasSubstr("Unable to open file at"));
|
||||
EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestL2CModel) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_en,
|
||||
language_detector->Detect("To be, or not to be, that is the question"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "en", .probability = 0.999856}}, result_en,
|
||||
kTolerance));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_fr,
|
||||
language_detector->Detect(
|
||||
"Il y a beaucoup de bouches qui parlent et fort peu "
|
||||
"de têtes qui pensent."));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "fr", .probability = 0.999781}}, result_fr,
|
||||
kTolerance));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_ru,
|
||||
language_detector->Detect("это какой-то английский язык"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "ru", .probability = 0.993362}}, result_ru,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestMultiplePredictions) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "zh", .probability = 0.505424},
|
||||
{.language_code = "ja", .probability = 0.481617}},
|
||||
result_mixed, kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestAllowList) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.category_allowlist = {"ja"};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "ja", .probability = 0.481617}}, result_ja,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestDenyList) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
options->classifier_options.category_denylist = {"ja"};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "zh", .probability = 0.505424}}, result_zh,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
Loading…
Reference in New Issue
Block a user