Adds a preprocessor for Universal Sentence Encoder models.
PiperOrigin-RevId: 481293992
This commit is contained in:
parent
5f3d5728e8
commit
17202af6f7
|
@ -0,0 +1,167 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||||
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
|
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
|
||||||
|
constexpr absl::string_view kResponseContextMetadataName = "res_context";
|
||||||
|
constexpr absl::string_view kResponseTextMetadataName = "res_text";
|
||||||
|
|
||||||
|
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||||
|
|
||||||
|
// Preprocesses input text into three kTfLiteString input tensors for a
|
||||||
|
// Universal Sentence Encoder (USE) model.
|
||||||
|
//
|
||||||
|
// The associated USE model is expected to contain input tensors with metadata
|
||||||
|
// names:
|
||||||
|
//
|
||||||
|
// Tensor | Metadata Name
|
||||||
|
// ---------------- | ------------------
|
||||||
|
// Query text | "inp_text"
|
||||||
|
// Response context | "res_context"
|
||||||
|
// Response text | "res_text"
|
||||||
|
//
|
||||||
|
// This calculator will return an error if the model does not have three input
|
||||||
|
// tensors or if the tensors do not have metadata names corresponding to the
|
||||||
|
// above names in some order. Additional details regarding these input
|
||||||
|
// tensors are given in the Calculator "Outputs" section below.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// TEXT - std::string
|
||||||
|
// The text to be embedded.
|
||||||
|
// Side Inputs:
|
||||||
|
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||||
|
// The metadata extractor for the USE model. Used to determine the order of
|
||||||
|
// the three input Tensors for the USE model.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// Vector containing the three input Tensors for the USE model. The tensors
|
||||||
|
// fit a question-answering setting and store a query text, a response
|
||||||
|
// context, and a response text. This calculator will just be preprocessing
|
||||||
|
// a single input text that will be stored in the response text tensor. The
|
||||||
|
// query text and response context tensors will store empty strings.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||||
|
// input_stream: "TEXT:text"
|
||||||
|
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// }
|
||||||
|
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||||
|
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||||
|
"METADATA_EXTRACTOR"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Indices of the three input tensors for the USE model. They should form the
|
||||||
|
// set {0, 1, 2}.
|
||||||
|
int query_text_tensor_index_ = 0;
|
||||||
|
int response_context_tensor_index_ = 1;
|
||||||
|
int response_text_tensor_index_ = 2;
|
||||||
|
|
||||||
|
// Tensor shapes for the model's input tensors.
|
||||||
|
// The query text and response context tensors will only hold the empty
|
||||||
|
// string, so their tensors will have shape [0], but the Universal Sentence
|
||||||
|
// Encoder model's input signature requires them to be present. The response
|
||||||
|
// text tensor will store the embedding text and have shape
|
||||||
|
// [embedding_text_len].
|
||||||
|
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
const ModelMetadataExtractor* metadata_extractor =
|
||||||
|
&kMetadataExtractorSideIn(cc).Get();
|
||||||
|
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||||
|
query_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kQueryTextMetadataName);
|
||||||
|
response_context_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kResponseContextMetadataName);
|
||||||
|
response_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||||
|
input_tensors_metadata, kResponseTextMetadataName);
|
||||||
|
|
||||||
|
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
|
||||||
|
{query_text_tensor_index_, response_context_tensor_index_,
|
||||||
|
response_text_tensor_index_});
|
||||||
|
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||||
|
query_text_tensor_index_, response_context_tensor_index_,
|
||||||
|
response_text_tensor_index_));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
absl::string_view text = kTextIn(cc).Get();
|
||||||
|
const int text_len = static_cast<int>(text.length());
|
||||||
|
tensor_shapes_[response_text_tensor_index_] = text_len;
|
||||||
|
|
||||||
|
std::vector<Tensor> input_tensors;
|
||||||
|
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
|
||||||
|
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||||
|
input_tensors.push_back(
|
||||||
|
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::memcpy(
|
||||||
|
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
|
||||||
|
"", 0);
|
||||||
|
std::memcpy(input_tensors[response_context_tensor_index_]
|
||||||
|
.GetCpuWriteView()
|
||||||
|
.buffer<char>(),
|
||||||
|
"", 0);
|
||||||
|
std::memcpy(input_tensors[response_text_tensor_index_]
|
||||||
|
.GetCpuWriteView()
|
||||||
|
.buffer<char>(),
|
||||||
|
text.data(), text_len * sizeof(char));
|
||||||
|
kTensorsOut(cc).Send(std::move(input_tensors));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,111 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/tool/options_map.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::IsOkAndHolds;
|
||||||
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||||
|
|
||||||
|
constexpr absl::string_view kTestModelPath =
|
||||||
|
"mediapipe/tasks/testdata/text/"
|
||||||
|
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<std::string>>
|
||||||
|
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "text"
|
||||||
|
output_stream: "tensors"
|
||||||
|
node {
|
||||||
|
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||||
|
input_stream: "TEXT:text"
|
||||||
|
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
std::string model_buffer =
|
||||||
|
tasks::core::LoadBinaryContent(kTestModelPath.data());
|
||||||
|
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||||
|
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||||
|
model_buffer.data(), model_buffer.size()));
|
||||||
|
// Run the graph.
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||||
|
graph_config,
|
||||||
|
{{"metadata_extractor",
|
||||||
|
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||||
|
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||||
|
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
if (output_packets.size() != 1) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"output_packets has size $0, expected 1", output_packets.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<Tensor>& tensor_vec =
|
||||||
|
output_packets[0].Get<std::vector<Tensor>>();
|
||||||
|
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"tensor_vec has size $0, expected $1", tensor_vec.size(),
|
||||||
|
kNumInputTensorsForUniversalSentenceEncoder));
|
||||||
|
}
|
||||||
|
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||||
|
Tensor::ElementType::kChar));
|
||||||
|
}
|
||||||
|
std::vector<std::string> results;
|
||||||
|
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||||
|
results.push_back(
|
||||||
|
{tensor_vec[i].GetCpuReadView().buffer<char>(),
|
||||||
|
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
|
||||||
|
ASSERT_THAT(
|
||||||
|
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
|
||||||
|
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
6
mediapipe/tasks/testdata/text/BUILD
vendored
6
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -30,6 +30,7 @@ mediapipe_files(srcs = [
|
||||||
"mobilebert_with_metadata.tflite",
|
"mobilebert_with_metadata.tflite",
|
||||||
"test_model_text_classifier_bool_output.tflite",
|
"test_model_text_classifier_bool_output.tflite",
|
||||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||||
|
"universal_sentence_encoder_qa_with_metadata.tflite",
|
||||||
])
|
])
|
||||||
|
|
||||||
exports_files(srcs = [
|
exports_files(srcs = [
|
||||||
|
@ -89,3 +90,8 @@ filegroup(
|
||||||
name = "mobilebert_embedding_model",
|
name = "mobilebert_embedding_model",
|
||||||
srcs = ["mobilebert_embedding_with_metadata.tflite"],
|
srcs = ["mobilebert_embedding_with_metadata.tflite"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "universal_sentence_encoder_qa",
|
||||||
|
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||||
|
)
|
||||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -736,6 +736,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_universal_sentence_encoder_qa_with_metadata_tflite",
|
||||||
|
sha256 = "82c2d0450aa458adbec2f78eff33cfbf2a41b606b44246726ab67373926e32bc",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
|
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
|
||||||
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user