Internal MediaPipe Tasks change.

PiperOrigin-RevId: 519878741
This commit is contained in:
MediaPipe Team 2023-03-27 17:53:08 -07:00 committed by Copybara-Service
parent 2d553a57e8
commit 5f8831660f
7 changed files with 119 additions and 31 deletions

View File

@ -237,7 +237,9 @@ cc_library(
cc_test( cc_test(
name = "bert_preprocessor_calculator_test", name = "bert_preprocessor_calculator_test",
srcs = ["bert_preprocessor_calculator_test.cc"], srcs = ["bert_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"], data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
],
linkopts = ["-ldl"], linkopts = ["-ldl"],
deps = [ deps = [
":bert_preprocessor_calculator", ":bert_preprocessor_calculator",
@ -250,7 +252,7 @@ cc_test(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_sentencepiece//src:sentencepiece_processor", "@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
], ],
) )
@ -300,7 +302,7 @@ cc_test(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_sentencepiece//src:sentencepiece_processor", "@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
], ],
) )

View File

@ -121,31 +121,39 @@ class BertPreprocessorCalculator : public Node {
private: private:
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_; std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
// The max sequence length accepted by the BERT model. // The max sequence length accepted by the BERT model if its input tensors
// are static.
int bert_max_seq_len_ = 2; int bert_max_seq_len_ = 2;
// Indices of the three input tensors for the BERT model. They should form the // Indices of the three input tensors for the BERT model. They should form the
// set {0, 1, 2}. // set {0, 1, 2}.
int input_ids_tensor_index_ = 0; int input_ids_tensor_index_ = 0;
int segment_ids_tensor_index_ = 1; int segment_ids_tensor_index_ = 1;
int input_masks_tensor_index_ = 2; int input_masks_tensor_index_ = 2;
// Whether the model's input tensor shapes are dynamic.
bool has_dynamic_input_tensors_ = false;
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens. // Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and // This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
// clips the vector of tokens to have length at most `bert_max_seq_len_`. // clips the vector of tokens to have length at most `bert_max_seq_len_` if
// the input tensors are static.
std::vector<std::string> TokenizeInputText(absl::string_view input_text); std::vector<std::string> TokenizeInputText(absl::string_view input_text);
// Processes the `input_tokens` to generate the three input tensors for the // Processes the `input_tokens` to generate the three input tensors of size
// BERT model. // `tensor_size` for the BERT model.
std::vector<Tensor> GenerateInputTensors( std::vector<Tensor> GenerateInputTensors(
const std::vector<std::string>& input_tokens); const std::vector<std::string>& input_tokens, int tensor_size);
}; };
absl::Status BertPreprocessorCalculator::UpdateContract( absl::Status BertPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) { CalculatorContract* cc) {
const auto& options = const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>(); cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
if (options.has_dynamic_input_tensors()) {
return absl::OkStatus();
} else {
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required"; RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
RET_CHECK_GE(options.bert_max_seq_len(), 2) RET_CHECK_GE(options.bert_max_seq_len(), 2)
<< "bert_max_seq_len must be at least 2"; << "bert_max_seq_len must be at least 2";
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -178,12 +186,17 @@ absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
const auto& options = const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>(); cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
bert_max_seq_len_ = options.bert_max_seq_len(); bert_max_seq_len_ = options.bert_max_seq_len();
has_dynamic_input_tensors_ = options.has_dynamic_input_tensors();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) { absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
kTensorsOut(cc).Send( int tensor_size = bert_max_seq_len_;
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get()))); std::vector<std::string> input_tokens = TokenizeInputText(kTextIn(cc).Get());
if (has_dynamic_input_tensors_) {
tensor_size = input_tokens.size();
}
kTensorsOut(cc).Send(GenerateInputTensors(input_tokens, tensor_size));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -197,8 +210,11 @@ std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
// Offset by 2 to account for [CLS] and [SEP] // Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size = int input_tokens_size =
std::min(bert_max_seq_len_, static_cast<int>(tokenizer_result.subwords.size()) + 2;
static_cast<int>(tokenizer_result.subwords.size()) + 2); // For static shapes, truncate the input tokens to `bert_max_seq_len_`.
if (!has_dynamic_input_tensors_) {
input_tokens_size = std::min(bert_max_seq_len_, input_tokens_size);
}
std::vector<std::string> input_tokens; std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size); input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassifierToken)); input_tokens.push_back(std::string(kClassifierToken));
@ -210,16 +226,16 @@ std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
} }
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors( std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
const std::vector<std::string>& input_tokens) { const std::vector<std::string>& input_tokens, int tensor_size) {
std::vector<int32_t> input_ids(bert_max_seq_len_, 0); std::vector<int32_t> input_ids(tensor_size, 0);
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0); std::vector<int32_t> segment_ids(tensor_size, 0);
std::vector<int32_t> input_masks(bert_max_seq_len_, 0); std::vector<int32_t> input_masks(tensor_size, 0);
// Convert tokens back into ids and set mask // Convert tokens back into ids and set mask
for (int i = 0; i < input_tokens.size(); ++i) { for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]); tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_masks[i] = 1; input_masks[i] = 1;
} }
// |<--------bert_max_seq_len_--------->| // |<-----------tensor_size------------>|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0 // segment_ids 0 0 0... 0 0 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0 // input_masks 1 1 1... 1 1 0 0... 0
@ -228,7 +244,7 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
input_tensors.reserve(kNumInputTensorsForBert); input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) { for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back( input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})}); {Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
} }
std::memcpy(input_tensors[input_ids_tensor_index_] std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView() .GetCpuWriteView()

View File

@ -24,6 +24,10 @@ message BertPreprocessorCalculatorOptions {
optional BertPreprocessorCalculatorOptions ext = 462509271; optional BertPreprocessorCalculatorOptions ext = 462509271;
} }
// The maximum input sequence length for the calculator's BERT model. // The maximum input sequence length for the calculator's BERT model. Used
// if the model's input tensors have static shape.
optional int32 bert_max_seq_len = 1; optional int32 bert_max_seq_len = 1;
// Whether the BERT model's input tensors have dynamic shape.
optional bool has_dynamic_input_tensors = 2;
} }

