Internal text task change.

PiperOrigin-RevId: 508568811
This commit is contained in:
MediaPipe Team 2023-02-09 22:29:39 -08:00 committed by Copybara-Service
parent d61b7dbef8
commit 915d2c7417
6 changed files with 128 additions and 23 deletions

View File

@ -58,8 +58,6 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
TextModelType::ModelType model_type) { TextModelType::ModelType model_type) {
switch (model_type) { switch (model_type) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
// TODO: Support the UniversalSentenceEncoder model.
case TextModelType::USE_MODEL:
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, "Unspecified model type", absl::StatusCode::kInvalidArgument, "Unspecified model type",
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
@ -69,6 +67,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
return "RegexPreprocessorCalculator"; return "RegexPreprocessorCalculator";
case TextModelType::STRING_MODEL: case TextModelType::STRING_MODEL:
return "TextToTensorCalculator"; return "TextToTensorCalculator";
case TextModelType::USE_MODEL:
return "UniversalSentenceEncoderPreprocessorCalculator";
} }
} }
@ -189,8 +189,12 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
auto& text_preprocessor = graph.AddNode(preprocessor_name); auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.model_type()) { switch (options.model_type()) {
case TextModelType::UNSPECIFIED_MODEL: case TextModelType::UNSPECIFIED_MODEL:
case TextModelType::STRING_MODEL: case TextModelType::STRING_MODEL: {
break;
}
case TextModelType::USE_MODEL: { case TextModelType::USE_MODEL: {
metadata_extractor_in >>
text_preprocessor.SideIn(kMetadataExtractorTag);
break; break;
} }
case TextModelType::BERT_MODEL: { case TextModelType::BERT_MODEL: {

View File

@ -54,17 +54,21 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/text/utils:text_model_utils",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -75,6 +79,7 @@ cc_test(
data = [ data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model", "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
], ],
deps = [ deps = [
":text_embedder", ":text_embedder",

View File

@ -46,24 +46,30 @@ struct TextEmbedderOptions {
// Performs embedding extraction on text. // Performs embedding extraction on text.
// //
// This API expects a TFLite model with TFLite Model Metadata that contains the // This API expects a TFLite model with TFLite Model Metadata that contains the
// mandatory (described below) input tensors and output tensors. Metadata should // mandatory (described below) input tensors and output tensors.
// contain the input process unit for the model's Tokenizer as well as input /
// output tensor metadata.
// //
// TODO: Support Universal Sentence Encoder. // 1. BERT-based model
// Input tensors: // - 3 input tensors of size `[batch_size x bert_max_seq_len]` and type
// (kTfLiteInt32) // kTfLiteInt32 with names "ids", "mask", and "segment_ids" representing
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names // the input ids, mask ids, and segment ids respectively
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and // - at least one output tensor (all of type kTfLiteFloat32) with `N`
// segment ids respectively // components corresponding to the `N` dimensions of the returned
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the // feature vector for this output layer and with either 2 or 4 dimensions,
// input ids // i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
// // - input process units for a BertTokenizer or SentencePieceTokenizer
// At least one output tensor with: // 2. Regex-based model
// (kTfLiteFloat32) // - 1 input tensor of size `[batch_size x max_seq_len]` and type
// - `N` components corresponding to the `N` dimensions of the returned // kTfLiteInt32 representing the input ids
// feature vector for this output layer. // - at least one output tensor (all of type kTfLiteFloat32) with `N`
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. // components corresponding to the `N` dimensions of the returned
// feature vector for this output layer and with either 2 or 4 dimensions,
// i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
// - input process units for a RegexTokenizer
// 3. UniversalSentenceEncoder-based model
// - 3 input tensors with names "inp_text", "res_context" and "res_text"
// - 2 output tensors with names "query_encoding" and "response_encoding" of
// type kTfLiteFloat32. The "query_encoding" is filtered and only the other
// output tensor is used for the embedding.
class TextEmbedder : core::BaseTaskApi { class TextEmbedder : core::BaseTaskApi {
public: public:
using BaseTaskApi::BaseTaskApi; using BaseTaskApi::BaseTaskApi;

View File

@ -15,20 +15,24 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
namespace mediapipe::tasks::text::text_embedder { namespace mediapipe::tasks::text::text_embedder {
namespace { namespace {
@ -38,13 +42,17 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::processors::proto::TextModelType;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::text::utils::GetModelType;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextTag[] = "TEXT"; constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kUSEQueryTensorName[] = "query_encoding";
} // namespace } // namespace
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding // A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
@ -128,12 +136,22 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
auto* postprocessing_options = &postprocessing.GetOptions<
components::processors::proto::EmbeddingPostprocessingGraphOptions>();
// The UniversalSentenceEncoder model has an extraneous output head.
std::vector<absl::string_view> filtered_head_names;
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
GetModelType(model_resources));
if (model_type == TextModelType::USE_MODEL) {
postprocessing_options->mutable_tensors_to_embeddings_options()
->add_ignored_head_names(kUSEQueryTensorName);
}
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
components::processors::ConfigureEmbeddingPostprocessingGraph( components::processors::ConfigureEmbeddingPostprocessingGraph(
model_resources, task_options.embedder_options(), model_resources, task_options.embedder_options(),
&postprocessing postprocessing_options));
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result. // Outputs the embedding result.

View File

@ -0,0 +1,40 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h"
#include "absl/memory/memory.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
} // namespace tflite::ops::custom
namespace mediapipe::tasks::text::text_embedder {
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver() {
auto resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
resolver->AddCustom(
"TFSentencepieceTokenizeOp",
::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
resolver->AddCustom(
"RaggedTensorToTensor",
::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
return resolver;
}
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,32 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
#include <memory>
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe::tasks::text::text_embedder {
// Creates a custom OpResolver containing the additional SENTENCEPIECE_TOKENIZER
// and RAGGED_TENSOR_TO_TENSOR ops needed by universal sentence encoder-based
// models.
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver();
} // namespace mediapipe::tasks::text::text_embedder
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_