Internal refactoring for TextEmbedder.

PiperOrigin-RevId: 493766612
This commit is contained in:
MediaPipe Team 2022-12-07 18:54:34 -08:00 committed by Copybara-Service
parent a0efcb47f2
commit 700c7b4b22
9 changed files with 375 additions and 106 deletions

View File

@ -150,9 +150,12 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/utils:text_model_utils",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -60,10 +60,16 @@ mediapipe_proto_library(
], ],
) )
mediapipe_proto_library(
name = "text_model_type_proto",
srcs = ["text_model_type.proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto", name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"], srcs = ["text_preprocessing_graph_options.proto"],
deps = [ deps = [
":text_model_type_proto",
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
], ],

View File

@ -0,0 +1,31 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.components.processors.proto;
message TextModelType {
// TFLite text models supported by MediaPipe tasks.
enum ModelType {
UNSPECIFIED_MODEL = 0;
// A BERT-based model.
BERT_MODEL = 1;
// A model expecting input passed through a regex-based tokenizer.
REGEX_MODEL = 2;
// A model taking a string tensor input.
STRING_MODEL = 3;
}
}

View File

@ -18,25 +18,16 @@ syntax = "proto2";
package mediapipe.tasks.components.processors.proto; package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto";
message TextPreprocessingGraphOptions { message TextPreprocessingGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional TextPreprocessingGraphOptions ext = 476978751; optional TextPreprocessingGraphOptions ext = 476978751;
} }
// The type of text preprocessor required for the TFLite model. optional TextModelType.ModelType model_type = 1;
enum PreprocessorType {
UNSPECIFIED_PREPROCESSOR = 0;
// Used for the BertPreprocessorCalculator.
BERT_PREPROCESSOR = 1;
// Used for the RegexPreprocessorCalculator.
REGEX_PREPROCESSOR = 2;
// Used for the TextToTensorCalculator.
STRING_PREPROCESSOR = 3;
}
optional PreprocessorType preprocessor_type = 1;
// The maximum input sequence length for the TFLite model. Used with // The maximum input sequence length for the TFLite model. Used with
// BERT_PREPROCESSOR and REGEX_PREPROCESSOR. // BERT_MODEL and REGEX_MODEL.
optional int32 max_seq_len = 2; optional int32 max_seq_len = 2;
} }

View File

