Open-source bert_preprocessor_calculator

PiperOrigin-RevId: 481246966
This commit is contained in:
MediaPipe Team 2022-10-14 15:29:55 -07:00 committed by Copybara-Service
parent 42543f7ad6
commit ca28a19822
3 changed files with 318 additions and 0 deletions

View File

@ -161,6 +161,44 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "bert_preprocessor_calculator_proto",
srcs = ["bert_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bert_preprocessor_calculator",
srcs = ["bert_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":bert_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],

View File

@ -0,0 +1,251 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kTokenizerProcessUnitIndex = 0;
constexpr absl::string_view kInputIdsTensorName = "ids";
constexpr absl::string_view kInputMasksTensorName = "mask";
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
constexpr absl::string_view kClassifierToken = "[CLS]";
constexpr absl::string_view kSeparatorToken = "[SEP]";
// Preprocesses input text into three int32 input tensors for a BERT model using
// a tokenizer.
// The associated BERT model is expected to contain input tensors with names:
//
// Tensor | Metadata Name
// ---------------- | --------------
// IDs | "ids"
// Segment IDs | "segment_ids"
// Mask | "mask"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have names corresponding to the above
// metadata names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// This calculator is currently configured for the TextClassifier Task but it
// will eventually be generalized for other Text Tasks.
// TODO: Handle preprocessing for other Text Tasks too.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the BERT model. Used to determine the order of
// the three input Tensors for the BERT model and to extract the metadata to
// construct the tokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the BERT model:
// (1): the token ids of the tokenized input string. A classifier token
// ("[CLS]") will be prepended to the input tokens and a separator
// token ("[SEP]") will be appended to the input tokens.
// (2): the segment ids, which are all 0 for now but will have different
// values to distinguish between different sentences in the input
// text for other Text tasks.
// (3): the input mask ids, which are 1 at each of the input token indices
// and 0 elsewhere.
// The Tensors will have size equal to the max sequence length for the BERT
// model.
//
// Example:
// node {
// calculator: "BertPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
// bert_max_seq_len: 128
// }
// }
// }
class BertPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
// The max sequence length accepted by the BERT model.
int bert_max_seq_len_ = 2;
// Indices of the three input tensors for the BERT model. They should form the
// set {0, 1, 2}.
int input_ids_tensor_index_ = 0;
int segment_ids_tensor_index_ = 1;
int input_masks_tensor_index_ = 2;
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
// Processes the `input_tokens` to generate the three input tensors for the
// BERT model.
std::vector<Tensor> GenerateInputTensors(
const std::vector<std::string>& input_tokens);
};
absl::Status BertPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
RET_CHECK_GE(options.bert_max_seq_len(), 2)
<< "bert_max_seq_len must be at least 2";
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::ProcessUnit* tokenizer_metadata =
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
tokenizer_metadata, metadata_extractor));
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputIdsTensorName);
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kSegmentIdsTensorName);
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputMasksTensorName);
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
segment_ids_tensor_index_,
input_masks_tensor_index_};
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
input_ids_tensor_index_, segment_ids_tensor_index_,
input_masks_tensor_index_));
}
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
bert_max_seq_len_ = options.bert_max_seq_len();
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
kTensorsOut(cc).Send(
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
return absl::OkStatus();
}
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
absl::string_view input_text) {
std::string processed_input = std::string(input_text);
absl::AsciiStrToLower(&processed_input);
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(processed_input);
// Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size =
std::min(bert_max_seq_len_,
static_cast<int>(tokenizer_result.subwords.size()) + 2);
std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassifierToken));
for (int i = 0; i < input_tokens_size - 2; ++i) {
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
}
input_tokens.push_back(std::string(kSeparatorToken));
return input_tokens;
}
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
const std::vector<std::string>& input_tokens) {
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_masks[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
}
std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_ids.data(), input_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[segment_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[input_masks_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_masks.data(), input_masks.size() * sizeof(int32_t));
return input_tensors;
}
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BertPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BertPreprocessorCalculatorOptions ext = 462509271;
}
// The maximum input sequence length for the calculator's BERT model.
optional int32 bert_max_seq_len = 1;
}