View File

@ -43,7 +43,8 @@ constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite"; "mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator( absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
absl::string_view text, absl::string_view model_path) { absl::string_view text, absl::string_view model_path,
bool has_dynamic_input_tensors = false, int tensor_size = kBertMaxSeqLen) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"( absl::Substitute(R"(
input_stream: "text" input_stream: "text"
@ -56,11 +57,12 @@ absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
options { options {
[mediapipe.BertPreprocessorCalculatorOptions.ext] { [mediapipe.BertPreprocessorCalculatorOptions.ext] {
bert_max_seq_len: $0 bert_max_seq_len: $0
has_dynamic_input_tensors: $1
} }
} }
} }
)", )",
kBertMaxSeqLen)); tensor_size, has_dynamic_input_tensors));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets); tool::AddVectorSink("tensors", &graph_config, &output_packets);
@ -92,13 +94,13 @@ absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
} }
std::vector<std::vector<int>> results; std::vector<std::vector<int>> results;
for (int i = 0; i < kNumInputTensorsForBert; i++) { for (int i = 0; i < tensor_vec.size(); i++) {
const Tensor& tensor = tensor_vec[i]; const Tensor& tensor = tensor_vec[i];
if (tensor.element_type() != Tensor::ElementType::kInt32) { if (tensor.element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32"); return absl::InvalidArgumentError("Expected tensor element type kInt32");
} }
auto* buffer = tensor.GetCpuReadView().buffer<int>(); auto* buffer = tensor.GetCpuReadView().buffer<int>();
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen); std::vector<int> buffer_view(buffer, buffer + tensor_size);
results.push_back(buffer_view); results.push_back(buffer_view);
} }
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources()); MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());

View File

@ -30,4 +30,8 @@ message TextPreprocessingGraphOptions {
// The maximum input sequence length for the TFLite model. Used with // The maximum input sequence length for the TFLite model. Used with
// BERT_MODEL and REGEX_MODEL. // BERT_MODEL and REGEX_MODEL.
optional int32 max_seq_len = 2; optional int32 max_seq_len = 2;
// The model's input tensors are dynamic rather than static.
// Used with BERT_MODEL.
optional bool has_dynamic_input_tensors = 3;
} }

View File

