diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 185bf231b..cec44a9e3 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -150,9 +150,12 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index f48c4bad8..816ba47e3 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -60,10 +60,16 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "text_model_type_proto", + srcs = ["text_model_type.proto"], +) + mediapipe_proto_library( name = "text_preprocessing_graph_options_proto", srcs = ["text_preprocessing_graph_options.proto"], deps = [ + ":text_model_type_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", ], diff --git a/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto new file mode 100644 index 000000000..7ffc0db07 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/text_model_type.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +message TextModelType { + // TFLite text models supported by MediaPipe tasks. + enum ModelType { + UNSPECIFIED_MODEL = 0; + // A BERT-based model. + BERT_MODEL = 1; + // A model expecting input passed through a regex-based tokenizer. + REGEX_MODEL = 2; + // A model taking a string tensor input. + STRING_MODEL = 3; + } +} diff --git a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index a67cfd8a9..b610f7757 100644 --- a/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -18,25 +18,16 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto"; message TextPreprocessingGraphOptions { extend mediapipe.CalculatorOptions { optional TextPreprocessingGraphOptions ext = 476978751; } - // The type of text preprocessor required for the TFLite model. - enum PreprocessorType { - UNSPECIFIED_PREPROCESSOR = 0; - // Used for the BertPreprocessorCalculator. - BERT_PREPROCESSOR = 1; - // Used for the RegexPreprocessorCalculator. - REGEX_PREPROCESSOR = 2; - // Used for the TextToTensorCalculator. - STRING_PREPROCESSOR = 3; - } - optional PreprocessorType preprocessor_type = 1; + optional TextModelType.ModelType model_type = 1; // The maximum input sequence length for the TFLite model. Used with - // BERT_PREPROCESSOR and REGEX_PREPROCESSOR. + // BERT_MODEL and REGEX_MODEL. optional int32 max_seq_len = 2; } diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index de16375bd..f6c15c441 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -25,15 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" -namespace mediapipe { -namespace tasks { -namespace components { -namespace processors { - +namespace mediapipe::tasks::components::processors { namespace { using ::mediapipe::api2::Input; @@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::components::processors::proto:: TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -constexpr int kNumInputTensorsForBert = 3; -constexpr int kNumInputTensorsForRegex = 1; - -// Gets the name of the MediaPipe calculator associated with -// `preprocessor_type`. -absl::StatusOr GetCalculatorNameFromPreprocessorType( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) { - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: +// Gets the name of the MediaPipe preprocessor calculator associated with +// `model_type`. +absl::StatusOr GetCalculatorNameFromModelType( + TextModelType::ModelType model_type) { + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type", + absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextModelType::BERT_MODEL: return "BertPreprocessorCalculator"; - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: + case TextModelType::REGEX_MODEL: return "RegexPreprocessorCalculator"; - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; } } -// Determines the PreprocessorType for the model based on its metadata as well -// as its input tensors' type and count. Returns an error if there is no -// compatible preprocessor. -absl::StatusOr -GetPreprocessorType(const ModelResources& model_resources) { - const tflite::SubGraph& model_graph = - *(*model_resources.GetTfLiteModel()->subgraphs())[0]; - bool all_int32_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; - }); - bool all_string_tensors = - absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { - return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; - }); - if (!all_int32_tensors && !all_string_tensors) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "All input tensors should have type int32 or all should have type " - "string", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); - } - if (all_string_tensors) { - return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; - } - - // Otherwise, all tensors should have type int32 - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); - if (metadata_extractor->GetModelMetadata() == nullptr || - metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Text models with int32 input tensors require TFLite Model " - "Metadata but none was found", - MediaPipeTasksStatus::kMetadataNotFoundError); - } - - if (model_graph.inputs()->size() == kNumInputTensorsForBert) { - return TextPreprocessingGraphOptions::BERT_PREPROCESSOR; - } - - if (model_graph.inputs()->size() == kNumInputTensorsForRegex) { - return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR; - } - - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::Substitute("Models with int32 input tensors should take exactly $0 " - "or $1 input tensors, but found $2", - kNumInputTensorsForBert, kNumInputTensorsForRegex, - model_graph.inputs()->size()), - MediaPipeTasksStatus::kInvalidNumInputTensorsError); -} - // Returns the maximum input sequence length accepted by the TFLite // model that owns `model graph` or returns an error if the model's input // tensors' shape is invalid for text preprocessing. This util assumes that the @@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph( MediaPipeTasksStatus::kInvalidArgumentError); } - ASSIGN_OR_RETURN( - TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, - GetPreprocessorType(model_resources)); - options.set_preprocessor_type(preprocessor_type); - switch (preprocessor_type) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + options.set_model_type(model_type); + switch (model_type) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::BERT_MODEL: + case TextModelType::REGEX_MODEL: { ASSIGN_OR_RETURN( int max_seq_len, GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); @@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { absl::StatusOr>> BuildTextPreprocessing( const TextPreprocessingGraphOptions& options, Source text_in, SideSource metadata_extractor_in, Graph& graph) { - ASSIGN_OR_RETURN( - std::string preprocessor_name, - GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); + ASSIGN_OR_RETURN(std::string preprocessor_name, + GetCalculatorNameFromModelType(options.model_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); - switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: - case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + switch (options.model_type()) { + case TextModelType::UNSPECIFIED_MODEL: + case TextModelType::STRING_MODEL: { break; } - case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { + case TextModelType::BERT_MODEL: { text_preprocessor.GetOptions() .set_bert_max_seq_len(options.max_seq_len()); metadata_extractor_in >> text_preprocessor.SideIn(kMetadataExtractorTag); break; } - case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + case TextModelType::REGEX_MODEL: { text_preprocessor.GetOptions() .set_max_seq_len(options.max_seq_len()); metadata_extractor_in >> @@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { REGISTER_MEDIAPIPE_GRAPH( ::mediapipe::tasks::components::processors::TextPreprocessingGraph); -} // namespace processors -} // namespace components -} // namespace tasks -} // namespace mediapipe +} // namespace mediapipe::tasks::components::processors diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 710e8a984..092a7d450 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -43,3 +43,43 @@ cc_test( "@com_google_absl//absl/container:node_hash_map", ], ) + +cc_library( + name = "text_model_utils", + srcs = ["text_model_utils.cc"], + hdrs = ["text_model_utils.h"], + deps = [ + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "text_model_utils_test", + srcs = ["text_model_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_model_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.cc b/mediapipe/tasks/cc/text/utils/text_model_utils.cc new file mode 100644 index 000000000..9d0005ec1 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.cc @@ -0,0 +1,119 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe::tasks::text::utils { +namespace { + +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kNumInputTensorsForRegex = 1; +constexpr int kNumInputTensorsForStringPreprocessor = 1; + +// Determines the ModelType for a model with int32 input tensors based +// on the number of input tensors. Returns an error if there is missing metadata +// or an invalid number of input tensors. +absl::StatusOr GetIntTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Text models with int32 input tensors require TFLite Model " + "Metadata but none was found", + MediaPipeTasksStatus::kMetadataNotFoundError); + } + + if (num_input_tensors == kNumInputTensorsForBert) { + return TextModelType::BERT_MODEL; + } + + if (num_input_tensors == kNumInputTensorsForRegex) { + return TextModelType::REGEX_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with int32 input tensors should take exactly $0 " + "or $1 input tensors, but found $2", + kNumInputTensorsForBert, kNumInputTensorsForRegex, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} + +// Determines the ModelType for a model with string input tensors based +// on the number of input tensors. Returns an error if there is an invalid +// number of input tensors. +absl::StatusOr GetStringTensorModelType( + const ModelResources& model_resources, int num_input_tensors) { + if (num_input_tensors == kNumInputTensorsForStringPreprocessor) { + return TextModelType::STRING_MODEL; + } + + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute("Models with string input tensors should take exactly " + "$0 tensors, but found $1", + kNumInputTensorsForStringPreprocessor, + num_input_tensors), + MediaPipeTasksStatus::kInvalidNumInputTensorsError); +} +} // namespace + +absl::StatusOr GetModelType( + const ModelResources& model_resources) { + const tflite::SubGraph& model_graph = + *(*model_resources.GetTfLiteModel()->subgraphs())[0]; + bool all_int32_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32; + }); + bool all_string_tensors = + absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) { + return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING; + }); + if (!all_int32_tensors && !all_string_tensors) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "All input tensors should have type int32 or all should have type " + "string", + MediaPipeTasksStatus::kInvalidInputTensorTypeError); + } + if (all_string_tensors) { + return GetStringTensorModelType(model_resources, + model_graph.inputs()->size()); + } + + // Otherwise, all tensors should have type int32 + return GetIntTensorModelType(model_resources, model_graph.inputs()->size()); +} + +} // namespace mediapipe::tasks::text::utils diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils.h b/mediapipe/tasks/cc/text/utils/text_model_utils.h new file mode 100644 index 000000000..da8783d33 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +namespace mediapipe::tasks::text::utils { + +// Determines the ModelType for the model based on its metadata as well +// as its input tensors' type and count. Returns an error if there is no +// compatible model type. +absl::StatusOr +GetModelType(const core::ModelResources& model_resources); + +} // namespace mediapipe::tasks::text::utils + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc new file mode 100644 index 000000000..c02f8eca5 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::utils { + +namespace { + +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::processors::proto::TextModelType; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::proto::ExternalFile; + +constexpr absl::string_view kTestModelResourcesTag = "test_model_resources"; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/text/"; +// Classification model with BERT preprocessing. +constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite"; +// Embedding model with BERT preprocessing. +constexpr absl::string_view kMobileBert = + "mobilebert_embedding_with_metadata.tflite"; +// Classification model with regex preprocessing. +constexpr absl::string_view kRegexClassifierPath = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +// Embedding model with regex preprocessing. +constexpr absl::string_view kRegexOneEmbeddingModel = + "regex_one_embedding_with_metadata.tflite"; +// Classification model that takes a string tensor and outputs a bool tensor. +constexpr absl::string_view kStringToBoolModelPath = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +absl::StatusOr GetModelTypeFromFile( + absl::string_view file_name) { + auto model_file = std::make_unique(); + model_file->set_file_name(GetFullPath(file_name)); + ASSIGN_OR_RETURN(auto model_resources, + ModelResources::Create(std::string(kTestModelResourcesTag), + std::move(model_file))); + return GetModelType(*model_resources); +} + +} // namespace + +class TextModelUtilsTest : public tflite_shims::testing::Test {}; + +TEST_F(TextModelUtilsTest, BertClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kBertClassifierPath)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, BertEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert)); + ASSERT_EQ(model_type, TextModelType::BERT_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexClassifierModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexClassifierPath)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kRegexOneEmbeddingModel)); + ASSERT_EQ(model_type, TextModelType::REGEX_MODEL); +} + +TEST_F(TextModelUtilsTest, StringInputModelTest) { + MP_ASSERT_OK_AND_ASSIGN(auto model_type, + GetModelTypeFromFile(kStringToBoolModelPath)); + ASSERT_EQ(model_type, TextModelType::STRING_MODEL); +} + +} // namespace mediapipe::tasks::text::utils