Internal text task change.

PiperOrigin-RevId: 508568811
This commit is contained in:
MediaPipe Team 2023-02-09 22:29:39 -08:00 committed by Copybara-Service
parent d61b7dbef8
commit 915d2c7417
6 changed files with 128 additions and 23 deletions

View File

@ -58,8 +58,6 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
TextModelType::ModelType model_type) {
switch (model_type) {
case TextModelType::UNSPECIFIED_MODEL:
// TODO: Support the UniversalSentenceEncoder model.
case TextModelType::USE_MODEL:
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, "Unspecified model type",
MediaPipeTasksStatus::kInvalidArgumentError);
@ -69,6 +67,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
return "RegexPreprocessorCalculator";
case TextModelType::STRING_MODEL:
return "TextToTensorCalculator";
case TextModelType::USE_MODEL:
return "UniversalSentenceEncoderPreprocessorCalculator";
}
}
@ -189,8 +189,12 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.model_type()) {
case TextModelType::UNSPECIFIED_MODEL:
case TextModelType::STRING_MODEL:
case TextModelType::STRING_MODEL: {
break;
}
case TextModelType::USE_MODEL: {
metadata_extractor_in >>
text_preprocessor.SideIn(kMetadataExtractorTag);
break;
}
case TextModelType::BERT_MODEL: {

View File

@ -54,17 +54,21 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/text/utils:text_model_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
@ -75,6 +79,7 @@ cc_test(
data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
],
deps = [
":text_embedder",

View File

@ -46,24 +46,30 @@ struct TextEmbedderOptions {
// Performs embedding extraction on text.
//
// This API expects a TFLite model with TFLite Model Metadata that contains the
// mandatory (described below) input tensors and output tensors. Metadata should
// contain the input process unit for the model's Tokenizer as well as input /
// output tensor metadata.
// mandatory (described below) input tensors and output tensors.
//
// TODO: Support Universal Sentence Encoder.
// Input tensors:
// (kTfLiteInt32)
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
// segment ids respectively
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
// input ids
//
// At least one output tensor with:
// (kTfLiteFloat32)
// - `N` components corresponding to the `N` dimensions of the returned
// feature vector for this output layer.
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
// 1. BERT-based model
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` and type
// kTfLiteInt32 with names "ids", "mask", and "segment_ids" representing
// the input ids, mask ids, and segment ids respectively
// - at least one output tensor (all of type kTfLiteFloat32) with `N`
// components corresponding to the `N` dimensions of the returned
// feature vector for this output layer and with either 2 or 4 dimensions,
// i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
// - input process units for a BertTokenizer or SentencePieceTokenizer
// 2. Regex-based model
// - 1 input tensor of size `[batch_size x max_seq_len]` and type
// kTfLiteInt32 representing the input ids
// - at least one output tensor (all of type kTfLiteFloat32) with `N`
// components corresponding to the `N` dimensions of the returned
// feature vector for this output layer and with either 2 or 4 dimensions,
// i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
// - input process units for a RegexTokenizer
// 3. UniversalSentenceEncoder-based model
// - 3 input tensors with names "inp_text", "res_context" and "res_text"
// - 2 output tensors with names "query_encoding" and "response_encoding" of
// type kTfLiteFloat32. The "query_encoding" is filtered and only the other
// output tensor is used for the embedding.
class TextEmbedder : core::BaseTaskApi {
public:
using BaseTaskApi::BaseTaskApi;

View File

@ -15,20 +15,24 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
@ -38,13 +42,17 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::processors::proto::TextModelType;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::text::utils::GetModelType;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kUSEQueryTensorName[] = "query_encoding";
} // namespace
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
@ -128,12 +136,22 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
auto* postprocessing_options = &postprocessing.GetOptions<
components::processors::proto::EmbeddingPostprocessingGraphOptions>();
// The UniversalSentenceEncoder model has an extraneous output head.
std::vector<absl::string_view> filtered_head_names;
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
GetModelType(model_resources));
if (model_type == TextModelType::USE_MODEL) {
postprocessing_options->mutable_tensors_to_embeddings_options()
->add_ignored_head_names(kUSEQueryTensorName);
}
MP_RETURN_IF_ERROR(
components::processors::ConfigureEmbeddingPostprocessingGraph(
model_resources, task_options.embedder_options(),
&postprocessing
.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
postprocessing_options));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result.

View File

@ -0,0 +1,40 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h"
#include "absl/memory/memory.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
} // namespace tflite::ops::custom
namespace mediapipe::tasks::text::text_embedder {
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver() {
auto resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
resolver->AddCustom(
"TFSentencepieceTokenizeOp",
::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
resolver->AddCustom(
"RaggedTensorToTensor",
::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
return resolver;
}
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,32 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
#include <memory>
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe::tasks::text::text_embedder {
// Creates a custom OpResolver containing the additional SENTENCEPIECE_TOKENIZER
// and RAGGED_TENSOR_TO_TENSOR ops needed by universal sentence encoder-based
// models.
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver();
} // namespace mediapipe::tasks::text::text_embedder
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_