Internal text task change.
PiperOrigin-RevId: 508568811
This commit is contained in:
parent
d61b7dbef8
commit
915d2c7417
|
@ -58,8 +58,6 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
||||||
TextModelType::ModelType model_type) {
|
TextModelType::ModelType model_type) {
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case TextModelType::UNSPECIFIED_MODEL:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
// TODO: Support the UniversalSentenceEncoder model.
|
|
||||||
case TextModelType::USE_MODEL:
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
absl::StatusCode::kInvalidArgument, "Unspecified model type",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
@ -69,6 +67,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromModelType(
|
||||||
return "RegexPreprocessorCalculator";
|
return "RegexPreprocessorCalculator";
|
||||||
case TextModelType::STRING_MODEL:
|
case TextModelType::STRING_MODEL:
|
||||||
return "TextToTensorCalculator";
|
return "TextToTensorCalculator";
|
||||||
|
case TextModelType::USE_MODEL:
|
||||||
|
return "UniversalSentenceEncoderPreprocessorCalculator";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,8 +189,12 @@ class TextPreprocessingGraph : public mediapipe::Subgraph {
|
||||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||||
switch (options.model_type()) {
|
switch (options.model_type()) {
|
||||||
case TextModelType::UNSPECIFIED_MODEL:
|
case TextModelType::UNSPECIFIED_MODEL:
|
||||||
case TextModelType::STRING_MODEL:
|
case TextModelType::STRING_MODEL: {
|
||||||
|
break;
|
||||||
|
}
|
||||||
case TextModelType::USE_MODEL: {
|
case TextModelType::USE_MODEL: {
|
||||||
|
metadata_extractor_in >>
|
||||||
|
text_preprocessor.SideIn(kMetadataExtractorTag);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case TextModelType::BERT_MODEL: {
|
case TextModelType::BERT_MODEL: {
|
||||||
|
|
|
@ -54,17 +54,21 @@ cc_library(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
|
"//mediapipe/tasks/cc/components/processors:text_preprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:text_model_type_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/text/utils:text_model_utils",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -75,6 +79,7 @@ cc_test(
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
"//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":text_embedder",
|
":text_embedder",
|
||||||
|
|
|
@ -46,24 +46,30 @@ struct TextEmbedderOptions {
|
||||||
// Performs embedding extraction on text.
|
// Performs embedding extraction on text.
|
||||||
//
|
//
|
||||||
// This API expects a TFLite model with TFLite Model Metadata that contains the
|
// This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||||
// mandatory (described below) input tensors and output tensors. Metadata should
|
// mandatory (described below) input tensors and output tensors.
|
||||||
// contain the input process unit for the model's Tokenizer as well as input /
|
|
||||||
// output tensor metadata.
|
|
||||||
//
|
//
|
||||||
// TODO: Support Universal Sentence Encoder.
|
// 1. BERT-based model
|
||||||
// Input tensors:
|
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` and type
|
||||||
// (kTfLiteInt32)
|
// kTfLiteInt32 with names "ids", "mask", and "segment_ids" representing
|
||||||
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
|
// the input ids, mask ids, and segment ids respectively
|
||||||
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
|
// - at least one output tensor (all of type kTfLiteFloat32) with `N`
|
||||||
// segment ids respectively
|
// components corresponding to the `N` dimensions of the returned
|
||||||
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
// feature vector for this output layer and with either 2 or 4 dimensions,
|
||||||
// input ids
|
// i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
|
||||||
//
|
// - input process units for a BertTokenizer or SentencePieceTokenizer
|
||||||
// At least one output tensor with:
|
// 2. Regex-based model
|
||||||
// (kTfLiteFloat32)
|
// - 1 input tensor of size `[batch_size x max_seq_len]` and type
|
||||||
// - `N` components corresponding to the `N` dimensions of the returned
|
// kTfLiteInt32 representing the input ids
|
||||||
// feature vector for this output layer.
|
// - at least one output tensor (all of type kTfLiteFloat32) with `N`
|
||||||
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
|
// components corresponding to the `N` dimensions of the returned
|
||||||
|
// feature vector for this output layer and with either 2 or 4 dimensions,
|
||||||
|
// i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
|
||||||
|
// - input process units for a RegexTokenizer
|
||||||
|
// 3. UniversalSentenceEncoder-based model
|
||||||
|
// - 3 input tensors with names "inp_text", "res_context" and "res_text"
|
||||||
|
// - 2 output tensors with names "query_encoding" and "response_encoding" of
|
||||||
|
// type kTfLiteFloat32. The "query_encoding" is filtered and only the other
|
||||||
|
// output tensor is used for the embedding.
|
||||||
class TextEmbedder : core::BaseTaskApi {
|
class TextEmbedder : core::BaseTaskApi {
|
||||||
public:
|
public:
|
||||||
using BaseTaskApi::BaseTaskApi;
|
using BaseTaskApi::BaseTaskApi;
|
||||||
|
|
|
@ -15,20 +15,24 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/utils/text_model_utils.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::text_embedder {
|
namespace mediapipe::tasks::text::text_embedder {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -38,13 +42,17 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
using ::mediapipe::tasks::components::processors::proto::TextModelType;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
using ::mediapipe::tasks::text::utils::GetModelType;
|
||||||
|
|
||||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
constexpr char kTextTag[] = "TEXT";
|
constexpr char kTextTag[] = "TEXT";
|
||||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
|
constexpr char kUSEQueryTensorName[] = "query_encoding";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
|
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
|
||||||
|
@ -128,12 +136,22 @@ class TextEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
|
auto* postprocessing_options = &postprocessing.GetOptions<
|
||||||
|
components::processors::proto::EmbeddingPostprocessingGraphOptions>();
|
||||||
|
|
||||||
|
// The UniversalSentenceEncoder model has an extraneous output head.
|
||||||
|
std::vector<absl::string_view> filtered_head_names;
|
||||||
|
ASSIGN_OR_RETURN(TextModelType::ModelType model_type,
|
||||||
|
GetModelType(model_resources));
|
||||||
|
if (model_type == TextModelType::USE_MODEL) {
|
||||||
|
postprocessing_options->mutable_tensors_to_embeddings_options()
|
||||||
|
->add_ignored_head_names(kUSEQueryTensorName);
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
components::processors::ConfigureEmbeddingPostprocessingGraph(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing
|
postprocessing_options));
|
||||||
.GetOptions<components::processors::proto::
|
|
||||||
EmbeddingPostprocessingGraphOptions>()));
|
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Outputs the embedding result.
|
// Outputs the embedding result.
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder_test_utils.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
|
||||||
|
TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
|
||||||
|
} // namespace tflite::ops::custom
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::text_embedder {
|
||||||
|
|
||||||
|
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver() {
|
||||||
|
auto resolver =
|
||||||
|
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
|
||||||
|
resolver->AddCustom(
|
||||||
|
"TFSentencepieceTokenizeOp",
|
||||||
|
::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
|
||||||
|
resolver->AddCustom(
|
||||||
|
"RaggedTensorToTensor",
|
||||||
|
::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
|
||||||
|
return resolver;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::text_embedder
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::text_embedder {
|
||||||
|
|
||||||
|
// Creates a custom OpResolver containing the additional SENTENCEPIECE_TOKENIZER
|
||||||
|
// and RAGGED_TENSOR_TO_TENSOR ops needed by universal sentence encoder-based
|
||||||
|
// models.
|
||||||
|
std::unique_ptr<tflite::OpResolver> CreateUSEOpResolver();
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::text_embedder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user