diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 4fce8bf8e..18107463c 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -161,6 +161,44 @@ cc_test( ], ) +mediapipe_proto_library( + name = "bert_preprocessor_calculator_proto", + srcs = ["bert_preprocessor_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "bert_preprocessor_calculator", + srcs = ["bert_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":bert_preprocessor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc new file mode 100644 index 000000000..d464c3929 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc @@ -0,0 +1,251 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::core::FindTensorIndexByMetadataName; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kTokenizerProcessUnitIndex = 0; +constexpr absl::string_view kInputIdsTensorName = "ids"; +constexpr absl::string_view kInputMasksTensorName = "mask"; +constexpr absl::string_view kSegmentIdsTensorName = "segment_ids"; +constexpr absl::string_view kClassifierToken = "[CLS]"; +constexpr absl::string_view kSeparatorToken = "[SEP]"; + +// Preprocesses input text into three int32 input tensors for a BERT model using +// a tokenizer. +// The associated BERT model is expected to contain input tensors with names: +// +// Tensor | Metadata Name +// ---------------- | -------------- +// IDs | "ids" +// Segment IDs | "segment_ids" +// Mask | "mask" +// +// This calculator will return an error if the model does not have three input +// tensors or if the tensors do not have names corresponding to the above +// metadata names in some order. Additional details regarding these input +// tensors are given in the Calculator "Outputs" section below. +// +// This calculator is currently configured for the TextClassifier Task but it +// will eventually be generalized for other Text Tasks. +// TODO: Handle preprocessing for other Text Tasks too. +// +// Inputs: +// TEXT - std::string +// The input text. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the BERT model. Used to determine the order of +// the three input Tensors for the BERT model and to extract the metadata to +// construct the tokenizer. +// +// Outputs: +// TENSORS - std::vector +// Vector containing the three input Tensors for the BERT model: +// (1): the token ids of the tokenized input string. A classifier token +// ("[CLS]") will be prepended to the input tokens and a separator +// token ("[SEP]") will be appended to the input tokens. +// (2): the segment ids, which are all 0 for now but will have different +// values to distinguish between different sentences in the input +// text for other Text tasks. +// (3): the input mask ids, which are 1 at each of the input token indices +// and 0 elsewhere. +// The Tensors will have size equal to the max sequence length for the BERT +// model. +// +// Example: +// node { +// calculator: "BertPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// options { +// [mediapipe.BertPreprocessorCalculatorOptions.ext] { +// bert_max_seq_len: 128 +// } +// } +// } +class BertPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::unique_ptr tokenizer_; + // The max sequence length accepted by the BERT model. + int bert_max_seq_len_ = 2; + // Indices of the three input tensors for the BERT model. They should form the + // set {0, 1, 2}. + int input_ids_tensor_index_ = 0; + int segment_ids_tensor_index_ = 1; + int input_masks_tensor_index_ = 2; + + // Applies `tokenizer_` to the `input_text` to generate a vector of tokens. + // This util prepends "[CLS]" and appends "[SEP]" to the input tokens and + // clips the vector of tokens to have length at most `bert_max_seq_len_`. + std::vector TokenizeInputText(absl::string_view input_text); + // Processes the `input_tokens` to generate the three input tensors for the + // BERT model. + std::vector GenerateInputTensors( + const std::vector& input_tokens); +}; + +absl::Status BertPreprocessorCalculator::UpdateContract( + CalculatorContract* cc) { + const auto& options = + cc->Options(); + RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required"; + RET_CHECK_GE(options.bert_max_seq_len(), 2) + << "bert_max_seq_len must be at least 2"; + return absl::OkStatus(); +} + +absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + const tflite::ProcessUnit* tokenizer_metadata = + metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex); + ASSIGN_OR_RETURN(tokenizer_, + tasks::text::tokenizers::CreateTokenizerFromProcessUnit( + tokenizer_metadata, metadata_extractor)); + + auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata(); + input_ids_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kInputIdsTensorName); + segment_ids_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kSegmentIdsTensorName); + input_masks_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kInputMasksTensorName); + absl::flat_hash_set tensor_indices = {input_ids_tensor_index_, + segment_ids_tensor_index_, + input_masks_tensor_index_}; + if (tensor_indices != absl::flat_hash_set({0, 1, 2})) { + return absl::InvalidArgumentError(absl::Substitute( + "Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}", + input_ids_tensor_index_, segment_ids_tensor_index_, + input_masks_tensor_index_)); + } + + const auto& options = + cc->Options(); + bert_max_seq_len_ = options.bert_max_seq_len(); + return absl::OkStatus(); +} + +absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) { + kTensorsOut(cc).Send( + GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get()))); + return absl::OkStatus(); +} + +std::vector BertPreprocessorCalculator::TokenizeInputText( + absl::string_view input_text) { + std::string processed_input = std::string(input_text); + absl::AsciiStrToLower(&processed_input); + + tasks::text::tokenizers::TokenizerResult tokenizer_result = + tokenizer_->Tokenize(processed_input); + + // Offset by 2 to account for [CLS] and [SEP] + int input_tokens_size = + std::min(bert_max_seq_len_, + static_cast(tokenizer_result.subwords.size()) + 2); + std::vector input_tokens; + input_tokens.reserve(input_tokens_size); + input_tokens.push_back(std::string(kClassifierToken)); + for (int i = 0; i < input_tokens_size - 2; ++i) { + input_tokens.push_back(std::move(tokenizer_result.subwords[i])); + } + input_tokens.push_back(std::string(kSeparatorToken)); + return input_tokens; +} + +std::vector BertPreprocessorCalculator::GenerateInputTensors( + const std::vector& input_tokens) { + std::vector input_ids(bert_max_seq_len_, 0); + std::vector segment_ids(bert_max_seq_len_, 0); + std::vector input_masks(bert_max_seq_len_, 0); + // Convert tokens back into ids and set mask + for (int i = 0; i < input_tokens.size(); ++i) { + tokenizer_->LookupId(input_tokens[i], &input_ids[i]); + input_masks[i] = 1; + } + // |<--------bert_max_seq_len_--------->| + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 + // segment_ids 0 0 0... 0 0 0 0... 0 + // input_masks 1 1 1... 1 1 0 0... 0 + + std::vector input_tensors; + input_tensors.reserve(kNumInputTensorsForBert); + for (int i = 0; i < kNumInputTensorsForBert; ++i) { + input_tensors.push_back( + {Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})}); + } + std::memcpy(input_tensors[input_ids_tensor_index_] + .GetCpuWriteView() + .buffer(), + input_ids.data(), input_ids.size() * sizeof(int32_t)); + std::memcpy(input_tensors[segment_ids_tensor_index_] + .GetCpuWriteView() + .buffer(), + segment_ids.data(), segment_ids.size() * sizeof(int32_t)); + std::memcpy(input_tensors[input_masks_tensor_index_] + .GetCpuWriteView() + .buffer(), + input_masks.data(), input_masks.size() * sizeof(int32_t)); + return input_tensors; +} + +MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto new file mode 100644 index 000000000..72b569143 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto @@ -0,0 +1,29 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message BertPreprocessorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional BertPreprocessorCalculatorOptions ext = 462509271; + } + + // The maximum input sequence length for the calculator's BERT model. + optional int32 bert_max_seq_len = 1; +}