diff --git a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 743986943..ee3d8299c 100644 --- a/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -58,8 +58,6 @@ absl::StatusOr GetCalculatorNameFromModelType( TextModelType::ModelType model_type) { switch (model_type) { case TextModelType::UNSPECIFIED_MODEL: - // TODO: Support the UniversalSentenceEncoder model. - case TextModelType::USE_MODEL: return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "Unspecified model type", MediaPipeTasksStatus::kInvalidArgumentError); @@ -69,6 +67,8 @@ absl::StatusOr GetCalculatorNameFromModelType( return "RegexPreprocessorCalculator"; case TextModelType::STRING_MODEL: return "TextToTensorCalculator"; + case TextModelType::USE_MODEL: + return "UniversalSentenceEncoderPreprocessorCalculator"; } } @@ -189,8 +189,12 @@ class TextPreprocessingGraph : public mediapipe::Subgraph { auto& text_preprocessor = graph.AddNode(preprocessor_name); switch (options.model_type()) { case TextModelType::UNSPECIFIED_MODEL: - case TextModelType::STRING_MODEL: + case TextModelType::STRING_MODEL: { + break; + } case TextModelType::USE_MODEL: { + metadata_extractor_in >> + text_preprocessor.SideIn(kMetadataExtractorTag); break; } case TextModelType::BERT_MODEL: { diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index 4c970159e..c6a2616b0 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -54,17 +54,21 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/text/utils:text_model_utils", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -75,6 +79,7 @@ cc_test( data = [ "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + "//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa", ], deps = [ ":text_embedder", diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h index 81f90fd27..d729ff3c2 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h @@ -46,24 +46,30 @@ struct TextEmbedderOptions { // Performs embedding extraction on text. // // This API expects a TFLite model with TFLite Model Metadata that contains the -// mandatory (described below) input tensors and output tensors. Metadata should -// contain the input process unit for the model's Tokenizer as well as input / -// output tensor metadata. +// mandatory (described below) input tensors and output tensors. // -// TODO: Support Universal Sentence Encoder. -// Input tensors: -// (kTfLiteInt32) -// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names -// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and -// segment ids respectively -// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the -// input ids -// -// At least one output tensor with: -// (kTfLiteFloat32) -// - `N` components corresponding to the `N` dimensions of the returned -// feature vector for this output layer. -// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +// 1. BERT-based model +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` and type +// kTfLiteInt32 with names "ids", "mask", and "segment_ids" representing +// the input ids, mask ids, and segment ids respectively +// - at least one output tensor (all of type kTfLiteFloat32) with `N` +// components corresponding to the `N` dimensions of the returned +// feature vector for this output layer and with either 2 or 4 dimensions, +// i.e. `[1 x N]` or `[1 x 1 x 1 x N]` +// - input process units for a BertTokenizer or SentencePieceTokenizer +// 2. Regex-based model +// - 1 input tensor of size `[batch_size x max_seq_len]` and type +// kTfLiteInt32 representing the input ids +// - at least one output tensor (all of type kTfLiteFloat32) with `N` +// components corresponding to the `N` dimensions of the returned +// feature vector for this output layer and with either 2 or 4 dimensions, +// i.e. `[1 x N]` or `[1 x 1 x 1 x N]` +// - input process units for a RegexTokenizer +// 3. UniversalSentenceEncoder-based model +// - 3 input tensors with names "inp_text", "res_context" and "res_text" +// - 2 output tensors with names "query_encoding" and "response_encoding" of +// type kTfLiteFloat32. The "query_encoding" is filtered and only the other +// output tensor is used for the embedding. class TextEmbedder : core::BaseTaskApi { public: using BaseTaskApi::BaseTaskApi; diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 225ef07bd..518695138 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -15,20 +15,24 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" #include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" +#include "mediapipe/tasks/cc/text/utils/text_model_utils.h" namespace mediapipe::tasks::text::text_embedder { namespace { @@ -38,13 +42,17 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::components::processors::proto::TextModelType; using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::text::utils::GetModelType; constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kUSEQueryTensorName[] = "query_encoding"; + } // namespace // A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding @@ -128,12 +136,22 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); + auto* postprocessing_options = &postprocessing.GetOptions< + components::processors::proto::EmbeddingPostprocessingGraphOptions>(); + + // The UniversalSentenceEncoder model has an extraneous output head. + std::vector filtered_head_names; + ASSIGN_OR_RETURN(TextModelType::ModelType model_type, + GetModelType(model_resources)); + if (model_type == TextModelType::USE_MODEL) { + postprocessing_options->mutable_tensors_to_embeddings_options() + ->add_ignored_head_names(kUSEQueryTensorName); + } + MP_RETURN_IF_ERROR( components::processors::ConfigureEmbeddingPostprocessingGraph( model_resources, task_options.embedder_options(), - &postprocessing - .GetOptions())); + postprocessing_options)); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.cc new file mode 100644 index 000000000..f3e248722 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.cc @@ -0,0 +1,40 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h" + +#include "absl/memory/memory.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" + +namespace tflite::ops::custom { +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); +} // namespace tflite::ops::custom + +namespace mediapipe::tasks::text::text_embedder { + +std::unique_ptr CreateUSEOpResolver() { + auto resolver = + absl::make_unique(); + resolver->AddCustom( + "TFSentencepieceTokenizeOp", + ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER()); + resolver->AddCustom( + "RaggedTensorToTensor", + ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR()); + return resolver; +} + +} // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h new file mode 100644 index 000000000..42f82af20 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_ + +#include + +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe::tasks::text::text_embedder { + +// Creates a custom OpResolver containing the additional SENTENCEPIECE_TOKENIZER +// and RAGGED_TENSOR_TO_TENSOR ops needed by universal sentence encoder-based +// models. +std::unique_ptr CreateUSEOpResolver(); + +} // namespace mediapipe::tasks::text::text_embedder + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_