Internal refactoring for TextEmbedder.
PiperOrigin-RevId: 493766612
This commit is contained in:
parent
a0efcb47f2
commit
700c7b4b22
|
@ -150,9 +150,12 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/cc/text/utils:text_model_utils",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -60,10 +60,16 @@ mediapipe_proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "text_model_type_proto",
|
||||||
|
srcs = ["text_model_type.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "text_preprocessing_graph_options_proto",
|
name = "text_preprocessing_graph_options_proto",
|
||||||
srcs = ["text_preprocessing_graph_options.proto"],
|
srcs = ["text_preprocessing_graph_options.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":text_model_type_proto",
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
],
|
],
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
|
message TextModelType {
|
||||||
|
// TFLite text models supported by MediaPipe tasks.
|
||||||
|
enum ModelType {
|
||||||
|
UNSPECIFIED_MODEL = 0;
|
||||||
|
// A BERT-based model.
|
||||||
|
BERT_MODEL = 1;
|
||||||
|
// A model expecting input passed through a regex-based tokenizer.
|
||||||
|
REGEX_MODEL = 2;
|
||||||
|
// A model taking a string tensor input.
|
||||||
|
STRING_MODEL = 3;
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,25 +18,16 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.components.processors.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/tasks/cc/components/processors/proto/text_model_type.proto";
|
||||||
|
|
||||||
message TextPreprocessingGraphOptions {
|
message TextPreprocessingGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional TextPreprocessingGraphOptions ext = 476978751;
|
optional TextPreprocessingGraphOptions ext = 476978751;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The type of text preprocessor required for the TFLite model.
|
optional TextModelType.ModelType model_type = 1;
|
||||||
enum PreprocessorType {
|
|
||||||
UNSPECIFIED_PREPROCESSOR = 0;
|
|
||||||
// Used for the BertPreprocessorCalculator.
|
|
||||||
BERT_PREPROCESSOR = 1;
|
|
||||||
// Used for the RegexPreprocessorCalculator.
|
|
||||||
REGEX_PREPROCESSOR = 2;
|
|
||||||
// Used for the TextToTensorCalculator.
|
|
||||||
STRING_PREPROCESSOR = 3;
|
|
||||||
}
|
|
||||||
optional PreprocessorType preprocessor_type = 1;
|
|
||||||
|
|
||||||
// The maximum input sequence length for the TFLite model. Used with
|
// The maximum input sequence length for the TFLite model. Used with
|
||||||
// BERT_PREPROCESSOR and REGEX_PREPROCESSOR.
|
// BERT_MODEL and REGEX_MODEL.
|
||||||
optional int32 max_seq_len = 2;
|
optional int32 max_seq_len = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,15 +25,14 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/subgraph.h"
|
#include "mediapipe/framework/subgraph.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe::tasks::components::processors {
|
||||||
namespace tasks {
|
|
||||||
namespace components {
|
|
||||||
namespace processors {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::api2::Input;
|
using ::mediapipe::api2::Input;
|
||||||
|
@ -42,91 +41,35 @@ using ::mediapipe::api2::SideInput;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::SideSource;
|
using ::mediapipe::api2::builder::SideSource;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||||
using ::mediapipe::tasks::components::processors::proto::
|
using ::mediapipe::tasks::components::processors::proto::
|
||||||
TextPreprocessingGraphOptions;
|
TextPreprocessingGraphOptions;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
using ::mediapipe::tasks::text::utils::GetModelType;
|
||||||
|
|
||||||
constexpr char kTextTag[] = "TEXT";
|
constexpr char kTextTag[] = "TEXT";
|
||||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
constexpr int kNumInputTensorsForBert = 3;
|
// Gets the name of the MediaPipe preprocessor calculator associated with
|
||||||
constexpr int kNumInputTensorsForRegex = 1;
|
// `model_type`.
|
||||||
|
absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
||||||
// Gets the name of the MediaPipe calculator associated with
|
TextModelType::ModelType model_type) {
|
||||||
// `preprocessor_type`.
|
switch (model_type) {
|
||||||
absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type) {
|
|
||||||
switch (preprocessor_type) {
|
|
||||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument, "Unspecified preprocessor type",
|
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
case TextModelType::BERT_MODEL:
|
||||||
return "BertPreprocessorCalculator";
|
return "BertPreprocessorCalculator";
|
||||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
case TextModelType::REGEX_MODEL:
|
||||||
return "RegexPreprocessorCalculator";
|
return "RegexPreprocessorCalculator";
|
||||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
|
case TextModelType::STRING_MODEL:
|
||||||
return "TextToTensorCalculator";
|
return "TextToTensorCalculator";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determines the PreprocessorType for the model based on its metadata as well
|
|
||||||
// as its input tensors' type and count. Returns an error if there is no
|
|
||||||
// compatible preprocessor.
|
|
||||||
absl::StatusOr<TextPreprocessingGraphOptions::PreprocessorType>
|
|
||||||
GetPreprocessorType(const ModelResources& model_resources) {
|
|
||||||
const tflite::SubGraph& model_graph =
|
|
||||||
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
|
|
||||||
bool all_int32_tensors =
|
|
||||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
|
||||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
|
|
||||||
});
|
|
||||||
bool all_string_tensors =
|
|
||||||
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
|
||||||
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
|
|
||||||
});
|
|
||||||
if (!all_int32_tensors && !all_string_tensors) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"All input tensors should have type int32 or all should have type "
|
|
||||||
"string",
|
|
||||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
|
||||||
}
|
|
||||||
if (all_string_tensors) {
|
|
||||||
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, all tensors should have type int32
|
|
||||||
const ModelMetadataExtractor* metadata_extractor =
|
|
||||||
model_resources.GetMetadataExtractor();
|
|
||||||
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
|
||||||
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"Text models with int32 input tensors require TFLite Model "
|
|
||||||
"Metadata but none was found",
|
|
||||||
MediaPipeTasksStatus::kMetadataNotFoundError);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model_graph.inputs()->size() == kNumInputTensorsForBert) {
|
|
||||||
return TextPreprocessingGraphOptions::BERT_PREPROCESSOR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model_graph.inputs()->size() == kNumInputTensorsForRegex) {
|
|
||||||
return TextPreprocessingGraphOptions::REGEX_PREPROCESSOR;
|
|
||||||
}
|
|
||||||
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::Substitute("Models with int32 input tensors should take exactly $0 "
|
|
||||||
"or $1 input tensors, but found $2",
|
|
||||||
kNumInputTensorsForBert, kNumInputTensorsForRegex,
|
|
||||||
model_graph.inputs()->size()),
|
|
||||||
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the maximum input sequence length accepted by the TFLite
|
// Returns the maximum input sequence length accepted by the TFLite
|
||||||
// model that owns `model graph` or returns an error if the model's input
|
// model that owns `model graph` or returns an error if the model's input
|
||||||
// tensors' shape is invalid for text preprocessing. This util assumes that the
|
// tensors' shape is invalid for text preprocessing. This util assumes that the
|
||||||
|
@ -181,17 +124,16 @@ absl::Status ConfigureTextPreprocessingGraph(
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
|
||||||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
GetModelType(model_resources));
|
||||||
GetPreprocessorType(model_resources));
|
options.set_model_type(model_type);
|
||||||
options.set_preprocessor_type(preprocessor_type);
|
switch (model_type) {
|
||||||
switch (preprocessor_type) {
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
case TextModelType::STRING_MODEL: {
|
||||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
case TextModelType::BERT_MODEL:
|
||||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
case TextModelType::REGEX_MODEL: {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||||
|
@ -239,23 +181,22 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
||||||
absl::StatusOr<Source<std::vector<Tensor>>> BuildTextPreprocessing(
|
absl::StatusOr<Source<std::vector<Tensor>>> BuildTextPreprocessing(
|
||||||
const TextPreprocessingGraphOptions& options, Source<std::string> text_in,
|
const TextPreprocessingGraphOptions& options, Source<std::string> text_in,
|
||||||
SideSource<ModelMetadataExtractor> metadata_extractor_in, Graph& graph) {
|
SideSource<ModelMetadataExtractor> metadata_extractor_in, Graph& graph) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(std::string preprocessor_name,
|
||||||
std::string preprocessor_name,
|
GetCalculatorNameFromModelType(options.model_type()));
|
||||||
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
|
||||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||||
switch (options.preprocessor_type()) {
|
switch (options.model_type()) {
|
||||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
case TextModelType::STRING_MODEL: {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
case TextModelType::BERT_MODEL: {
|
||||||
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
|
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
|
||||||
.set_bert_max_seq_len(options.max_seq_len());
|
.set_bert_max_seq_len(options.max_seq_len());
|
||||||
metadata_extractor_in >>
|
metadata_extractor_in >>
|
||||||
text_preprocessor.SideIn(kMetadataExtractorTag);
|
text_preprocessor.SideIn(kMetadataExtractorTag);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
case TextModelType::REGEX_MODEL: {
|
||||||
text_preprocessor.GetOptions<RegexPreprocessorCalculatorOptions>()
|
text_preprocessor.GetOptions<RegexPreprocessorCalculatorOptions>()
|
||||||
.set_max_seq_len(options.max_seq_len());
|
.set_max_seq_len(options.max_seq_len());
|
||||||
metadata_extractor_in >>
|
metadata_extractor_in >>
|
||||||
|
@ -270,7 +211,4 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::components::processors::TextPreprocessingGraph);
|
::mediapipe::tasks::components::processors::TextPreprocessingGraph);
|
||||||
|
|
||||||
} // namespace processors
|
} // namespace mediapipe::tasks::components::processors
|
||||||
} // namespace components
|
|
||||||
} // namespace tasks
|
|
||||||
} // namespace mediapipe
|
|
||||||
|
|
|
@ -43,3 +43,43 @@ cc_test(
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "text_model_utils",
|
||||||
|
srcs = ["text_model_utils.cc"],
|
||||||
|
hdrs = ["text_model_utils.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "text_model_utils_test",
|
||||||
|
srcs = ["text_model_utils_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":text_model_utils",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
119
mediapipe/tasks/cc/text/utils/text_model_utils.cc
Normal file
119
mediapipe/tasks/cc/text/utils/text_model_utils.cc
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::utils {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||||
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
|
constexpr int kNumInputTensorsForBert = 3;
|
||||||
|
constexpr int kNumInputTensorsForRegex = 1;
|
||||||
|
constexpr int kNumInputTensorsForStringPreprocessor = 1;
|
||||||
|
|
||||||
|
// Determines the ModelType for a model with int32 input tensors based
|
||||||
|
// on the number of input tensors. Returns an error if there is missing metadata
|
||||||
|
// or an invalid number of input tensors.
|
||||||
|
absl::StatusOr<TextModelType::ModelType> GetIntTensorModelType(
|
||||||
|
const ModelResources& model_resources, int num_input_tensors) {
|
||||||
|
const ModelMetadataExtractor* metadata_extractor =
|
||||||
|
model_resources.GetMetadataExtractor();
|
||||||
|
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||||
|
metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"Text models with int32 input tensors require TFLite Model "
|
||||||
|
"Metadata but none was found",
|
||||||
|
MediaPipeTasksStatus::kMetadataNotFoundError);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_input_tensors == kNumInputTensorsForBert) {
|
||||||
|
return TextModelType::BERT_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_input_tensors == kNumInputTensorsForRegex) {
|
||||||
|
return TextModelType::REGEX_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::Substitute("Models with int32 input tensors should take exactly $0 "
|
||||||
|
"or $1 input tensors, but found $2",
|
||||||
|
kNumInputTensorsForBert, kNumInputTensorsForRegex,
|
||||||
|
num_input_tensors),
|
||||||
|
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines the ModelType for a model with string input tensors based
|
||||||
|
// on the number of input tensors. Returns an error if there is an invalid
|
||||||
|
// number of input tensors.
|
||||||
|
absl::StatusOr<TextModelType::ModelType> GetStringTensorModelType(
|
||||||
|
const ModelResources& model_resources, int num_input_tensors) {
|
||||||
|
if (num_input_tensors == kNumInputTensorsForStringPreprocessor) {
|
||||||
|
return TextModelType::STRING_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::Substitute("Models with string input tensors should take exactly "
|
||||||
|
"$0 tensors, but found $1",
|
||||||
|
kNumInputTensorsForStringPreprocessor,
|
||||||
|
num_input_tensors),
|
||||||
|
MediaPipeTasksStatus::kInvalidNumInputTensorsError);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<TextModelType::ModelType> GetModelType(
|
||||||
|
const ModelResources& model_resources) {
|
||||||
|
const tflite::SubGraph& model_graph =
|
||||||
|
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
|
||||||
|
bool all_int32_tensors =
|
||||||
|
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||||
|
return (*model_graph.tensors())[i]->type() == tflite::TensorType_INT32;
|
||||||
|
});
|
||||||
|
bool all_string_tensors =
|
||||||
|
absl::c_all_of(*model_graph.inputs(), [&model_graph](int i) {
|
||||||
|
return (*model_graph.tensors())[i]->type() == tflite::TensorType_STRING;
|
||||||
|
});
|
||||||
|
if (!all_int32_tensors && !all_string_tensors) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
"All input tensors should have type int32 or all should have type "
|
||||||
|
"string",
|
||||||
|
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||||
|
}
|
||||||
|
if (all_string_tensors) {
|
||||||
|
return GetStringTensorModelType(model_resources,
|
||||||
|
model_graph.inputs()->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, all tensors should have type int32
|
||||||
|
return GetIntTensorModelType(model_resources, model_graph.inputs()->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::utils
|
33
mediapipe/tasks/cc/text/utils/text_model_utils.h
Normal file
33
mediapipe/tasks/cc/text/utils/text_model_utils.h
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::utils {
|
||||||
|
|
||||||
|
// Determines the ModelType for the model based on its metadata as well
|
||||||
|
// as its input tensors' type and count. Returns an error if there is no
|
||||||
|
// compatible model type.
|
||||||
|
absl::StatusOr<components::processors::proto::TextModelType::ModelType>
|
||||||
|
GetModelType(const core::ModelResources& model_resources);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::utils
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_TEXT_MODEL_UTILS_H_
|
108
mediapipe/tasks/cc/text/utils/text_model_utils_test.cc
Normal file
108
mediapipe/tasks/cc/text/utils/text_model_utils_test.cc
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::utils {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||||
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
using ::mediapipe::tasks::core::proto::ExternalFile;
|
||||||
|
|
||||||
|
constexpr absl::string_view kTestModelResourcesTag = "test_model_resources";
|
||||||
|
|
||||||
|
constexpr absl::string_view kTestDataDirectory =
|
||||||
|
"/mediapipe/tasks/testdata/text/";
|
||||||
|
// Classification model with BERT preprocessing.
|
||||||
|
constexpr absl::string_view kBertClassifierPath = "bert_text_classifier.tflite";
|
||||||
|
// Embedding model with BERT preprocessing.
|
||||||
|
constexpr absl::string_view kMobileBert =
|
||||||
|
"mobilebert_embedding_with_metadata.tflite";
|
||||||
|
// Classification model with regex preprocessing.
|
||||||
|
constexpr absl::string_view kRegexClassifierPath =
|
||||||
|
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||||
|
// Embedding model with regex preprocessing.
|
||||||
|
constexpr absl::string_view kRegexOneEmbeddingModel =
|
||||||
|
"regex_one_embedding_with_metadata.tflite";
|
||||||
|
// Classification model that takes a string tensor and outputs a bool tensor.
|
||||||
|
constexpr absl::string_view kStringToBoolModelPath =
|
||||||
|
"test_model_text_classifier_bool_output.tflite";
|
||||||
|
|
||||||
|
std::string GetFullPath(absl::string_view file_name) {
|
||||||
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
|
||||||
|
absl::string_view file_name) {
|
||||||
|
auto model_file = std::make_unique<ExternalFile>();
|
||||||
|
model_file->set_file_name(GetFullPath(file_name));
|
||||||
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
|
ModelResources::Create(std::string(kTestModelResourcesTag),
|
||||||
|
std::move(model_file)));
|
||||||
|
return GetModelType(*model_resources);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class TextModelUtilsTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
GetModelTypeFromFile(kBertClassifierPath));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, BertEmbedderModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type, GetModelTypeFromFile(kMobileBert));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::BERT_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, RegexClassifierModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
GetModelTypeFromFile(kRegexClassifierPath));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, RegexEmbedderModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
GetModelTypeFromFile(kRegexOneEmbeddingModel));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::REGEX_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextModelUtilsTest, StringInputModelTest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
GetModelTypeFromFile(kStringToBoolModelPath));
|
||||||
|
ASSERT_EQ(model_type, TextModelType::STRING_MODEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::utils
|
Loading…
Reference in New Issue
Block a user