Open-source the regex_preprocessing_calculator
PiperOrigin-RevId: 481256045
This commit is contained in:
parent
eb52b72707
commit
5f3d5728e8
|
@ -199,6 +199,41 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "regex_preprocessor_calculator_proto",
|
||||||
|
srcs = ["regex_preprocessor_calculator.proto"],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "regex_preprocessor_calculator",
|
||||||
|
srcs = ["regex_preprocessor_calculator.cc"],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/framework:mediapipe_internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":regex_preprocessor_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
|
||||||
|
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "inference_calculator_proto",
|
name = "inference_calculator_proto",
|
||||||
srcs = ["inference_calculator.proto"],
|
srcs = ["inference_calculator.proto"],
|
||||||
|
|
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
|
// Preprocesses input text into one int32 input tensor for a text model using
|
||||||
|
// a RegexTokenizer.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// TEXT - std::string
|
||||||
|
// The input text.
|
||||||
|
// Side Inputs:
|
||||||
|
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||||
|
// The metadata extractor for the text model. Used to extract the metadata
|
||||||
|
// to construct the RegexTokenizer.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// Vector containing a single Tensor which is the text model's input tensor.
|
||||||
|
// Depending on the tokenizer metadata, the tensor may start with
|
||||||
|
// the id of the tokenizer's <START> token. The following tensor values will
|
||||||
|
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
|
||||||
|
// have the id of the <UNKNOWN> token. The tensor will be padded with the
|
||||||
|
// <PAD> token id to have size equal to the max sequence length for the text
|
||||||
|
// model.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "RegexPreprocessorCalculator"
|
||||||
|
// input_stream: "TEXT:text"
|
||||||
|
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
|
||||||
|
// max_seq_len: 256
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class RegexPreprocessorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||||
|
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||||
|
"METADATA_EXTRACTOR"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
|
||||||
|
// The max sequence length accepted by the text model.
|
||||||
|
int max_seq_len_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status RegexPreprocessorCalculator::UpdateContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||||
|
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
|
||||||
|
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||||
|
const ModelMetadataExtractor* metadata_extractor =
|
||||||
|
&kMetadataExtractorSideIn(cc).Get();
|
||||||
|
const tflite::TensorMetadata* tensor_metadata =
|
||||||
|
metadata_extractor->GetInputTensorMetadata(0);
|
||||||
|
if (tensor_metadata == nullptr) {
|
||||||
|
return absl::InvalidArgumentError("No tensor metadata found");
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
const auto* tokenizer_metadata,
|
||||||
|
metadata_extractor->FindFirstProcessUnit(
|
||||||
|
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
|
||||||
|
if (tokenizer_metadata == nullptr) {
|
||||||
|
return absl::InvalidArgumentError("No tokenizer metadata found");
|
||||||
|
}
|
||||||
|
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
|
||||||
|
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
|
||||||
|
ASSIGN_OR_RETURN(tokenizer_,
|
||||||
|
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
|
||||||
|
regex_tokenizer_options, metadata_extractor));
|
||||||
|
|
||||||
|
const auto& options =
|
||||||
|
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||||
|
max_seq_len_ = options.max_seq_len();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||||
|
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||||
|
tokenizer_->Tokenize(kTextIn(cc).Get());
|
||||||
|
|
||||||
|
int unknown_token_id = 0;
|
||||||
|
tokenizer_->GetUnknownToken(&unknown_token_id);
|
||||||
|
int pad_token_id = 0;
|
||||||
|
tokenizer_->GetPadToken(&pad_token_id);
|
||||||
|
|
||||||
|
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
|
||||||
|
int start_token_id = 0;
|
||||||
|
int input_token_index = 0;
|
||||||
|
if (tokenizer_->GetStartToken(&start_token_id)) {
|
||||||
|
input_tokens[0] = start_token_id;
|
||||||
|
input_token_index = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
|
||||||
|
(input_token_index < max_seq_len_);
|
||||||
|
++i, ++input_token_index) {
|
||||||
|
const std::string& token = tokenizer_result.subwords[i];
|
||||||
|
int token_id = 0;
|
||||||
|
if (tokenizer_->LookupId(token, &token_id)) {
|
||||||
|
input_tokens[input_token_index] = token_id;
|
||||||
|
} else {
|
||||||
|
input_tokens[input_token_index] = unknown_token_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// |<-------sentence_length-------->|
|
||||||
|
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
|
||||||
|
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
|
||||||
|
// not found in the tokenizer vocab.
|
||||||
|
std::vector<Tensor> result;
|
||||||
|
result.push_back(
|
||||||
|
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
|
||||||
|
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
||||||
|
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
||||||
|
kTensorsOut(cc).Send(std::move(result));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message RegexPreprocessorCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional RegexPreprocessorCalculatorOptions ext = 463716697;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The maximum input sequence length for the calculator's text model.
|
||||||
|
optional int32 max_seq_len = 1;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user