Internal MediaPipe Tasks change.
PiperOrigin-RevId: 519878741
This commit is contained in:
parent
2d553a57e8
commit
5f8831660f
|
@ -237,7 +237,9 @@ cc_library(
|
|||
cc_test(
|
||||
name = "bert_preprocessor_calculator_test",
|
||||
srcs = ["bert_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator",
|
||||
|
@ -250,7 +252,7 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -300,7 +302,7 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -121,31 +121,39 @@ class BertPreprocessorCalculator : public Node {
|
|||
|
||||
private:
|
||||
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
|
||||
// The max sequence length accepted by the BERT model.
|
||||
// The max sequence length accepted by the BERT model if its input tensors
|
||||
// are static.
|
||||
int bert_max_seq_len_ = 2;
|
||||
// Indices of the three input tensors for the BERT model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int input_ids_tensor_index_ = 0;
|
||||
int segment_ids_tensor_index_ = 1;
|
||||
int input_masks_tensor_index_ = 2;
|
||||
// Whether the model's input tensor shapes are dynamic.
|
||||
bool has_dynamic_input_tensors_ = false;
|
||||
|
||||
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
|
||||
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
|
||||
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
|
||||
// clips the vector of tokens to have length at most `bert_max_seq_len_` if
|
||||
// the input tensors are static.
|
||||
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
|
||||
// Processes the `input_tokens` to generate the three input tensors for the
|
||||
// BERT model.
|
||||
// Processes the `input_tokens` to generate the three input tensors of size
|
||||
// `tensor_size` for the BERT model.
|
||||
std::vector<Tensor> GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens);
|
||||
const std::vector<std::string>& input_tokens, int tensor_size);
|
||||
};
|
||||
|
||||
absl::Status BertPreprocessorCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
if (options.has_dynamic_input_tensors()) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
|
||||
RET_CHECK_GE(options.bert_max_seq_len(), 2)
|
||||
<< "bert_max_seq_len must be at least 2";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -178,12 +186,17 @@ absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
|
|||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
bert_max_seq_len_ = options.bert_max_seq_len();
|
||||
has_dynamic_input_tensors_ = options.has_dynamic_input_tensors();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||
kTensorsOut(cc).Send(
|
||||
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
|
||||
int tensor_size = bert_max_seq_len_;
|
||||
std::vector<std::string> input_tokens = TokenizeInputText(kTextIn(cc).Get());
|
||||
if (has_dynamic_input_tensors_) {
|
||||
tensor_size = input_tokens.size();
|
||||
}
|
||||
kTensorsOut(cc).Send(GenerateInputTensors(input_tokens, tensor_size));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -197,8 +210,11 @@ std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
|
|||
|
||||
// Offset by 2 to account for [CLS] and [SEP]
|
||||
int input_tokens_size =
|
||||
std::min(bert_max_seq_len_,
|
||||
static_cast<int>(tokenizer_result.subwords.size()) + 2);
|
||||
static_cast<int>(tokenizer_result.subwords.size()) + 2;
|
||||
// For static shapes, truncate the input tokens to `bert_max_seq_len_`.
|
||||
if (!has_dynamic_input_tensors_) {
|
||||
input_tokens_size = std::min(bert_max_seq_len_, input_tokens_size);
|
||||
}
|
||||
std::vector<std::string> input_tokens;
|
||||
input_tokens.reserve(input_tokens_size);
|
||||
input_tokens.push_back(std::string(kClassifierToken));
|
||||
|
@ -210,16 +226,16 @@ std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
|
|||
}
|
||||
|
||||
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens) {
|
||||
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
|
||||
const std::vector<std::string>& input_tokens, int tensor_size) {
|
||||
std::vector<int32_t> input_ids(tensor_size, 0);
|
||||
std::vector<int32_t> segment_ids(tensor_size, 0);
|
||||
std::vector<int32_t> input_masks(tensor_size, 0);
|
||||
// Convert tokens back into ids and set mask
|
||||
for (int i = 0; i < input_tokens.size(); ++i) {
|
||||
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
|
||||
input_masks[i] = 1;
|
||||
}
|
||||
// |<--------bert_max_seq_len_--------->|
|
||||
// |<-----------tensor_size------------>|
|
||||
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
|
||||
// segment_ids 0 0 0... 0 0 0 0... 0
|
||||
// input_masks 1 1 1... 1 1 0 0... 0
|
||||
|
@ -228,7 +244,7 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
|||
input_tensors.reserve(kNumInputTensorsForBert);
|
||||
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
|
||||
}
|
||||
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
|
|
|
@ -24,6 +24,10 @@ message BertPreprocessorCalculatorOptions {
|
|||
optional BertPreprocessorCalculatorOptions ext = 462509271;
|
||||
}
|
||||
|
||||
// The maximum input sequence length for the calculator's BERT model.
|
||||
// The maximum input sequence length for the calculator's BERT model. Used
|
||||
// if the model's input tensors have static shape.
|
||||
optional int32 bert_max_seq_len = 1;
|
||||
|
||||
// Whether the BERT model's input tensors have dynamic shape.
|
||||
optional bool has_dynamic_input_tensors = 2;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,8 @@ constexpr absl::string_view kTestModelPath =
|
|||
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
|
||||
absl::string_view text, absl::string_view model_path) {
|
||||
absl::string_view text, absl::string_view model_path,
|
||||
bool has_dynamic_input_tensors = false, int tensor_size = kBertMaxSeqLen) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "text"
|
||||
|
@ -56,11 +57,12 @@ absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
|
|||
options {
|
||||
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||
bert_max_seq_len: $0
|
||||
has_dynamic_input_tensors: $1
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
kBertMaxSeqLen));
|
||||
tensor_size, has_dynamic_input_tensors));
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
|
@ -92,13 +94,13 @@ absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
|
|||
}
|
||||
|
||||
std::vector<std::vector<int>> results;
|
||||
for (int i = 0; i < kNumInputTensorsForBert; i++) {
|
||||
for (int i = 0; i < tensor_vec.size(); i++) {
|
||||
const Tensor& tensor = tensor_vec[i];
|
||||
if (tensor.element_type() != Tensor::ElementType::kInt32) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kInt32");
|
||||
}
|
||||
auto* buffer = tensor.GetCpuReadView().buffer<int>();
|
||||
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
|
||||
std::vector<int> buffer_view(buffer, buffer + tensor_size);
|
||||
results.push_back(buffer_view);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
|
||||
|
|
|
@ -30,4 +30,8 @@ message TextPreprocessingGraphOptions {
|
|||
// The maximum input sequence length for the TFLite model. Used with
|
||||
// BERT_MODEL and REGEX_MODEL.
|
||||
optional int32 max_seq_len = 2;
|
||||
|
||||
// The model's input tensors are dynamic rather than static.
|
||||
// Used with BERT_MODEL.
|
||||
optional bool has_dynamic_input_tensors = 3;
|
||||
}
|
||||
|
|
|
@ -114,6 +114,60 @@ absl::StatusOr<int> GetMaxSeqLen(const tflite::SubGraph& model_graph) {
|
|||
}
|
||||
return max_seq_len;
|
||||
}
|
||||
|
||||
// Determines whether the TFLite model for `model_graph` has input tensors with
|
||||
// dynamic shape rather than static shape or returns an error if the input
|
||||
// tensors have invalid shape signatures. This util assumes that the model has
|
||||
// the correct input tensors type and count for the BertPreprocessorCalculator.
|
||||
absl::StatusOr<bool> HasDynamicInputTensors(
|
||||
const tflite::SubGraph& model_graph) {
|
||||
const flatbuffers::Vector<int32_t>& input_indices = *model_graph.inputs();
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>&
|
||||
model_tensors = *model_graph.tensors();
|
||||
|
||||
// Static input tensors may have undefined shape signatures.
|
||||
if (absl::c_all_of(input_indices, [&model_tensors](int i) {
|
||||
return model_tensors[i]->shape_signature() == nullptr;
|
||||
})) {
|
||||
return false;
|
||||
} else if (absl::c_any_of(input_indices, [&model_tensors](int i) {
|
||||
return model_tensors[i]->shape_signature() == nullptr;
|
||||
})) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"Input tensors contain a mix of defined and "
|
||||
"undefined shape signatures.");
|
||||
}
|
||||
|
||||
for (int i : input_indices) {
|
||||
const tflite::Tensor* tensor = model_tensors[i];
|
||||
if (tensor->shape_signature()->size() != 2) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute(
|
||||
"Model should take 2-D shape signatures, got dimension: $0",
|
||||
tensor->shape_signature()->size()),
|
||||
MediaPipeTasksStatus::kInvalidInputTensorDimensionsError);
|
||||
}
|
||||
}
|
||||
|
||||
// For dynamic input tensors, the shape_signature entry corresponding to the
|
||||
// input size is -1.
|
||||
if (absl::c_all_of(input_indices, [&model_tensors](int i) {
|
||||
return (*model_tensors[i]->shape_signature())[1] != -1;
|
||||
})) {
|
||||
return false;
|
||||
} else if (absl::c_all_of(input_indices, [&model_tensors](int i) {
|
||||
return (*model_tensors[i]->shape_signature())[1] == -1;
|
||||
})) {
|
||||
return true;
|
||||
} else {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Input tensors contain a mix of static and dynamic shapes.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureTextPreprocessingGraph(
|
||||
|
@ -128,6 +182,8 @@ absl::Status ConfigureTextPreprocessingGraph(
|
|||
|
||||
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
|
||||
GetModelType(model_resources));
|
||||
const tflite::SubGraph& model_graph =
|
||||
*(*model_resources.GetTfLiteModel()->subgraphs())[0];
|
||||
options.set_model_type(model_type);
|
||||
switch (model_type) {
|
||||
case TextModelType::UNSPECIFIED_MODEL:
|
||||
|
@ -137,13 +193,15 @@ absl::Status ConfigureTextPreprocessingGraph(
|
|||
}
|
||||
case TextModelType::BERT_MODEL:
|
||||
case TextModelType::REGEX_MODEL: {
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
ASSIGN_OR_RETURN(int max_seq_len, GetMaxSeqLen(model_graph));
|
||||
options.set_max_seq_len(max_seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
if (model_type == TextModelType::BERT_MODEL) {
|
||||
ASSIGN_OR_RETURN(bool has_dynamic_input_tensors,
|
||||
HasDynamicInputTensors(model_graph));
|
||||
options.set_has_dynamic_input_tensors(has_dynamic_input_tensors);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -200,6 +258,8 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
|||
case TextModelType::BERT_MODEL: {
|
||||
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
|
||||
.set_bert_max_seq_len(options.max_seq_len());
|
||||
text_preprocessor.GetOptions<BertPreprocessorCalculatorOptions>()
|
||||
.set_has_dynamic_input_tensors(options.has_dynamic_input_tensors());
|
||||
metadata_extractor_in >>
|
||||
text_preprocessor.SideIn(kMetadataExtractorTag);
|
||||
break;
|
||||
|
|
|
@ -49,6 +49,7 @@ using ::testing::HasSubstr;
|
|||
using ::testing::Optional;
|
||||
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
const float kPrecision = 1e-6;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||
|
@ -66,7 +67,6 @@ std::string GetFullPath(absl::string_view file_name) {
|
|||
// TODO: create shared matcher for ClassificationResult.
|
||||
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||
const TextClassifierResult& expected) {
|
||||
const float kPrecision = 1e-6;
|
||||
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
|
||||
for (int i = 0; i < actual.classifications.size(); ++i) {
|
||||
const Classifications& a = actual.classifications[i];
|
||||
|
|
Loading…
Reference in New Issue
Block a user