Adds a preprocessor for Universal Sentence Encoder models.

PiperOrigin-RevId: 481293992
This commit is contained in:
MediaPipe Team 2022-10-14 21:36:40 -07:00 committed by Copybara-Service
parent 5f3d5728e8
commit 17202af6f7
4 changed files with 290 additions and 0 deletions

View File

@ -0,0 +1,167 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <cstring>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
constexpr absl::string_view kResponseContextMetadataName = "res_context";
constexpr absl::string_view kResponseTextMetadataName = "res_text";
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
// Preprocesses input text into three kTfLiteString input tensors for a
// Universal Sentence Encoder (USE) model.
//
// The associated USE model is expected to contain input tensors with metadata
// names:
//
// Tensor | Metadata Name
// ---------------- | ------------------
// Query text | "inp_text"
// Response context | "res_context"
// Response text | "res_text"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have metadata names corresponding to the
// above names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// Inputs:
// TEXT - std::string
// The text to be embedded.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the USE model. Used to determine the order of
// the three input Tensors for the USE model.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the USE model. The tensors
// fit a question-answering setting and store a query text, a response
// context, and a response text. This calculator will just be preprocessing
// a single input text that will be stored in the response text tensor. The
// query text and response context tensors will store empty strings.
//
// Example:
// node {
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// }
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Indices of the three input tensors for the USE model. They should form the
// set {0, 1, 2}.
int query_text_tensor_index_ = 0;
int response_context_tensor_index_ = 1;
int response_text_tensor_index_ = 2;
// Tensor shapes for the model's input tensors.
// The query text and response context tensors will only hold the empty
// string, so their tensors will have shape [0], but the Universal Sentence
// Encoder model's input signature requires them to be present. The response
// text tensor will store the embedding text and have shape
// [embedding_text_len].
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
};
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
query_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kQueryTextMetadataName);
response_context_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseContextMetadataName);
response_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseTextMetadataName);
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
{query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_});
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_));
}
return absl::OkStatus();
}
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
const int text_len = static_cast<int>(text.length());
tensor_shapes_[response_text_tensor_index_] = text_len;
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
}
std::memcpy(
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_context_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_text_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
text.data(), text_len * sizeof(char));
kTensorsOut(cc).Send(std::move(input_tensors));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,111 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::IsOkAndHolds;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/"
"universal_sentence_encoder_qa_with_metadata.tflite";
absl::StatusOr<std::vector<std::string>>
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer =
tasks::core::LoadBinaryContent(kTestModelPath.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(),
kNumInputTensorsForUniversalSentenceEncoder));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
}
std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
results.push_back(
{tensor_vec[i].GetCpuReadView().buffer<char>(),
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
}
return results;
}
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
ASSERT_THAT(
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
}
} // namespace
} // namespace mediapipe

View File

@ -30,6 +30,7 @@ mediapipe_files(srcs = [
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
"universal_sentence_encoder_qa_with_metadata.tflite",
]) ])
exports_files(srcs = [ exports_files(srcs = [
@ -89,3 +90,8 @@ filegroup(
name = "mobilebert_embedding_model", name = "mobilebert_embedding_model",
srcs = ["mobilebert_embedding_with_metadata.tflite"], srcs = ["mobilebert_embedding_with_metadata.tflite"],
) )
filegroup(
name = "universal_sentence_encoder_qa",
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
)

View File

@ -736,6 +736,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"], urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"],
) )
http_file(
name = "com_google_mediapipe_universal_sentence_encoder_qa_with_metadata_tflite",
sha256 = "82c2d0450aa458adbec2f78eff33cfbf2a41b606b44246726ab67373926e32bc",
urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"],
)
http_file( http_file(
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt", name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923", sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",