@ -25,15 +25,14 @@ limitations under the License.
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/subgraph.h" #include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
namespace mediapipe { namespace mediapipe::tasks::components::processors {
namespace tasks {
namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::SideSource;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::processors::proto::TextModelType;
using ::mediapipe::tasks::components::processors::proto:: using ::mediapipe::tasks::components::processors::proto::
TextPreprocessingGraphOptions; TextPreprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::mediapipe::tasks::text::utils::GetModelType;
constexpr char kTextTag[] = "TEXT"; constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr int kNumInputTensorsForBert = 3; // Gets the name of the MediaPipe preprocessor calculator associated with
constexpr int kNumInputTensorsForRegex = 1; // `model_type`.
absl::StatusOr<std::string> GetCalculatorNameFromModelType(
// Gets the name of the MediaPipe calculator associated with TextModelType::ModelType model_type) {
// `preprocessor_type`. switch (model_type) {
absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType( case TextModelType::UNSPECIFIED_MODEL:
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) {
switch (preprocessor_type) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", absl::StatusCode::kInvalidArgument, "Unspecified model type",
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: case TextModelType::BERT_MODEL:
return "BertPreprocessorCalculator"; return "BertPreprocessorCalculator";
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: case TextModelType::REGEX_MODEL:
return "RegexPreprocessorCalculator"; return "RegexPreprocessorCalculator";
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: case TextModelType::STRING_MODEL:
return "TextToTensorCalculator"; return "TextToTensorCalculator";
} }
} }
// Determines the PreprocessorType for the model based on its metadata as well
// as its input tensors' type and count. Returns an error if there is no
// compatible preprocessor.
absl::StatusOr<TextPreprocessingGraphOptions::PreprocessorType>
GetPreprocessorType(const ModelResources& model_resources) {
const tflite::SubGraph& model_graph =
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
bool all_int32_tensors =
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
});
bool all_string_tensors =
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
});
if (!all_int32_tensors && !all_string_tensors) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"All input tensors should have type int32 or all should have type "
"string",
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
}
if (all_string_tensors) {
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
}
// Otherwise, all tensors should have type int32
const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor();
if (metadata_extractor->GetModelMetadata() == nullptr ||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Text models with int32 input tensors require TFLite Model "
"Metadata but none was found",
MediaPipeTasksStatus::kMetadataNotFoundError);
}
if (model_graph.inputs()->size() == kNumInputTensorsForBert) {
return TextPreprocessingGraphOptions::BERT_PREPROCESSOR;
}
if (model_graph.inputs()->size() == kNumInputTensorsForRegex) {
return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR;
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute("Models with int32 input tensors should take exactly $0 "
"or $1 input tensors, but found $2",
kNumInputTensorsForBert, kNumInputTensorsForRegex,
model_graph.inputs()->size()),
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
}
// Returns the maximum input sequence length accepted by the TFLite // Returns the maximum input sequence length accepted by the TFLite
// model that owns `model graph` or returns an error if the model's input // model that owns `model graph` or returns an error if the model's input
// tensors' shape is invalid for text preprocessing. This util assumes that the // tensors' shape is invalid for text preprocessing. This util assumes that the
@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph(
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, GetModelType(model_resources));
GetPreprocessorType(model_resources)); options.set_model_type(model_type);
options.set_preprocessor_type(preprocessor_type); switch (model_type) {
switch (preprocessor_type) { case TextModelType::UNSPECIFIED_MODEL:
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: case TextModelType::STRING_MODEL: {
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break; break;
} }
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: case TextModelType::BERT_MODEL:
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { case TextModelType::REGEX_MODEL: {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
int max_seq_len, int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
absl::StatusOr<Source<std::vector<Tensor>>> BuildTextPreprocessing( absl::StatusOr<Source<std::vector<Tensor>>> BuildTextPreprocessing(
const TextPreprocessingGraphOptions& options, Source<std::string> text_in, const TextPreprocessingGraphOptions& options, Source<std::string> text_in,
SideSource<ModelMetadataExtractor> metadata_extractor_in, Graph& graph) { SideSource<ModelMetadataExtractor> metadata_extractor_in, Graph& graph) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(std::string preprocessor_name,
std::string preprocessor_name, GetCalculatorNameFromModelType(options.model_type()));
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
auto& text_preprocessor = graph.AddNode(preprocessor_name); auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.preprocessor_type()) { switch (options.model_type()) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: case TextModelType::UNSPECIFIED_MODEL:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { case TextModelType::STRING_MODEL: {
break; break;
} }
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { case TextModelType::BERT_MODEL: {
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>() text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
.set_bert_max_seq_len(options.max_seq_len()); .set_bert_max_seq_len(options.max_seq_len());
metadata_extractor_in >> metadata_extractor_in >>
text_preprocessor.SideIn(kMetadataExtractorTag); text_preprocessor.SideIn(kMetadataExtractorTag);
break; break;
} }
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { case TextModelType::REGEX_MODEL: {
text_preprocessor.GetOptions<RegexPreprocessorCalculatorOptions>() text_preprocessor.GetOptions<RegexPreprocessorCalculatorOptions>()
.set_max_seq_len(options.max_seq_len()); .set_max_seq_len(options.max_seq_len());
metadata_extractor_in >> metadata_extractor_in >>
@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::processors::TextPreprocessingGraph); ::mediapipe::tasks::components::processors::TextPreprocessingGraph);
} // namespace processors } // namespace mediapipe::tasks::components::processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -43,3 +43,43 @@ cc_test(
"@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_map",
], ],
) )
cc_library(
name = "text_model_utils",
srcs = ["text_model_utils.cc"],
hdrs = ["text_model_utils.h"],
deps = [
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
)
cc_test(
name = "text_model_utils_test",
srcs = ["text_model_utils_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
":text_model_utils",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)

View File

@ -0,0 +1,119 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe::tasks::text::utils {
namespace {
using ::mediapipe::tasks::components::processors::proto::TextModelType;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kNumInputTensorsForRegex = 1;
constexpr int kNumInputTensorsForStringPreprocessor = 1;
// Determines the ModelType for a model with int32 input tensors based
// on the number of input tensors. Returns an error if there is missing metadata
// or an invalid number of input tensors.
absl::StatusOr<TextModelType::ModelType> GetIntTensorModelType(
const ModelResources& model_resources, int num_input_tensors) {
const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor();
if (metadata_extractor->GetModelMetadata() == nullptr ||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Text models with int32 input tensors require TFLite Model "
"Metadata but none was found",
MediaPipeTasksStatus::kMetadataNotFoundError);
}
if (num_input_tensors == kNumInputTensorsForBert) {
return TextModelType::BERT_MODEL;
}
if (num_input_tensors == kNumInputTensorsForRegex) {
return TextModelType::REGEX_MODEL;
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute("Models with int32 input tensors should take exactly $0 "
"or $1 input tensors, but found $2",
kNumInputTensorsForBert, kNumInputTensorsForRegex,
num_input_tensors),
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
}
// Determines the ModelType for a model with string input tensors based
// on the number of input tensors. Returns an error if there is an invalid
// number of input tensors.
absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
const ModelResources& model_resources, int num_input_tensors) {
if (num_input_tensors == kNumInputTensorsForStringPreprocessor) {
return TextModelType::STRING_MODEL;
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute("Models with string input tensors should take exactly "
"$0 tensors, but found $1",
kNumInputTensorsForStringPreprocessor,
num_input_tensors),
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
}
} // namespace
absl::StatusOr<TextModelType::ModelType> GetModelType(
const ModelResources& model_resources) {
const tflite::SubGraph& model_graph =
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
bool all_int32_tensors =
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
});
bool all_string_tensors =
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
});
if (!all_int32_tensors && !all_string_tensors) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"All input tensors should have type int32 or all should have type "
"string",
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
}
if (all_string_tensors) {
return GetStringTensorModelType(model_resources,
model_graph.inputs()->size());
}
// Otherwise, all tensors should have type int32
return GetIntTensorModelType(model_resources, model_graph.inputs()->size());
}
} // namespace mediapipe::tasks::text::utils

