Internal change
PiperOrigin-RevId: 519013105
This commit is contained in:
parent
8a55f11952
commit
712ea6f15b
|
@ -75,6 +75,8 @@ cc_library(
|
||||||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||||
|
@ -43,6 +45,10 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||||
"Landmarks2TransformMatrix",
|
"Landmarks2TransformMatrix",
|
||||||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||||
/*version=*/2);
|
/*version=*/2);
|
||||||
|
// For the LanguageDetector model.
|
||||||
|
AddCustom("NGramHash", mediapipe::tflite_operations::Register_NGRAM_HASH());
|
||||||
|
AddCustom("KmeansEmbeddingLookup",
|
||||||
|
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||||
}
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
38
mediapipe/tasks/cc/text/language_detector/BUILD
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "language_detector",
|
||||||
|
srcs = ["language_detector.cc"],
|
||||||
|
hdrs = ["language_detector.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
namespace kmeans_embedding_lookup_op {
|
namespace kmeans_embedding_lookup_op {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -33,6 +33,10 @@ constexpr int kEncodingTable = 1;
|
||||||
constexpr int kCodebook = 2;
|
constexpr int kCodebook = 2;
|
||||||
constexpr int kOutputLabel = 0;
|
constexpr int kOutputLabel = 0;
|
||||||
|
|
||||||
|
using ::tflite::GetInput;
|
||||||
|
using ::tflite::GetOutput;
|
||||||
|
using ::tflite::GetTensorData;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
@ -142,4 +146,4 @@ TfLiteRegistration* Register_KmeansEmbeddingLookup() {
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
|
@ -27,10 +27,10 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
|
|
||||||
TfLiteRegistration* Register_KmeansEmbeddingLookup();
|
TfLiteRegistration* Register_KmeansEmbeddingLookup();
|
||||||
|
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
||||||
|
|
|
@ -12,14 +12,14 @@
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
using ::tflite::ArrayFloatNear;
|
using ::tflite::ArrayFloatNear;
|
||||||
|
|
||||||
// Helper class for testing the op.
|
// Helper class for testing the op.
|
||||||
class KmeansEmbeddingLookupModel : public SingleOpModel {
|
class KmeansEmbeddingLookupModel : public tflite::SingleOpModel {
|
||||||
public:
|
public:
|
||||||
explicit KmeansEmbeddingLookupModel(
|
explicit KmeansEmbeddingLookupModel(
|
||||||
std::initializer_list<int> input_shape,
|
std::initializer_list<int> input_shape,
|
||||||
|
@ -27,7 +27,7 @@ class KmeansEmbeddingLookupModel : public SingleOpModel {
|
||||||
std::initializer_list<int> codebook_shape,
|
std::initializer_list<int> codebook_shape,
|
||||||
std::initializer_list<int> output_shape) {
|
std::initializer_list<int> output_shape) {
|
||||||
// Setup the model inputs and the interpreter.
|
// Setup the model inputs and the interpreter.
|
||||||
output_ = AddOutput({TensorType_FLOAT32, output_shape});
|
output_ = AddOutput({tflite::TensorType_FLOAT32, output_shape});
|
||||||
SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
|
SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
|
||||||
Register_KmeansEmbeddingLookup);
|
Register_KmeansEmbeddingLookup);
|
||||||
BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
|
BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
|
||||||
|
@ -68,9 +68,9 @@ class KmeansEmbeddingLookupModel : public SingleOpModel {
|
||||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int input_ = AddInput(TensorType_INT32);
|
int input_ = AddInput(tflite::TensorType_INT32);
|
||||||
int encoding_table_ = AddInput(TensorType_UINT8);
|
int encoding_table_ = AddInput(tflite::TensorType_UINT8);
|
||||||
int codebook_ = AddInput(TensorType_FLOAT32);
|
int codebook_ = AddInput(tflite::TensorType_FLOAT32);
|
||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -173,4 +173,4 @@ TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
|
|
||||||
namespace ngram_op {
|
namespace ngram_op {
|
||||||
|
|
||||||
|
@ -217,21 +217,21 @@ void Free(TfLiteContext* context, void* buffer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
SetTensorToDynamic(output);
|
tflite::SetTensorToDynamic(output);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context,
|
context, params->PreprocessInput(
|
||||||
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
tflite::GetInput(context, node, kInputMessage), context));
|
||||||
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
if (IsDynamicTensor(output)) {
|
if (tflite::IsDynamicTensor(output)) {
|
||||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
output_size->data[0] = 1;
|
output_size->data[0] = 1;
|
||||||
output_size->data[1] = params->GetNumNGrams();
|
output_size->data[1] = params->GetNumNGrams();
|
||||||
|
@ -261,4 +261,4 @@ TfLiteRegistration* Register_NGRAM_HASH() {
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
|
@ -18,10 +18,10 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
|
|
||||||
TfLiteRegistration* Register_NGRAM_HASH();
|
TfLiteRegistration* Register_NGRAM_HASH();
|
||||||
|
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
namespace tflite::ops::custom {
|
namespace mediapipe::tflite_operations {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::flexbuffers::Builder;
|
using ::flexbuffers::Builder;
|
||||||
|
@ -42,7 +42,7 @@ using ::testing::ElementsAreArray;
|
||||||
using ::testing::Message;
|
using ::testing::Message;
|
||||||
|
|
||||||
// Helper class for testing the op.
|
// Helper class for testing the op.
|
||||||
class NGramHashModel : public SingleOpModel {
|
class NGramHashModel : public tflite::SingleOpModel {
|
||||||
public:
|
public:
|
||||||
explicit NGramHashModel(const uint64_t seed,
|
explicit NGramHashModel(const uint64_t seed,
|
||||||
const std::vector<int>& ngram_lengths,
|
const std::vector<int>& ngram_lengths,
|
||||||
|
@ -71,7 +71,7 @@ class NGramHashModel : public SingleOpModel {
|
||||||
}
|
}
|
||||||
fbb.EndMap(start);
|
fbb.EndMap(start);
|
||||||
fbb.Finish();
|
fbb.Finish();
|
||||||
output_ = AddOutput({TensorType_INT32, {}});
|
output_ = AddOutput({tflite::TensorType_INT32, {}});
|
||||||
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||||
BuildInterpreter({GetShape(input_)});
|
BuildInterpreter({GetShape(input_)});
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ class NGramHashModel : public SingleOpModel {
|
||||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int input_ = AddInput(TensorType_STRING);
|
int input_ = AddInput(tflite::TensorType_STRING);
|
||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||||
|
|
||||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||||
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||||
const string& testcase_input = testcase_inputs[test_idx];
|
const std::string& testcase_input = testcase_inputs[test_idx];
|
||||||
m.Invoke(testcase_input);
|
m.Invoke(testcase_input);
|
||||||
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||||
<< testcase_input);
|
<< testcase_input);
|
||||||
|
@ -310,4 +310,4 @@ TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite::ops::custom
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
126
mediapipe/tasks/cc/text/language_detector/language_detector.cc
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::Category;
|
||||||
|
using ::mediapipe::tasks::components::containers::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::Classifications;
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||||
|
using ClassificationResultProto =
|
||||||
|
::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::text::text_classifier::proto::
|
||||||
|
TextClassifierGraphOptions;
|
||||||
|
|
||||||
|
constexpr char kTextStreamName[] = "text_in";
|
||||||
|
constexpr char kTextTag[] = "TEXT";
|
||||||
|
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// type "TextClassifierGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<TextClassifierGraphOptions> options) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
subgraph.GetOptions<TextClassifierGraphOptions>().Swap(options.get());
|
||||||
|
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||||
|
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||||
|
graph.Out(kClassificationsTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing LanguageDetectorOptions struct to the internal
|
||||||
|
// TextClassifierGraphOptions proto.
|
||||||
|
std::unique_ptr<TextClassifierGraphOptions>
|
||||||
|
ConvertLanguageDetectorOptionsToProto(LanguageDetectorOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<TextClassifierGraphOptions>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
auto classifier_options_proto =
|
||||||
|
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||||
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
|
&(options->classifier_options)));
|
||||||
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
|
classifier_options_proto.get());
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<LanguageDetectorResult>
|
||||||
|
ExtractLanguageDetectorResultFromClassificationResult(
|
||||||
|
const ClassificationResult& classification_result) {
|
||||||
|
if (classification_result.classifications.size() != 1) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The LanguageDetector TextClassifierGraph should have exactly one "
|
||||||
|
"classification head.");
|
||||||
|
}
|
||||||
|
const Classifications& languages_and_scores =
|
||||||
|
classification_result.classifications[0];
|
||||||
|
LanguageDetectorResult language_detector_result;
|
||||||
|
for (const Category& category : languages_and_scores.categories) {
|
||||||
|
if (!category.category_name.has_value()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"LanguageDetector ClassificationResult has a missing language code.");
|
||||||
|
}
|
||||||
|
language_detector_result.push_back(
|
||||||
|
{.language_code = *category.category_name,
|
||||||
|
.probability = category.score});
|
||||||
|
}
|
||||||
|
return language_detector_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<LanguageDetector>> LanguageDetector::Create(
|
||||||
|
std::unique_ptr<LanguageDetectorOptions> options) {
|
||||||
|
auto options_proto = ConvertLanguageDetectorOptionsToProto(options.get());
|
||||||
|
return core::TaskApiFactory::Create<LanguageDetector,
|
||||||
|
TextClassifierGraphOptions>(
|
||||||
|
CreateGraphConfig(std::move(options_proto)),
|
||||||
|
std::move(options->base_options.op_resolver));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<LanguageDetectorResult> LanguageDetector::Detect(
|
||||||
|
absl::string_view text) {
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
runner_->Process(
|
||||||
|
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||||
|
ClassificationResult classification_result =
|
||||||
|
ConvertToClassificationResult(output_packets[kClassificationsStreamName]
|
||||||
|
.Get<ClassificationResultProto>());
|
||||||
|
return ExtractLanguageDetectorResultFromClassificationResult(
|
||||||
|
classification_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
|
@ -0,0 +1,84 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
|
||||||
|
// A language code and its probability.
|
||||||
|
struct LanguageDetectorPrediction {
|
||||||
|
// An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||||
|
// "ja"-Latn for Japanese (romaji).
|
||||||
|
std::string language_code;
|
||||||
|
|
||||||
|
float probability;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Task output.
|
||||||
|
using LanguageDetectorResult = std::vector<LanguageDetectorPrediction>;
|
||||||
|
|
||||||
|
// The options for configuring a MediaPipe LanguageDetector task.
|
||||||
|
struct LanguageDetectorOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
|
// number of results, etc.
|
||||||
|
components::processors::ClassifierOptions classifier_options;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Predicts the language of an input text.
|
||||||
|
//
|
||||||
|
// This API expects a TFLite model with TFLite Model Metadata that
|
||||||
|
// contains the mandatory (described below) input tensors, output tensor,
|
||||||
|
// and the language codes in an AssociatedFile.
|
||||||
|
//
|
||||||
|
// Input tensors:
|
||||||
|
// (kTfLiteString)
|
||||||
|
// - 1 input tensor that is scalar or has shape [1] containing the input
|
||||||
|
// string.
|
||||||
|
// Output tensor:
|
||||||
|
// (kTfLiteFloat32)
|
||||||
|
// - 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||||
|
class LanguageDetector : core::BaseTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseTaskApi::BaseTaskApi;
|
||||||
|
|
||||||
|
// Creates a LanguageDetector instance from the provided `options`.
|
||||||
|
static absl::StatusOr<std::unique_ptr<LanguageDetector>> Create(
|
||||||
|
std::unique_ptr<LanguageDetectorOptions> options);
|
||||||
|
|
||||||
|
// Predicts the language of the input `text`.
|
||||||
|
absl::StatusOr<LanguageDetectorResult> Detect(absl::string_view text);
|
||||||
|
|
||||||
|
// Shuts down the LanguageDetector instance when all the work is done.
|
||||||
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_LANGUAGE_DETECTOR_H_
|
|
@ -0,0 +1,163 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/language_detector.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/cord.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
|
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||||
|
constexpr char kLanguageDetector[] = "language_detector.tflite";
|
||||||
|
|
||||||
|
constexpr float kTolerance = 0.000001;
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status MatchesLanguageDetectorResult(
|
||||||
|
const LanguageDetectorResult& expected,
|
||||||
|
const LanguageDetectorResult& actual, float tolerance) {
|
||||||
|
if (expected.size() != actual.size()) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected $0 predictions, but got $1", expected.size(), actual.size()));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
|
if (expected[i].language_code != actual[i].language_code) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected prediction $0 to have language_code $1, but got $2", i,
|
||||||
|
expected[i].language_code, actual[i].language_code));
|
||||||
|
}
|
||||||
|
if (std::abs(expected[i].probability - actual[i].probability) > tolerance) {
|
||||||
|
return absl::FailedPreconditionError(absl::Substitute(
|
||||||
|
"Expected prediction $0 to have probability $1, but got $2", i,
|
||||||
|
expected[i].probability, actual[i].probability));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||||
|
absl::StatusOr<std::unique_ptr<LanguageDetector>> language_detector =
|
||||||
|
LanguageDetector::Create(std::move(options));
|
||||||
|
|
||||||
|
EXPECT_EQ(language_detector.status().code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_THAT(language_detector.status().message(),
|
||||||
|
HasSubstr("Unable to open file at"));
|
||||||
|
EXPECT_THAT(language_detector.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestL2CModel) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_en,
|
||||||
|
language_detector->Detect("To be, or not to be, that is the question"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "en", .probability = 0.999856}}, result_en,
|
||||||
|
kTolerance));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_fr,
|
||||||
|
language_detector->Detect(
|
||||||
|
"Il y a beaucoup de bouches qui parlent et fort peu "
|
||||||
|
"de têtes qui pensent."));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "fr", .probability = 0.999781}}, result_fr,
|
||||||
|
kTolerance));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
LanguageDetectorResult result_ru,
|
||||||
|
language_detector->Detect("это какой-то английский язык"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "ru", .probability = 0.993362}}, result_ru,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestMultiplePredictions) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_mixed,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "zh", .probability = 0.505424},
|
||||||
|
{.language_code = "ja", .probability = 0.481617}},
|
||||||
|
result_mixed, kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestAllowList) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.category_allowlist = {"ja"};
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_ja,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "ja", .probability = 0.481617}}, result_ja,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LanguageDetectorTest, TestDenyList) {
|
||||||
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kLanguageDetector);
|
||||||
|
options->classifier_options.score_threshold = 0.3;
|
||||||
|
options->classifier_options.category_denylist = {"ja"};
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageDetector> language_detector,
|
||||||
|
LanguageDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(LanguageDetectorResult result_zh,
|
||||||
|
language_detector->Detect("分久必合合久必分"));
|
||||||
|
MP_EXPECT_OK(MatchesLanguageDetectorResult(
|
||||||
|
{{.language_code = "zh", .probability = 0.505424}}, result_zh,
|
||||||
|
kTolerance));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector
|
Loading…
Reference in New Issue
Block a user