Open-source bert_preprocessor_calculator
PiperOrigin-RevId: 481246966
This commit is contained in:
parent
42543f7ad6
commit
ca28a19822
|
@ -161,6 +161,44 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "bert_preprocessor_calculator_proto",
|
||||||
|
srcs = ["bert_preprocessor_calculator.proto"],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "bert_preprocessor_calculator",
|
||||||
|
srcs = ["bert_preprocessor_calculator.cc"],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":bert_preprocessor_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
|
||||||
|
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "inference_calculator_proto",
|
name = "inference_calculator_proto",
|
||||||
srcs = ["inference_calculator.proto"],
|
srcs = ["inference_calculator.proto"],
|
||||||
|
|
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||||
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
|
constexpr int kNumInputTensorsForBert = 3;
|
||||||
|
constexpr int kTokenizerProcessUnitIndex = 0;
|
||||||
|
constexpr absl::string_view kInputIdsTensorName = "ids";
|
||||||
|
constexpr absl::string_view kInputMasksTensorName = "mask";
|
||||||
|
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
|
||||||
|
constexpr absl::string_view kClassifierToken = "[CLS]";
|
||||||
|
constexpr absl::string_view kSeparatorToken = "[SEP]";
|
||||||
|
|
||||||
|
// Preprocesses input text into three int32 input tensors for a BERT model using
|
||||||
|
// a tokenizer.
|
||||||
|
// The associated BERT model is expected to contain input tensors with names:
|
||||||
|
//
|
||||||
|
// Tensor | Metadata Name
|
||||||
|
// ---------------- | --------------
|
||||||
|
// IDs | "ids"
|
||||||
|
// Segment IDs | "segment_ids"
|
||||||
|
// Mask | "mask"
|
||||||
|
//
|
||||||
|
// This calculator will return an error if the model does not have three input
|
||||||
|
// tensors or if the tensors do not have names corresponding to the above
|
||||||
|
// metadata names in some order. Additional details regarding these input
|
||||||
|
// tensors are given in the Calculator "Outputs" section below.
|
||||||
|
//
|
||||||
|
// This calculator is currently configured for the TextClassifier Task but it
|
||||||
|
// will eventually be generalized for other Text Tasks.
|
||||||
|
// TODO: Handle preprocessing for other Text Tasks too.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// TEXT - std::string
|
||||||
|
// The input text.
|
||||||
|
// Side Inputs:
|
||||||
|
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||||
|
// The metadata extractor for the BERT model. Used to determine the order of
|
||||||
|
// the three input Tensors for the BERT model and to extract the metadata to
|
||||||
|
// construct the tokenizer.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// Vector containing the three input Tensors for the BERT model:
|
||||||
|
// (1): the token ids of the tokenized input string. A classifier token
|
||||||
|
// ("[CLS]") will be prepended to the input tokens and a separator
|
||||||
|
// token ("[SEP]") will be appended to the input tokens.
|
||||||
|
// (2): the segment ids, which are all 0 for now but will have different
|
||||||
|
// values to distinguish between different sentences in the input
|
||||||
|
// text for other Text tasks.
|
||||||
|
// (3): the input mask ids, which are 1 at each of the input token indices
|
||||||
|
// and 0 elsewhere.
|
||||||
|
// The Tensors will have size equal to the max sequence length for the BERT
|
||||||
|
// model.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "BertPreprocessorCalculator"
|
||||||
|
// input_stream: "TEXT:text"
|
||||||
|
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||||
|
// bert_max_seq_len: 128
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class BertPreprocessorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||||
|
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||||
|
"METADATA_EXTRACTOR"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
|
||||||
|
// The max sequence length accepted by the BERT model.
|
||||||
|
int bert_max_seq_len_ = 2;
|
||||||
|
// Indices of the three input tensors for the BERT model. They should form the
|
||||||
|
// set {0, 1, 2}.
|
||||||
|
int input_ids_tensor_index_ = 0;
|
||||||
|
int segment_ids_tensor_index_ = 1;
|
||||||
|
int input_masks_tensor_index_ = 2;
|
||||||
|
|
||||||
|
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
|
||||||
|
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
|
||||||
|
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
|
||||||
|
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
|
||||||
|
// Processes the `input_tokens` to generate the three input tensors for the
|
||||||
|
// BERT model.
|
||||||
|
std::vector<Tensor> GenerateInputTensors(
|
||||||
|
const std::vector<std::string>& input_tokens);
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status BertPreprocessorCalculator::UpdateContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||||
|
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
|
||||||
|
RET_CHECK_GE(options.bert_max_seq_len(), 2)
|
||||||
|
<< "bert_max_seq_len must be at least 2";
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||||
|
const ModelMetadataExtractor* metadata_extractor =
|
||||||
|
&kMetadataExtractorSideIn(cc).Get();
|
||||||
|
const tflite::ProcessUnit* tokenizer_metadata =
|
||||||
|
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
|
||||||
|
ASSIGN_OR_RETURN(tokenizer_,
|
||||||
|
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
|
||||||
|
tokenizer_metadata, metadata_extractor));
|
||||||
|
|
||||||
|
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||||
|
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kInputIdsTensorName);
|
||||||
|
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kSegmentIdsTensorName);
|
||||||
|
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kInputMasksTensorName);
|
||||||
|
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
|
||||||
|
segment_ids_tensor_index_,
|
||||||
|
input_masks_tensor_index_};
|
||||||
|
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||||
|
input_ids_tensor_index_, segment_ids_tensor_index_,
|
||||||
|
input_masks_tensor_index_));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||||
|
bert_max_seq_len_ = options.bert_max_seq_len();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||||
|
kTensorsOut(cc).Send(
|
||||||
|
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
|
||||||
|
absl::string_view input_text) {
|
||||||
|
std::string processed_input = std::string(input_text);
|
||||||
|
absl::AsciiStrToLower(&processed_input);
|
||||||
|
|
||||||
|
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||||
|
tokenizer_->Tokenize(processed_input);
|
||||||
|
|
||||||
|
// Offset by 2 to account for [CLS] and [SEP]
|
||||||
|
int input_tokens_size =
|
||||||
|
std::min(bert_max_seq_len_,
|
||||||
|
static_cast<int>(tokenizer_result.subwords.size()) + 2);
|
||||||
|
std::vector<std::string> input_tokens;
|
||||||
|
input_tokens.reserve(input_tokens_size);
|
||||||
|
input_tokens.push_back(std::string(kClassifierToken));
|
||||||
|
for (int i = 0; i < input_tokens_size - 2; ++i) {
|
||||||
|
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
|
||||||
|
}
|
||||||
|
input_tokens.push_back(std::string(kSeparatorToken));
|
||||||
|
return input_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||||
|
const std::vector<std::string>& input_tokens) {
|
||||||
|
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
|
||||||
|
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
|
||||||
|
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
|
||||||
|
// Convert tokens back into ids and set mask
|
||||||
|
for (int i = 0; i < input_tokens.size(); ++i) {
|
||||||
|
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
|
||||||
|
input_masks[i] = 1;
|
||||||
|
}
|
||||||
|
// |<--------bert_max_seq_len_--------->|
|
||||||
|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
|
||||||
|
// segment_ids 0 0 0... 0 0 0 0... 0
|
||||||
|
// input_masks 1 1 1... 1 1 0 0... 0
|
||||||
|
|
||||||
|
std::vector<Tensor> input_tensors;
|
||||||
|
input_tensors.reserve(kNumInputTensorsForBert);
|
||||||
|
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||||
|
input_tensors.push_back(
|
||||||
|
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
|
||||||
|
}
|
||||||
|
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||||
|
.GetCpuWriteView()
|
||||||
|
.buffer<int32_t>(),
|
||||||
|
input_ids.data(), input_ids.size() * sizeof(int32_t));
|
||||||
|
std::memcpy(input_tensors[segment_ids_tensor_index_]
|
||||||
|
.GetCpuWriteView()
|
||||||
|
.buffer<int32_t>(),
|
||||||
|
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
|
||||||
|
std::memcpy(input_tensors[input_masks_tensor_index_]
|
||||||
|
.GetCpuWriteView()
|
||||||
|
.buffer<int32_t>(),
|
||||||
|
input_masks.data(), input_masks.size() * sizeof(int32_t));
|
||||||
|
return input_tensors;
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message BertPreprocessorCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional BertPreprocessorCalculatorOptions ext = 462509271;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The maximum input sequence length for the calculator's BERT model.
|
||||||
|
optional int32 bert_max_seq_len = 1;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user