Open-source the regex_preprocessing_calculator

PiperOrigin-RevId: 481256045
This commit is contained in:
MediaPipe Team 2022-10-14 16:16:23 -07:00 committed by Copybara-Service
parent eb52b72707
commit 5f3d5728e8
3 changed files with 238 additions and 0 deletions

View File

@ -199,6 +199,41 @@ cc_library(
alwayslink = 1,
)
mediapipe_proto_library(
name = "regex_preprocessor_calculator_proto",
srcs = ["regex_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "regex_preprocessor_calculator",
srcs = ["regex_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":regex_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
mediapipe_proto_library(
name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"],

View File

@ -0,0 +1,174 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
// Preprocesses input text into one int32 input tensor for a text model using
// a RegexTokenizer.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the text model. Used to extract the metadata
// to construct the RegexTokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor which is the text model's input tensor.
// Depending on the tokenizer metadata, the tensor may start with
// the id of the tokenizer's <START> token. The following tensor values will
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
// have the id of the <UNKNOWN> token. The tensor will be padded with the
// <PAD> token id to have size equal to the max sequence length for the text
// model.
//
// Example:
// node {
// calculator: "RegexPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
// max_seq_len: 256
// }
// }
// }
class RegexPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
// The max sequence length accepted by the text model.
int max_seq_len_ = 0;
};
absl::Status RegexPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::TensorMetadata* tensor_metadata =
metadata_extractor->GetInputTensorMetadata(0);
if (tensor_metadata == nullptr) {
return absl::InvalidArgumentError("No tensor metadata found");
}
ASSIGN_OR_RETURN(
const auto* tokenizer_metadata,
metadata_extractor->FindFirstProcessUnit(
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
if (tokenizer_metadata == nullptr) {
return absl::InvalidArgumentError("No tokenizer metadata found");
}
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
regex_tokenizer_options, metadata_extractor));
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
max_seq_len_ = options.max_seq_len();
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(kTextIn(cc).Get());
int unknown_token_id = 0;
tokenizer_->GetUnknownToken(&unknown_token_id);
int pad_token_id = 0;
tokenizer_->GetPadToken(&pad_token_id);
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
int start_token_id = 0;
int input_token_index = 0;
if (tokenizer_->GetStartToken(&start_token_id)) {
input_tokens[0] = start_token_id;
input_token_index = 1;
}
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
(input_token_index < max_seq_len_);
++i, ++input_token_index) {
const std::string& token = tokenizer_result.subwords[i];
int token_id = 0;
if (tokenizer_->LookupId(token, &token_id)) {
input_tokens[input_token_index] = token_id;
} else {
input_tokens[input_token_index] = unknown_token_id;
}
}
// |<-------sentence_length-------->|
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
// not found in the tokenizer vocab.
std::vector<Tensor> result;
result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RegexPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional RegexPreprocessorCalculatorOptions ext = 463716697;
}
// The maximum input sequence length for the calculator's text model.
optional int32 max_seq_len = 1;
}