@ -114,6 +114,60 @@ absl::StatusOr<int> GetMaxSeqLen(const tflite::SubGraph& model_graph) {
} }
return max_seq_len; return max_seq_len;
} }
// Determines whether the TFLite model for `model_graph` has input tensors with
// dynamic shape rather than static shape or returns an error if the input
// tensors have invalid shape signatures. This util assumes that the model has
// the correct input tensors type and count for the BertPreprocessorCalculator.
absl::StatusOr<bool> HasDynamicInputTensors(
const tflite::SubGraph& model_graph) {
const flatbuffers::Vector<int32_t>& input_indices = *model_graph.inputs();
const flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>&
model_tensors = *model_graph.tensors();
// Static input tensors may have undefined shape signatures.
if (absl::c_all_of(input_indices, [&model_tensors](int i) {
return model_tensors[i]->shape_signature() == nullptr;
})) {
return false;
} else if (absl::c_any_of(input_indices, [&model_tensors](int i) {
return model_tensors[i]->shape_signature() == nullptr;
})) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Input tensors contain a mix of defined and "
"undefined shape signatures.");
}
for (int i : input_indices) {
const tflite::Tensor* tensor = model_tensors[i];
if (tensor->shape_signature()->size() != 2) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute(
"Model should take 2-D shape signatures, got dimension: $0",
tensor->shape_signature()->size()),
MediaPipeTasksStatus::kInvalidInputTensorDimensionsError);
}
}
// For dynamic input tensors, the shape_signature entry corresponding to the
// input size is -1.
if (absl::c_all_of(input_indices, [&model_tensors](int i) {
return (*model_tensors[i]->shape_signature())[1] != -1;
})) {
return false;
} else if (absl::c_all_of(input_indices, [&model_tensors](int i) {
return (*model_tensors[i]->shape_signature())[1] == -1;
})) {
return true;
} else {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Input tensors contain a mix of static and dynamic shapes.");
}
return false;
}
} // namespace } // namespace
absl::Status ConfigureTextPreprocessingGraph( absl::Status ConfigureTextPreprocessingGraph(
@ -128,6 +182,8 @@ absl::Status ConfigureTextPreprocessingGraph(
ASSIGN_OR_RETURN(TextModelType::ModelType model_type, ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
GetModelType(model_resources)); GetModelType(model_resources));
const tflite::SubGraph& model_graph =
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
options.set_model_type(model_type); options.set_model_type(model_type);
switch (model_type) { switch (model_type) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
@ -137,13 +193,15 @@ absl::Status ConfigureTextPreprocessingGraph(
} }
case TextModelType::BERT_MODEL: case TextModelType::BERT_MODEL:
case TextModelType::REGEX_MODEL: { case TextModelType::REGEX_MODEL: {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(int max_seq_len, GetMaxSeqLen(model_graph));
int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
options.set_max_seq_len(max_seq_len); options.set_max_seq_len(max_seq_len);
} }
} }
if (model_type == TextModelType::BERT_MODEL) {
ASSIGN_OR_RETURN(bool has_dynamic_input_tensors,
HasDynamicInputTensors(model_graph));
options.set_has_dynamic_input_tensors(has_dynamic_input_tensors);
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -200,6 +258,8 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
case TextModelType::BERT_MODEL: { case TextModelType::BERT_MODEL: {
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>() text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
.set_bert_max_seq_len(options.max_seq_len()); .set_bert_max_seq_len(options.max_seq_len());
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
.set_has_dynamic_input_tensors(options.has_dynamic_input_tensors());
metadata_extractor_in >> metadata_extractor_in >>
text_preprocessor.SideIn(kMetadataExtractorTag); text_preprocessor.SideIn(kMetadataExtractorTag);
break; break;

View File

@ -49,6 +49,7 @@ using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
constexpr int kMaxSeqLen = 128; constexpr int kMaxSeqLen = 128;
const float kPrecision = 1e-6;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
@ -66,7 +67,6 @@ std::string GetFullPath(absl::string_view file_name) {
// TODO: create shared matcher for ClassificationResult. // TODO: create shared matcher for ClassificationResult.
void ExpectApproximatelyEqual(const TextClassifierResult& actual, void ExpectApproximatelyEqual(const TextClassifierResult& actual,
const TextClassifierResult& expected) { const TextClassifierResult& expected) {
const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications.size(), expected.classifications.size()); ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications.size(); ++i) { for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications[i]; const Classifications& a = actual.classifications[i];