Open-sources a LanguageDetector C++ API.
PiperOrigin-RevId: 518758730
This commit is contained in:
parent
1a7be8a4c1
commit
eac6348fd3
|
@ -75,6 +75,8 @@ cc_library(
|
||||||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||||
|
@ -43,6 +45,10 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||||
"Landmarks2TransformMatrix",
|
"Landmarks2TransformMatrix",
|
||||||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||||
/*version=*/2);
|
/*version=*/2);
|
||||||
|
// For the LanguageDetector model.
|
||||||
|
AddCustom("NGramHash", ::tflite::ops::custom::Register_NGRAM_HASH());
|
||||||
|
AddCustom("KmeansEmbeddingLookup",
|
||||||
|
::tflite::ops::custom::Register_KmeansEmbeddingLookup());
|
||||||
}
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "language_detector",
|
||||||
|
srcs = ["language_detector.cc"],
|
||||||
|
hdrs = ["language_detector.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::Category;
|
||||||
|
using ::mediapipe::tasks::components::containers::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::Classifications;
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||||
|
using ClassificationResultProto =
|
||||||
|
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::text::text_classifier::proto::
|
||||||
|
TextClassifierGraphOptions;
|
||||||
|
|
||||||
|
constexpr char kTextStreamName[] = "text_in";
|
||||||
|
constexpr char kTextTag[] = "TEXT";
|
||||||
|
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// type "TextClassifierGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<TextClassifierGraphOptions> options) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
subgraph.GetOptions<TextClassifierGraphOptions>().Swap(options.get());
|
||||||
|
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||||
|
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||||
|
graph.Out(kClassificationsTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing LanguageDetectorOptions struct to the internal
|
||||||
|
// TextClassifierGraphOptions proto.
|
||||||
|
std::unique_ptr<TextClassifierGraphOptions>
|
||||||
|
ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<TextClassifierGraphOptions>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
auto classifier_options_proto =
|
||||||
|
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||||
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
|
&(options->classifier_options)));
|
||||||
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
|
classifier_options_proto.get());
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<LanguageDetectorResult>
|
||||||
|
ExtractLanguageDetectorResultFromClassificationResult(
|
||||||
|
const ClassificationResult& classification_result) {
|
||||||
|
if (classification_result.classifications.size() != 1) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The LanguageDetector TextClassifierGraph should have exactly one "
|
||||||
|
"classification head.");
|
||||||
|
}
|
||||||
|
const Classifications& languages_and_scores =
|
||||||
|
classification_result.classifications[0];
|
||||||
|
LanguageDetectorResult language_detector_result;
|
||||||
|
for (const Category& category : languages_and_scores.categories) {
|
||||||
|
if (!category.category_name.has_value()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"LanguageDetector ClassificationResult has a missing language code.");
|
||||||
|
}
|
||||||
|
language_detector_result.push_back(
|
||||||
|
{.language_code = *category.category_name,
|
||||||
|
.probability = category.score});
|
||||||
|
}
|
||||||
|
return language_detector_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<LanguageDetector>> LanguageDetector::Create(
|
||||||
|
std::unique_ptr<LanguageDetectorOptions> options) {
|
||||||
|
auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get());
|
||||||
|
return core::TaskApiFactory::Create<LanguageDetector,
|
||||||
|
TextClassifierGraphOptions>(
|
||||||
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
|
std::move(options->base_options.op_resolver));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<LanguageDetectorResult> LanguageDetector::Detect(
|
||||||
|
absl::string_view text) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
runner_->Process(
|
||||||
|
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||||
|
ClassificationResult classification_result =
|
||||||
|
ConvertToClassificationResult(output_packets[kClassificationsStreamName]
|
||||||
|
.Get<ClassificationResultProto>());
|
||||||
|
return ExtractLanguageDetectorResultFromClassificationResult(
|
||||||
|
classification_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
|
@ -0,0 +1,84 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
|
||||||
|
// A language code and its probability.
|
||||||
|
struct LanguageDetectorPrediction {
|
||||||
|
// An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||||
|
// "ja"-Latn for Japanese (romaji).
|
||||||
|
std::string language_code;
|
||||||
|
|
||||||
|
float probability;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Task output.
|
||||||
|
using LanguageDetectorResult = std::vector<LanguageDetectorPrediction>;
|
||||||
|
|
||||||
|
// The options for configuring a MediaPipe LanguageDetector task.
|
||||||
|
struct LanguageDetectorOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
|
// number of results, etc.
|
||||||
|
components::processors::ClassifierOptions classifier_options;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Predicts the language of an input text.
|
||||||
|
//
|
||||||
|
// This API expects a TFLite model with TFLite Model Metadata that
|
||||||
|
// contains the mandatory (described below) input tensors, output tensor,
|
||||||
|
// and the language codes in an AssociatedFile.
|
||||||
|
//
|
||||||
|
// Input tensors:
|
||||||
|
// (kTfLiteString)
|
||||||
|
// - 1 input tensor that is scalar or has shape [1] containing the input
|
||||||
|
// string.
|
||||||
|
// Output tensor:
|
||||||
|
// (kTfLiteFloat32)
|
||||||
|
// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||||
|
class LanguageDetector : core::BaseTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseTaskApi::BaseTaskApi;
|
||||||
|
|
||||||
|
// Creates a LanguageDetector instance from the provided `options`.
|
||||||
|
static absl::StatusOr<std::unique_ptr<LanguageDetector>> Create(
|
||||||
|
std::unique_ptr<LanguageDetectorOptions> options);
|
||||||
|
|
||||||
|
// Predicts the language of the input `text`.
|
||||||
|
absl::StatusOr<LanguageDetectorResult> Detect(absl::string_view text);
|
||||||
|
|
||||||
|
// Shuts down the LanguageDetector instance when all the work is done.
|
||||||
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
|
@ -0,0 +1,163 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/cord.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
|
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||||
|
constexpr char kLanguageDetector[] = "language_detector.tflite";
|
||||||
|
|
||||||
|
constexpr float kTolerance = 0.000001;
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status MatchesLanguageDetectorResult(
|
||||||
|
const LanguageDetectorResult& expected,
|
||||||
|
const LanguageDetectorResult& actual, float tolerance) {
|
||||||
|
if (expected.size() != actual.size()) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected $0 predictions, but got $1", expected.size(), actual.size()));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
|
if (expected[i].language_code != actual[i].language_code) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected prediction $0 to have language_code $1, but got $2", i,
|
||||||
|
expected[i].language_code, actual[i].language_code));
|
||||||
|
}
|
||||||
|
if (std::abs(expected[i].probability - actual[i].probability) > tolerance) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected prediction $0 to have probability $1, but got $2", i,
|
||||||
|
expected[i].probability, actual[i].probability));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||||
|
absl::StatusOr<std::unique_ptr<LanguageDetector>> language_detector =
|
||||||
|
LanguageDetector::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_THAT(language_detector.status().message(),
|
||||||
|
HasSubstr("Unable to open file at"));
|
||||||
|
EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestL2CModel) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_en,
|
||||||
|
language_detector->Detect("To be, or not to be, that is the question"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "en", .probability = 0.999856}}, result_en,
|
||||||
|
kTolerance));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_fr,
|
||||||
|
language_detector->Detect(
|
||||||
|
"Il y a beaucoup de bouches qui parlent et fort peu "
|
||||||
|
"de têtes qui pensent."));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "fr", .probability = 0.999781}}, result_fr,
|
||||||
|
kTolerance));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_ru,
|
||||||
|
language_detector->Detect("это какой-то английский язык"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "ru", .probability = 0.993362}}, result_ru,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestMultiplePredictions) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "zh", .probability = 0.505424},
|
||||||
|
{.language_code = "ja", .probability = 0.481617}},
|
||||||
|
result_mixed, kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestAllowList) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.category_allowlist = {"ja"};
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "ja", .probability = 0.481617}}, result_ja,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestDenyList) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
options->classifier_options.category_denylist = {"ja"};
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "zh", .probability = 0.505424}}, result_zh,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
Loading…
Reference in New Issue
Block a user