Internal change
PiperOrigin-RevId: 519013105
This commit is contained in:
parent
8a55f11952
commit
712ea6f15b
|
@ -75,6 +75,8 @@ cc_library(
|
|||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||
|
@ -43,6 +45,10 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
|||
"Landmarks2TransformMatrix",
|
||||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||
/*version=*/2);
|
||||
// For the LanguageDetector model.
|
||||
AddCustom("NGramHash", mediapipe::tflite_operations::Register_NGRAM_HASH());
|
||||
AddCustom("KmeansEmbeddingLookup",
|
||||
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
|
|
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "language_detector",
|
||||
srcs = ["language_detector.cc"],
|
||||
hdrs = ["language_detector.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers:category",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
namespace kmeans_embedding_lookup_op {
|
||||
|
||||
namespace {
|
||||
|
@ -33,6 +33,10 @@ constexpr int kEncodingTable = 1;
|
|||
constexpr int kCodebook = 2;
|
||||
constexpr int kOutputLabel = 0;
|
||||
|
||||
using ::tflite::GetInput;
|
||||
using ::tflite::GetOutput;
|
||||
using ::tflite::GetTensorData;
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
@ -142,4 +146,4 @@ TfLiteRegistration* Register_KmeansEmbeddingLookup() {
|
|||
return &r;
|
||||
}
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
|
|
@ -27,10 +27,10 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
|
||||
TfLiteRegistration* Register_KmeansEmbeddingLookup();
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::tflite::ArrayFloatNear;
|
||||
|
||||
// Helper class for testing the op.
|
||||
class KmeansEmbeddingLookupModel : public SingleOpModel {
|
||||
class KmeansEmbeddingLookupModel : public tflite::SingleOpModel {
|
||||
public:
|
||||
explicit KmeansEmbeddingLookupModel(
|
||||
std::initializer_list<int> input_shape,
|
||||
|
@ -27,7 +27,7 @@ class KmeansEmbeddingLookupModel : public SingleOpModel {
|
|||
std::initializer_list<int> codebook_shape,
|
||||
std::initializer_list<int> output_shape) {
|
||||
// Setup the model inputs and the interpreter.
|
||||
output_ = AddOutput({TensorType_FLOAT32, output_shape});
|
||||
output_ = AddOutput({tflite::TensorType_FLOAT32, output_shape});
|
||||
SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
|
||||
Register_KmeansEmbeddingLookup);
|
||||
BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
|
||||
|
@ -68,9 +68,9 @@ class KmeansEmbeddingLookupModel : public SingleOpModel {
|
|||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_ = AddInput(TensorType_INT32);
|
||||
int encoding_table_ = AddInput(TensorType_UINT8);
|
||||
int codebook_ = AddInput(TensorType_FLOAT32);
|
||||
int input_ = AddInput(tflite::TensorType_INT32);
|
||||
int encoding_table_ = AddInput(tflite::TensorType_UINT8);
|
||||
int codebook_ = AddInput(tflite::TensorType_FLOAT32);
|
||||
int output_;
|
||||
};
|
||||
|
||||
|
@ -173,4 +173,4 @@ TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
|
||||
namespace ngram_op {
|
||||
|
||||
|
@ -217,21 +217,21 @@ void Free(TfLiteContext* context, void* buffer) {
|
|||
}
|
||||
|
||||
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
SetTensorToDynamic(output);
|
||||
tflite::SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||
context, params->PreprocessInput(
|
||||
tflite::GetInput(context, node, kInputMessage), context));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
if (IsDynamicTensor(output)) {
|
||||
if (tflite::IsDynamicTensor(output)) {
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||
output_size->data[0] = 1;
|
||||
output_size->data[1] = params->GetNumNGrams();
|
||||
|
@ -261,4 +261,4 @@ TfLiteRegistration* Register_NGRAM_HASH() {
|
|||
return &r;
|
||||
}
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
|
|
@ -18,10 +18,10 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
|
||||
TfLiteRegistration* Register_NGRAM_HASH();
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace mediapipe::tflite_operations {
|
||||
namespace {
|
||||
|
||||
using ::flexbuffers::Builder;
|
||||
|
@ -42,7 +42,7 @@ using ::testing::ElementsAreArray;
|
|||
using ::testing::Message;
|
||||
|
||||
// Helper class for testing the op.
|
||||
class NGramHashModel : public SingleOpModel {
|
||||
class NGramHashModel : public tflite::SingleOpModel {
|
||||
public:
|
||||
explicit NGramHashModel(const uint64_t seed,
|
||||
const std::vector<int>& ngram_lengths,
|
||||
|
@ -71,7 +71,7 @@ class NGramHashModel : public SingleOpModel {
|
|||
}
|
||||
fbb.EndMap(start);
|
||||
fbb.Finish();
|
||||
output_ = AddOutput({TensorType_INT32, {}});
|
||||
output_ = AddOutput({tflite::TensorType_INT32, {}});
|
||||
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ class NGramHashModel : public SingleOpModel {
|
|||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_ = AddInput(TensorType_STRING);
|
||||
int input_ = AddInput(tflite::TensorType_STRING);
|
||||
int output_;
|
||||
};
|
||||
|
||||
|
@ -173,7 +173,7 @@ TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
|||
|
||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||
const string& testcase_input = testcase_inputs[test_idx];
|
||||
const std::string& testcase_input = testcase_inputs[test_idx];
|
||||
m.Invoke(testcase_input);
|
||||
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||
<< testcase_input);
|
||||
|
@ -310,4 +310,4 @@ TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite::ops::custom
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
|
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
|
@ -0,0 +1,126 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::Category;
|
||||
using ::mediapipe::tasks::components::containers::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||
using ClassificationResultProto =
|
||||
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::text::text_classifier::proto::
|
||||
TextClassifierGraphOptions;
|
||||
|
||||
constexpr char kTextStreamName[] = "text_in";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// type "TextClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<TextClassifierGraphOptions> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
subgraph.GetOptions<TextClassifierGraphOptions>().Swap(options.get());
|
||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||
graph.Out(kClassificationsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing LanguageDetectorOptions struct to the internal
|
||||
// TextClassifierGraphOptions proto.
|
||||
std::unique_ptr<TextClassifierGraphOptions>
|
||||
ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) {
|
||||
auto options_proto = std::make_unique<TextClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<LanguageDetectorResult>
|
||||
ExtractLanguageDetectorResultFromClassificationResult(
|
||||
const ClassificationResult& classification_result) {
|
||||
if (classification_result.classifications.size() != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The LanguageDetector TextClassifierGraph should have exactly one "
|
||||
"classification head.");
|
||||
}
|
||||
const Classifications& languages_and_scores =
|
||||
classification_result.classifications[0];
|
||||
LanguageDetectorResult language_detector_result;
|
||||
for (const Category& category : languages_and_scores.categories) {
|
||||
if (!category.category_name.has_value()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"LanguageDetector ClassificationResult has a missing language code.");
|
||||
}
|
||||
language_detector_result.push_back(
|
||||
{.language_code = *category.category_name,
|
||||
.probability = category.score});
|
||||
}
|
||||
return language_detector_result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<LanguageDetector>> LanguageDetector::Create(
|
||||
std::unique_ptr<LanguageDetectorOptions> options) {
|
||||
auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<LanguageDetector,
|
||||
TextClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<LanguageDetectorResult> LanguageDetector::Detect(
|
||||
absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
ClassificationResult classification_result =
|
||||
ConvertToClassificationResult(output_packets[kClassificationsStreamName]
|
||||
.Get<ClassificationResultProto>());
|
||||
return ExtractLanguageDetectorResultFromClassificationResult(
|
||||
classification_result);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
|
@ -0,0 +1,84 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
|
||||
// A language code and its probability.
|
||||
struct LanguageDetectorPrediction {
|
||||
// An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||
// "ja"-Latn for Japanese (romaji).
|
||||
std::string language_code;
|
||||
|
||||
float probability;
|
||||
};
|
||||
|
||||
// Task output.
|
||||
using LanguageDetectorResult = std::vector<LanguageDetectorPrediction>;
|
||||
|
||||
// The options for configuring a MediaPipe LanguageDetector task.
|
||||
struct LanguageDetectorOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
};
|
||||
|
||||
// Predicts the language of an input text.
|
||||
//
|
||||
// This API expects a TFLite model with TFLite Model Metadata that
|
||||
// contains the mandatory (described below) input tensors, output tensor,
|
||||
// and the language codes in an AssociatedFile.
|
||||
//
|
||||
// Input tensors:
|
||||
// (kTfLiteString)
|
||||
// - 1 input tensor that is scalar or has shape [1] containing the input
|
||||
// string.
|
||||
// Output tensor:
|
||||
// (kTfLiteFloat32)
|
||||
// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||
class LanguageDetector : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a LanguageDetector instance from the provided `options`.
|
||||
static absl::StatusOr<std::unique_ptr<LanguageDetector>> Create(
|
||||
std::unique_ptr<LanguageDetectorOptions> options);
|
||||
|
||||
// Predicts the language of the input `text`.
|
||||
absl::StatusOr<LanguageDetectorResult> Detect(absl::string_view text);
|
||||
|
||||
// Shuts down the LanguageDetector instance when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
|
@ -0,0 +1,163 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||
constexpr char kLanguageDetector[] = "language_detector.tflite";
|
||||
|
||||
constexpr float kTolerance = 0.000001;
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
absl::Status MatchesLanguageDetectorResult(
|
||||
const LanguageDetectorResult& expected,
|
||||
const LanguageDetectorResult& actual, float tolerance) {
|
||||
if (expected.size() != actual.size()) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected $0 predictions, but got $1", expected.size(), actual.size()));
|
||||
}
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
if (expected[i].language_code != actual[i].language_code) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected prediction $0 to have language_code $1, but got $2", i,
|
||||
expected[i].language_code, actual[i].language_code));
|
||||
}
|
||||
if (std::abs(expected[i].probability - actual[i].probability) > tolerance) {
|
||||
return absl::FailedPreconditionError(absl::Substitute(
|
||||
"Expected prediction $0 to have probability $1, but got $2", i,
|
||||
expected[i].probability, actual[i].probability));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||
absl::StatusOr<std::unique_ptr<LanguageDetector>> language_detector =
|
||||
LanguageDetector::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(language_detector.status().message(),
|
||||
HasSubstr("Unable to open file at"));
|
||||
EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestL2CModel) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_en,
|
||||
language_detector->Detect("To be, or not to be, that is the question"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "en", .probability = 0.999856}}, result_en,
|
||||
kTolerance));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_fr,
|
||||
language_detector->Detect(
|
||||
"Il y a beaucoup de bouches qui parlent et fort peu "
|
||||
"de têtes qui pensent."));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "fr", .probability = 0.999781}}, result_fr,
|
||||
kTolerance));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
LanguageDetectorResult result_ru,
|
||||
language_detector->Detect("это какой-то английский язык"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "ru", .probability = 0.993362}}, result_ru,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestMultiplePredictions) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "zh", .probability = 0.505424},
|
||||
{.language_code = "ja", .probability = 0.481617}},
|
||||
result_mixed, kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestAllowList) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.category_allowlist = {"ja"};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "ja", .probability = 0.481617}}, result_ja,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
TEST_F(LanguageDetectorTest, TestDenyList) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||
options->classifier_options.score_threshold = 0.3;
|
||||
options->classifier_options.category_denylist = {"ja"};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||
LanguageDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh,
|
||||
language_detector->Detect("分久必合合久必分"));
|
||||
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||
{{.language_code = "zh", .probability = 0.505424}}, result_zh,
|
||||
kTolerance));
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector
|
Loading…
Reference in New Issue
Block a user