diff --git a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc new file mode 100644 index 000000000..e589289f6 --- /dev/null +++ b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc @@ -0,0 +1,167 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::core::FindTensorIndexByMetadataName; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr absl::string_view kQueryTextMetadataName = "inp_text"; +constexpr absl::string_view kResponseContextMetadataName = "res_context"; +constexpr absl::string_view kResponseTextMetadataName = "res_text"; + +constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3; + +// Preprocesses input text into three kTfLiteString input tensors for a +// Universal Sentence Encoder (USE) model. +// +// The associated USE model is expected to contain input tensors with metadata +// names: +// +// Tensor | Metadata Name +// ---------------- | ------------------ +// Query text | "inp_text" +// Response context | "res_context" +// Response text | "res_text" +// +// This calculator will return an error if the model does not have three input +// tensors or if the tensors do not have metadata names corresponding to the +// above names in some order. Additional details regarding these input +// tensors are given in the Calculator "Outputs" section below. +// +// Inputs: +// TEXT - std::string +// The text to be embedded. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the USE model. Used to determine the order of +// the three input Tensors for the USE model. +// +// Outputs: +// TENSORS - std::vector +// Vector containing the three input Tensors for the USE model. The tensors +// fit a question-answering setting and store a query text, a response +// context, and a response text. This calculator will just be preprocessing +// a single input text that will be stored in the response text tensor. The +// query text and response context tensors will store empty strings. +// +// Example: +// node { +// calculator: "UniversalSentenceEncoderPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// } +class UniversalSentenceEncoderPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + // Indices of the three input tensors for the USE model. They should form the + // set {0, 1, 2}. + int query_text_tensor_index_ = 0; + int response_context_tensor_index_ = 1; + int response_text_tensor_index_ = 2; + + // Tensor shapes for the model's input tensors. + // The query text and response context tensors will only hold the empty + // string, so their tensors will have shape [0], but the Universal Sentence + // Encoder model's input signature requires them to be present. The response + // text tensor will store the embedding text and have shape + // [embedding_text_len]. + std::array tensor_shapes_; +}; + +absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open( + CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata(); + query_text_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kQueryTextMetadataName); + response_context_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kResponseContextMetadataName); + response_text_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kResponseTextMetadataName); + + absl::flat_hash_set tensor_indices = absl::flat_hash_set( + {query_text_tensor_index_, response_context_tensor_index_, + response_text_tensor_index_}); + if (tensor_indices != absl::flat_hash_set({0, 1, 2})) { + return absl::InvalidArgumentError(absl::Substitute( + "Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}", + query_text_tensor_index_, response_context_tensor_index_, + response_text_tensor_index_)); + } + return absl::OkStatus(); +} + +absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process( + CalculatorContext* cc) { + absl::string_view text = kTextIn(cc).Get(); + const int text_len = static_cast(text.length()); + tensor_shapes_[response_text_tensor_index_] = text_len; + + std::vector input_tensors; + input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder); + for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { + input_tensors.push_back( + {Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})}); + } + + std::memcpy( + input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer(), + "", 0); + std::memcpy(input_tensors[response_context_tensor_index_] + .GetCpuWriteView() + .buffer(), + "", 0); + std::memcpy(input_tensors[response_text_tensor_index_] + .GetCpuWriteView() + .buffer(), + text.data(), text_len * sizeof(char)); + kTensorsOut(cc).Send(std::move(input_tensors)); + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc new file mode 100644 index 000000000..d5f252b57 --- /dev/null +++ b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc @@ -0,0 +1,111 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/options_map.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::IsOkAndHolds; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3; + +constexpr absl::string_view kTestModelPath = + "mediapipe/tasks/testdata/text/" + "universal_sentence_encoder_qa_with_metadata.tflite"; + +absl::StatusOr> +RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "UniversalSentenceEncoderPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = + tasks::core::LoadBinaryContent(kTestModelPath.data()); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) { + return absl::InvalidArgumentError(absl::Substitute( + "tensor_vec has size $0, expected $1", tensor_vec.size(), + kNumInputTensorsForUniversalSentenceEncoder)); + } + if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { + return absl::InvalidArgumentError(absl::Substitute( + "tensor has element type $0, expected $1", tensor_vec[0].element_type(), + Tensor::ElementType::kChar)); + } + std::vector results; + for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { + results.push_back( + {tensor_vec[i].GetCpuReadView().buffer(), + static_cast(tensor_vec[i].shape().num_elements())}); + } + return results; +} + +TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) { + ASSERT_THAT( + RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"), + IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"}))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index a12e607bb..6cce5ae41 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -30,6 +30,7 @@ mediapipe_files(srcs = [ "mobilebert_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", + "universal_sentence_encoder_qa_with_metadata.tflite", ]) exports_files(srcs = [ @@ -89,3 +90,8 @@ filegroup( name = "mobilebert_embedding_model", srcs = ["mobilebert_embedding_with_metadata.tflite"], ) + +filegroup( + name = "universal_sentence_encoder_qa", + data = ["universal_sentence_encoder_qa_with_metadata.tflite"], +) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 3e651a3a0..b42019a17 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -736,6 +736,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/two_heads.tflite?generation=1661875968723352"], ) + http_file( + name = "com_google_mediapipe_universal_sentence_encoder_qa_with_metadata_tflite", + sha256 = "82c2d0450aa458adbec2f78eff33cfbf2a41b606b44246726ab67373926e32bc", + urls = ["https://storage.googleapis.com/mediapipe-assets/universal_sentence_encoder_qa_with_metadata.tflite?generation=1665445919252005"], + ) + http_file( name = "com_google_mediapipe_vocab_for_regex_tokenizer_txt", sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",