Internal refactoring for TextEmbedder.
PiperOrigin-RevId: 493766612
This commit is contained in:
parent
a0efcb47f2
commit
700c7b4b22
|
@ -150,9 +150,12 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/text/utils:text_model_utils",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -60,10 +60,16 @@ mediapipe_proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_model_type_proto",
|
||||
srcs = ["text_model_type.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_preprocessing_graph_options_proto",
|
||||
srcs = ["text_preprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
":text_model_type_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
message TextModelType {
|
||||
// TFLite text models supported by MediaPipe tasks.
|
||||
enum ModelType {
|
||||
UNSPECIFIED_MODEL = 0;
|
||||
// A BERT-based model.
|
||||
BERT_MODEL = 1;
|
||||
// A model expecting input passed through a regex-based tokenizer.
|
||||
REGEX_MODEL = 2;
|
||||
// A model taking a string tensor input.
|
||||
STRING_MODEL = 3;
|
||||
}
|
||||
}
|
|
@ -18,25 +18,16 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto";
|
||||
|
||||
message TextPreprocessingGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextPreprocessingGraphOptions ext = 476978751;
|
||||
}
|
||||
|
||||
// The type of text preprocessor required for the TFLite model.
|
||||
enum PreprocessorType {
|
||||
UNSPECIFIED_PREPROCESSOR = 0;
|
||||
// Used for the BertPreprocessorCalculator.
|
||||
BERT_PREPROCESSOR = 1;
|
||||
// Used for the RegexPreprocessorCalculator.
|
||||
REGEX_PREPROCESSOR = 2;
|
||||
// Used for the TextToTensorCalculator.
|
||||
STRING_PREPROCESSOR = 3;
|
||||
}
|
||||
optional PreprocessorType preprocessor_type = 1;
|
||||
optional TextModelType.ModelType model_type = 1;
|
||||
|
||||
// The maximum input sequence length for the TFLite model. Used with
|
||||
// BERT_PREPROCESSOR and REGEX_PREPROCESSOR.
|
||||
// BERT_MODEL and REGEX_MODEL.
|
||||
optional int32 max_seq_len = 2;
|
||||
}
|
||||
|
|
|
@ -25,15 +25,14 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
namespace mediapipe::tasks::components::processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
|
@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::SideSource;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||
using ::mediapipe::tasks::components::processors::proto::
|
||||
TextPreprocessingGraphOptions;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::mediapipe::tasks::text::utils::GetModelType;
|
||||
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kNumInputTensorsForRegex = 1;
|
||||
|
||||
// Gets the name of the MediaPipe calculator associated with
|
||||
// `preprocessor_type`.
|
||||
absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
|
||||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) {
|
||||
switch (preprocessor_type) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
// Gets the name of the MediaPipe preprocessor calculator associated with
|
||||
// `model_type`.
|
||||
absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
||||
TextModelType::ModelType model_type) {
|
||||
switch (model_type) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type",
|
||||
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
||||
case TextModelType::BERT_MODEL:
|
||||
return "BertPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
||||
case TextModelType::REGEX_MODEL:
|
||||
return "RegexPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
|
||||
case TextModelType::STRING_MODEL:
|
||||
return "TextToTensorCalculator";
|
||||
}
|
||||
}
|
||||
|
||||
// Determines the PreprocessorType for the model based on its metadata as well
|
||||
// as its input tensors' type and count. Returns an error if there is no
|
||||
// compatible preprocessor.
|
||||
absl::StatusOr<TextPreprocessingGraphOptions::PreprocessorType>
|
||||
GetPreprocessorType(const ModelResources& model_resources) {
|
||||
const tflite::SubGraph& model_graph =
|
||||
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
|
||||
bool all_int32_tensors =
|
||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
|
||||
});
|
||||
bool all_string_tensors =
|
||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
|
||||
});
|
||||
if (!all_int32_tensors && !all_string_tensors) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"All input tensors should have type int32 or all should have type "
|
||||
"string",
|
||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
}
|
||||
if (all_string_tensors) {
|
||||
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
|
||||
}
|
||||
|
||||
// Otherwise, all tensors should have type int32
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
model_resources.GetMetadataExtractor();
|
||||
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Text models with int32 input tensors require TFLite Model "
|
||||
"Metadata but none was found",
|
||||
MediaPipeTasksStatus::kMetadataNotFoundError);
|
||||
}
|
||||
|
||||
if (model_graph.inputs()->size() == kNumInputTensorsForBert) {
|
||||
return TextPreprocessingGraphOptions::BERT_PREPROCESSOR;
|
||||
}
|
||||
|
||||
if (model_graph.inputs()->size() == kNumInputTensorsForRegex) {
|
||||
return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR;
|
||||
}
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute("Models with int32 input tensors should take exactly $0 "
|
||||
"or $1 input tensors, but found $2",
|
||||
kNumInputTensorsForBert, kNumInputTensorsForRegex,
|
||||
model_graph.inputs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||
}
|
||||
|
||||
// Returns the maximum input sequence length accepted by the TFLite
|
||||
// model that owns `model graph` or returns an error if the model's input
|
||||
// tensors' shape is invalid for text preprocessing. This util assumes that the
|
||||
|
@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph(
|
|||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
||||
GetPreprocessorType(model_resources));
|
||||
options.set_preprocessor_type(preprocessor_type);
|
||||
switch (preprocessor_type) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
|
||||
GetModelType(model_resources));
|
||||
options.set_model_type(model_type);
|
||||
switch (model_type) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
case TextModelType::STRING_MODEL: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
||||
case TextModelType::BERT_MODEL:
|
||||
case TextModelType::REGEX_MODEL: {
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
|
@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
|||
absl::StatusOr<Source<std::vector<Tensor>>> BuildTextPreprocessing(
|
||||
const TextPreprocessingGraphOptions& options, Source<std::string> text_in,
|
||||
SideSource<ModelMetadataExtractor> metadata_extractor_in, Graph& graph) {
|
||||
ASSIGN_OR_RETURN(
|
||||
std::string preprocessor_name,
|
||||
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
||||
ASSIGN_OR_RETURN(std::string preprocessor_name,
|
||||
GetCalculatorNameFromModelType(options.model_type()));
|
||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||
switch (options.preprocessor_type()) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
switch (options.model_type()) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
case TextModelType::STRING_MODEL: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
||||
case TextModelType::BERT_MODEL: {
|
||||
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
|
||||
.set_bert_max_seq_len(options.max_seq_len());
|
||||
metadata_extractor_in >>
|
||||
text_preprocessor.SideIn(kMetadataExtractorTag);
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
||||
case TextModelType::REGEX_MODEL: {
|
||||
text_preprocessor.GetOptions<RegexPreprocessorCalculatorOptions>()
|
||||
.set_max_seq_len(options.max_seq_len());
|
||||
metadata_extractor_in >>
|
||||
|
@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
|||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::processors::TextPreprocessingGraph);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
} // namespace mediapipe::tasks::components::processors
|
||||
|
|
|
@ -43,3 +43,43 @@ cc_test(
|
|||
"@com_google_absl//absl/container:node_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_model_utils",
|
||||
srcs = ["text_model_utils.cc"],
|
||||
hdrs = ["text_model_utils.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_model_utils_test",
|
||||
srcs = ["text_model_utils_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
deps = [
|
||||
":text_model_utils",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
|
119
mediapipe/tasks/cc/text/utils/text_model_utils.cc
Normal file
119
mediapipe/tasks/cc/text/utils/text_model_utils.cc
Normal file
|
@ -0,0 +1,119 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mediapipe::tasks::text::utils {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kNumInputTensorsForRegex = 1;
|
||||
constexpr int kNumInputTensorsForStringPreprocessor = 1;
|
||||
|
||||
// Determines the ModelType for a model with int32 input tensors based
|
||||
// on the number of input tensors. Returns an error if there is missing metadata
|
||||
// or an invalid number of input tensors.
|
||||
absl::StatusOr<TextModelType::ModelType> GetIntTensorModelType(
|
||||
const ModelResources& model_resources, int num_input_tensors) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
model_resources.GetMetadataExtractor();
|
||||
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Text models with int32 input tensors require TFLite Model "
|
||||
"Metadata but none was found",
|
||||
MediaPipeTasksStatus::kMetadataNotFoundError);
|
||||
}
|
||||
|
||||
if (num_input_tensors == kNumInputTensorsForBert) {
|
||||
return TextModelType::BERT_MODEL;
|
||||
}
|
||||
|
||||
if (num_input_tensors == kNumInputTensorsForRegex) {
|
||||
return TextModelType::REGEX_MODEL;
|
||||
}
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute("Models with int32 input tensors should take exactly $0 "
|
||||
"or $1 input tensors, but found $2",
|
||||
kNumInputTensorsForBert, kNumInputTensorsForRegex,
|
||||
num_input_tensors),
|
||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||
}
|
||||
|
||||
// Determines the ModelType for a model with string input tensors based
|
||||
// on the number of input tensors. Returns an error if there is an invalid
|
||||
// number of input tensors.
|
||||
absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
|
||||
const ModelResources& model_resources, int num_input_tensors) {
|
||||
if (num_input_tensors == kNumInputTensorsForStringPreprocessor) {
|
||||
return TextModelType::STRING_MODEL;
|
||||
}
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute("Models with string input tensors should take exactly "
|
||||
"$0 tensors, but found $1",
|
||||
kNumInputTensorsForStringPreprocessor,
|
||||
num_input_tensors),
|
||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<TextModelType::ModelType> GetModelType(
|
||||
const ModelResources& model_resources) {
|
||||
const tflite::SubGraph& model_graph =
|
||||
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
|
||||
bool all_int32_tensors =
|
||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
|
||||
});
|
||||
bool all_string_tensors =
|
||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
|
||||
});
|
||||
if (!all_int32_tensors && !all_string_tensors) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"All input tensors should have type int32 or all should have type "
|
||||
"string",
|
||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
}
|
||||
if (all_string_tensors) {
|
||||
return GetStringTensorModelType(model_resources,
|
||||
model_graph.inputs()->size());
|
||||
}
|
||||
|
||||
// Otherwise, all tensors should have type int32
|
||||
return GetIntTensorModelType(model_resources, model_graph.inputs()->size());
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::utils
|
33
mediapipe/tasks/cc/text/utils/text_model_utils.h
Normal file
33
mediapipe/tasks/cc/text/utils/text_model_utils.h
Normal file
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
||||
namespace mediapipe::tasks::text::utils {
|
||||
|
||||
// Determines the ModelType for the model based on its metadata as well
|
||||
// as its input tensors' type and count. Returns an error if there is no
|
||||
// compatible model type.
|
||||
absl::StatusOr<components::processors::proto::TextModelType::ModelType>
|
||||
GetModelType(const core::ModelResources& model_resources);
|
||||
|
||||
} // namespace mediapipe::tasks::text::utils
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
108
mediapipe/tasks/cc/text/utils/text_model_utils_test.cc
Normal file
108
mediapipe/tasks/cc/text/utils/text_model_utils_test.cc
Normal file
|
@ -0,0 +1,108 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::utils {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::core::proto::ExternalFile;
|
||||
|
||||
constexpr absl::string_view kTestModelResourcesTag = "test_model_resources";
|
||||
|
||||
constexpr absl::string_view kTestDataDirectory =
|
||||
"/mediapipe/tasks/testdata/text/";
|
||||
// Classification model with BERT preprocessing.
|
||||
constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite";
|
||||
// Embedding model with BERT preprocessing.
|
||||
constexpr absl::string_view kMobileBert =
|
||||
"mobilebert_embedding_with_metadata.tflite";
|
||||
// Classification model with regex preprocessing.
|
||||
constexpr absl::string_view kRegexClassifierPath =
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
// Embedding model with regex preprocessing.
|
||||
constexpr absl::string_view kRegexOneEmbeddingModel =
|
||||
"regex_one_embedding_with_metadata.tflite";
|
||||
// Classification model that takes a string tensor and outputs a bool tensor.
|
||||
constexpr absl::string_view kStringToBoolModelPath =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
|
||||
absl::string_view file_name) {
|
||||
auto model_file = std::make_unique<ExternalFile>();
|
||||
model_file->set_file_name(GetFullPath(file_name));
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
ModelResources::Create(std::string(kTestModelResourcesTag),
|
||||
std::move(model_file)));
|
||||
return GetModelType(*model_resources);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class TextModelUtilsTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
GetModelTypeFromFile(kBertClassifierPath));
|
||||
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
|
||||
}
|
||||
|
||||
TEST_F(TextModelUtilsTest, BertEmbedderModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert));
|
||||
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
|
||||
}
|
||||
|
||||
TEST_F(TextModelUtilsTest, RegexClassifierModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
GetModelTypeFromFile(kRegexClassifierPath));
|
||||
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
|
||||
}
|
||||
|
||||
TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
GetModelTypeFromFile(kRegexOneEmbeddingModel));
|
||||
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
|
||||
}
|
||||
|
||||
TEST_F(TextModelUtilsTest, StringInputModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
GetModelTypeFromFile(kStringToBoolModelPath));
|
||||
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::utils
|
Loading…
Reference in New Issue
Block a user