Adds a preprocessor for Universal Sentence Encoder models.
PiperOrigin-RevId: 481293992
This commit is contained in:
parent
5f3d5728e8
commit
17202af6f7
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
|
||||
constexpr absl::string_view kResponseContextMetadataName = "res_context";
|
||||
constexpr absl::string_view kResponseTextMetadataName = "res_text";
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
// Preprocesses input text into three kTfLiteString input tensors for a
|
||||
// Universal Sentence Encoder (USE) model.
|
||||
//
|
||||
// The associated USE model is expected to contain input tensors with metadata
|
||||
// names:
|
||||
//
|
||||
// Tensor | Metadata Name
|
||||
// ---------------- | ------------------
|
||||
// Query text | "inp_text"
|
||||
// Response context | "res_context"
|
||||
// Response text | "res_text"
|
||||
//
|
||||
// This calculator will return an error if the model does not have three input
|
||||
// tensors or if the tensors do not have metadata names corresponding to the
|
||||
// above names in some order. Additional details regarding these input
|
||||
// tensors are given in the Calculator "Outputs" section below.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The text to be embedded.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the USE model. Used to determine the order of
|
||||
// the three input Tensors for the USE model.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing the three input Tensors for the USE model. The tensors
|
||||
// fit a question-answering setting and store a query text, a response
|
||||
// context, and a response text. This calculator will just be preprocessing
|
||||
// a single input text that will be stored in the response text tensor. The
|
||||
// query text and response context tensors will store empty strings.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// }
|
||||
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Indices of the three input tensors for the USE model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int query_text_tensor_index_ = 0;
|
||||
int response_context_tensor_index_ = 1;
|
||||
int response_text_tensor_index_ = 2;
|
||||
|
||||
// Tensor shapes for the model's input tensors.
|
||||
// The query text and response context tensors will only hold the empty
|
||||
// string, so their tensors will have shape [0], but the Universal Sentence
|
||||
// Encoder model's input signature requires them to be present. The response
|
||||
// text tensor will store the embedding text and have shape
|
||||
// [embedding_text_len].
|
||||
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
|
||||
};
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||
query_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kQueryTextMetadataName);
|
||||
response_context_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseContextMetadataName);
|
||||
response_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseTextMetadataName);
|
||||
|
||||
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
|
||||
{query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_});
|
||||
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||
query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
absl::string_view text = kTextIn(cc).Get();
|
||||
const int text_len = static_cast<int>(text.length());
|
||||
tensor_shapes_[response_text_tensor_index_] = text_len;
|
||||
|
||||
std::vector<Tensor> input_tensors;
|
||||
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_context_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_text_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
text.data(), text_len * sizeof(char));
|
||||
kTensorsOut(cc).Send(std::move(input_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_map.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::IsOkAndHolds;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
constexpr absl::string_view kTestModelPath =
|
||||
"mediapipe/tasks/testdata/text/"
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::string>>
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer =
|
||||
tasks::core::LoadBinaryContent(kTestModelPath.data());
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected $1", tensor_vec.size(),
|
||||
kNumInputTensorsForUniversalSentenceEncoder));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
}
|
||||
std::vector<std::string> results;
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
results.push_back(
|
||||
{tensor_vec[i].GetCpuReadView().buffer<char>(),
|
||||
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
|
||||
ASSERT_THAT(
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
|
||||
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
6
mediapipe/tasks/testdata/text/BUILD
vendored
6
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -30,6 +30,7 @@ mediapipe_files(srcs = [
|
|||
"mobilebert_with_metadata.tflite",
|
||||
"test_model_text_classifier_bool_output.tflite",
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite",
|
||||
])
|
||||
|
||||
exports_files(srcs = [
|
||||
|
@ -89,3 +90,8 @@ filegroup(
|
|||
name = "mobilebert_embedding_model",
|
||||
srcs = ["mobilebert_embedding_with_metadata.tflite"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "universal_sentence_encoder_qa",
|
||||
data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
|
||||
)
|
||||
|
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -736,6 +736,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_universal_sentence_encoder_qa_with_metadata_tflite",
|
||||
sha256 = "82c2d0450aa458adbec2f78eff33cfbf2a41b606b44246726ab67373926e32bc",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt",
|
||||
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
||||
|
|
Loading…
Reference in New Issue
Block a user