diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 18107463c..6b63403f7 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -199,6 +199,41 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "regex_preprocessor_calculator_proto", + srcs = ["regex_preprocessor_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "regex_preprocessor_calculator", + srcs = ["regex_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":regex_preprocessor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc new file mode 100644 index 000000000..92a5f0266 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc @@ -0,0 +1,174 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +// Preprocesses input text into one int32 input tensor for a text model using +// a RegexTokenizer. +// +// Inputs: +// TEXT - std::string +// The input text. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the text model. Used to extract the metadata +// to construct the RegexTokenizer. +// +// Outputs: +// TENSORS - std::vector +// Vector containing a single Tensor which is the text model's input tensor. +// Depending on the tokenizer metadata, the tensor may start with +// the id of the tokenizer's token. The following tensor values will +// be the ids of the tokens of the input text. Any out-of-vocab tokens will +// have the id of the token. The tensor will be padded with the +// token id to have size equal to the max sequence length for the text +// model. +// +// Example: +// node { +// calculator: "RegexPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// options { +// [mediapipe.RegexPreprocessorCalculatorOptions.ext] { +// max_seq_len: 256 +// } +// } +// } +class RegexPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::unique_ptr tokenizer_; + // The max sequence length accepted by the text model. + int max_seq_len_ = 0; +}; + +absl::Status RegexPreprocessorCalculator::UpdateContract( + CalculatorContract* cc) { + const auto& options = + cc->Options(); + RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required"; + RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive"; + return absl::OkStatus(); +} + +absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + const tflite::TensorMetadata* tensor_metadata = + metadata_extractor->GetInputTensorMetadata(0); + if (tensor_metadata == nullptr) { + return absl::InvalidArgumentError("No tensor metadata found"); + } + + ASSIGN_OR_RETURN( + const auto* tokenizer_metadata, + metadata_extractor->FindFirstProcessUnit( + *tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions)); + if (tokenizer_metadata == nullptr) { + return absl::InvalidArgumentError("No tokenizer metadata found"); + } + const tflite::RegexTokenizerOptions* regex_tokenizer_options = + tokenizer_metadata->options_as(); + ASSIGN_OR_RETURN(tokenizer_, + tasks::text::tokenizers::CreateRegexTokenizerFromOptions( + regex_tokenizer_options, metadata_extractor)); + + const auto& options = + cc->Options(); + max_seq_len_ = options.max_seq_len(); + return absl::OkStatus(); +} + +absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) { + tasks::text::tokenizers::TokenizerResult tokenizer_result = + tokenizer_->Tokenize(kTextIn(cc).Get()); + + int unknown_token_id = 0; + tokenizer_->GetUnknownToken(&unknown_token_id); + int pad_token_id = 0; + tokenizer_->GetPadToken(&pad_token_id); + + std::vector input_tokens(max_seq_len_, pad_token_id); + int start_token_id = 0; + int input_token_index = 0; + if (tokenizer_->GetStartToken(&start_token_id)) { + input_tokens[0] = start_token_id; + input_token_index = 1; + } + + for (int i = 0; (i < tokenizer_result.subwords.size()) && + (input_token_index < max_seq_len_); + ++i, ++input_token_index) { + const std::string& token = tokenizer_result.subwords[i]; + int token_id = 0; + if (tokenizer_->LookupId(token, &token_id)) { + input_tokens[input_token_index] = token_id; + } else { + input_tokens[input_token_index] = unknown_token_id; + } + } + + // |<-------sentence_length-------->| + // input_tensor , t1, t2... , ... + // is optional, t1, t2... will be replaced by if it's + // not found in the tokenizer vocab. + std::vector result; + result.push_back( + {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})}); + std::memcpy(result[0].GetCpuWriteView().buffer(), + input_tokens.data(), input_tokens.size() * sizeof(int32_t)); + kTensorsOut(cc).Send(std::move(result)); + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto new file mode 100644 index 000000000..793067a80 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto @@ -0,0 +1,29 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message RegexPreprocessorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional RegexPreprocessorCalculatorOptions ext = 463716697; + } + + // The maximum input sequence length for the calculator's text model. + optional int32 max_seq_len = 1; +}