View File

@ -0,0 +1,33 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe::tasks::text::utils {
// Determines the ModelType for the model based on its metadata as well
// as its input tensors' type and count. Returns an error if there is no
// compatible model type.
absl::StatusOr<components::processors::proto::TextModelType::ModelType>
GetModelType(const core::ModelResources& model_resources);
} // namespace mediapipe::tasks::text::utils
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_

View File

@ -0,0 +1,108 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
#include <memory>
#include <string>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe::tasks::text::utils {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::processors::proto::TextModelType;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::core::proto::ExternalFile;
constexpr absl::string_view kTestModelResourcesTag = "test_model_resources";
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/text/";
// Classification model with BERT preprocessing.
constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite";
// Embedding model with BERT preprocessing.
constexpr absl::string_view kMobileBert =
"mobilebert_embedding_with_metadata.tflite";
// Classification model with regex preprocessing.
constexpr absl::string_view kRegexClassifierPath =
"test_model_text_classifier_with_regex_tokenizer.tflite";
// Embedding model with regex preprocessing.
constexpr absl::string_view kRegexOneEmbeddingModel =
"regex_one_embedding_with_metadata.tflite";
// Classification model that takes a string tensor and outputs a bool tensor.
constexpr absl::string_view kStringToBoolModelPath =
"test_model_text_classifier_bool_output.tflite";
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
absl::string_view file_name) {
auto model_file = std::make_unique<ExternalFile>();
model_file->set_file_name(GetFullPath(file_name));
ASSIGN_OR_RETURN(auto model_resources,
ModelResources::Create(std::string(kTestModelResourcesTag),
std::move(model_file)));
return GetModelType(*model_resources);
}
} // namespace
class TextModelUtilsTest : public tflite_shims::testing::Test {};
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
GetModelTypeFromFile(kBertClassifierPath));
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
}
TEST_F(TextModelUtilsTest, BertEmbedderModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert));
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
}
TEST_F(TextModelUtilsTest, RegexClassifierModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
GetModelTypeFromFile(kRegexClassifierPath));
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
}
TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
GetModelTypeFromFile(kRegexOneEmbeddingModel));
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
}
TEST_F(TextModelUtilsTest, StringInputModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
GetModelTypeFromFile(kStringToBoolModelPath));
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
}
} // namespace mediapipe::tasks::text::utils