Open-sources TextEmbedder.
PiperOrigin-RevId: 487041832
This commit is contained in:
parent
ace098f370
commit
0363d60511
87
mediapipe/tasks/cc/text/text_embedder/BUILD
Normal file
87
mediapipe/tasks/cc/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "text_embedder",
|
||||
srcs = ["text_embedder.cc"],
|
||||
hdrs = ["text_embedder.h"],
|
||||
deps = [
|
||||
":text_embedder_graph",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_embedder_graph",
|
||||
srcs = ["text_embedder_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_embedder_test",
|
||||
srcs = ["text_embedder_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
],
|
||||
deps = [
|
||||
":text_embedder",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
30
mediapipe/tasks/cc/text/text_embedder/proto/BUILD
Normal file
30
mediapipe/tasks/cc/text/text_embedder/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_embedder_graph_options_proto",
|
||||
srcs = ["text_embedder_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.text.text_embedder.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message TextEmbedderGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextEmbedderGraphOptions ext = 477589892;
|
||||
}
|
||||
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the embedder behavior, such as normalization or
|
||||
// quantization.
|
||||
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||
}
|
104
mediapipe/tasks/cc/text/text_embedder/text_embedder.cc
Normal file
104
mediapipe/tasks/cc/text/text_embedder/text_embedder.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTextInStreamName[] = "text_in";
|
||||
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
|
||||
constexpr char kGraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
// Creates a MediaPipe graph config that contains a single node of type
|
||||
// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_graph = graph.AddNode(kGraphTypeName);
|
||||
task_graph.GetOptions<proto::TextEmbedderGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag);
|
||||
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
|
||||
graph.Out(kEmbeddingsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing TextEmbedderOptions struct to the internal
|
||||
// TextEmbedderGraphOptions proto.
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions>
|
||||
ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) {
|
||||
auto options_proto = std::make_unique<proto::TextEmbedderGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto embedder_options_proto =
|
||||
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||
components::processors::ConvertEmbedderOptionsToProto(
|
||||
&(options->embedder_options)));
|
||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<TextEmbedder>> TextEmbedder::Create(
|
||||
std::unique_ptr<TextEmbedderOptions> options) {
|
||||
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto =
|
||||
ConvertTextEmbedderOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<TextEmbedder,
|
||||
proto::TextEmbedderGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<TextEmbedderResult> TextEmbedder::Embed(absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextInStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
return ConvertToEmbeddingResult(
|
||||
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||
}
|
||||
|
||||
absl::StatusOr<double> TextEmbedder::CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v) {
|
||||
return components::utils::CosineSimilarity(u, v);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
96
mediapipe/tasks/cc/text/text_embedder/text_embedder.h
Normal file
96
mediapipe/tasks/cc/text/text_embedder/text_embedder.h
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
|
||||
// Alias the shared EmbeddingResult struct as result typo.
|
||||
using TextEmbedderResult =
|
||||
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||
|
||||
// Options for configuring a MediaPipe text embedder task.
|
||||
struct TextEmbedderOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the embedder behavior, such as L2-normalization or
|
||||
// scalar-quantization.
|
||||
components::processors::EmbedderOptions embedder_options;
|
||||
};
|
||||
|
||||
// Performs embedding extraction on text.
|
||||
//
|
||||
// This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||
// mandatory (described below) input tensors and output tensors. Metadata should
|
||||
// contain the input process unit for the model's Tokenizer as well as input /
|
||||
// output tensor metadata.
|
||||
//
|
||||
// TODO: Support Universal Sentence Encoder.
|
||||
// Input tensors:
|
||||
// (kTfLiteInt32)
|
||||
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
|
||||
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
|
||||
// segment ids respectively
|
||||
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
||||
// input ids
|
||||
//
|
||||
// At least one output tensor with:
|
||||
// (kTfLiteFloat32)
|
||||
// - `N` components corresponding to the `N` dimensions of the returned
|
||||
// feature vector for this output layer.
|
||||
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
|
||||
class TextEmbedder : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a TextEmbedder from the provided `options`. A non-default
|
||||
// OpResolver can be specified in the BaseOptions in order to support custom
|
||||
// Ops or specify a subset of built-in Ops.
|
||||
static absl::StatusOr<std::unique_ptr<TextEmbedder>> Create(
|
||||
std::unique_ptr<TextEmbedderOptions> options);
|
||||
|
||||
// Performs embedding extraction on the input `text`.
|
||||
absl::StatusOr<TextEmbedderResult> Embed(absl::string_view text);
|
||||
|
||||
// Shuts down the TextEmbedder when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||
// 0.
|
||||
//
|
||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||
static absl::StatusOr<double> CosineSimilarity(
|
||||
const components::containers::Embedding& u,
|
||||
const components::containers::Embedding& v);
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
|
145
mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc
Normal file
145
mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc
Normal file
|
@ -0,0 +1,145 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
|
||||
// extraction.
|
||||
// - Accepts input text and outputs embeddings on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// Input text to perform embedding extraction on.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The embedding result.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.text.TextEmbedderGraph"
|
||||
// input_stream: "TEXT:text_in"
|
||||
// output_stream: "EMBEDDINGS:embedding_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TextEmbedderGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
CHECK(sc != nullptr);
|
||||
ASSIGN_OR_RETURN(const ModelResources* model_resources,
|
||||
CreateModelResources<proto::TextEmbedderGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
Source<EmbeddingResult> embedding_result_out,
|
||||
BuildTextEmbedderTask(sc->Options<proto::TextEmbedderGraphOptions>(),
|
||||
*model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds a mediapipe TextEmbedder task graph into the provided
|
||||
// builder::Graph instance. The TextEmbedder task takes an input
|
||||
// text (std::string) and returns an embedding result.
|
||||
//
|
||||
// task_options: the mediapipe tasks TextEmbedderGraphOptions proto.
|
||||
// model_resources: the ModelResources object initialized from a
|
||||
// TextEmbedder model file with model metadata.
|
||||
// text_in: (std::string) stream to run embedding extraction on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<EmbeddingResult>> BuildTextEmbedderTask(
|
||||
const proto::TextEmbedderGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the text input
|
||||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
|
||||
model_resources,
|
||||
preprocessing.GetOptions<
|
||||
tasks::components::proto::TextPreprocessingGraphOptions>()));
|
||||
text_in >> preprocessing.In(kTextTag);
|
||||
|
||||
// Adds both InferenceCalculator and ModelResourcesCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
// The metadata extractor side-output comes from the
|
||||
// ModelResourcesCalculator.
|
||||
inference.SideOut(kMetadataExtractorTag) >>
|
||||
preprocessing.SideIn(kMetadataExtractorTag);
|
||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
|
||||
// Adds postprocessing calculators and connects its input stream to the
|
||||
// inference results.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||
model_resources, task_options.embedder_options(),
|
||||
&postprocessing.GetOptions<components::processors::proto::
|
||||
EmbeddingPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the embedding result.
|
||||
return postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::text::text_embedder::TextEmbedderGraph);
|
||||
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
143
mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc
Normal file
143
mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc
Normal file
|
@ -0,0 +1,143 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
|
||||
// Note that these models use dynamic-sized tensors.
|
||||
// Embedding model with BERT preprocessing.
|
||||
constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
||||
// Embedding model with regex preprocessing.
|
||||
constexpr char kRegexOneEmbeddingModel[] =
|
||||
"regex_one_embedding_with_metadata.tflite";
|
||||
|
||||
// Tolerance for embedding vector coordinate values.
|
||||
constexpr float kEpsilon = 1e-4;
|
||||
// Tolerancy for cosine similarity evaluation.
|
||||
constexpr double kSimilarityTolerancy = 1e-6;
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
class EmbedderTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
||||
auto text_embedder =
|
||||
TextEmbedder::Create(std::make_unique<TextEmbedderOptions>());
|
||||
ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
ASSERT_THAT(
|
||||
text_embedder.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
|
||||
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
EXPECT_EQ(result0.embeddings.size(), 1);
|
||||
EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16);
|
||||
|
||||
EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
EXPECT_EQ(result1.embeddings.size(), 1);
|
||||
EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16);
|
||||
|
||||
EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
options->embedder_options.quantize = true;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result.embeddings.size(), 1);
|
||||
ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
Loading…
Reference in New Issue
Block a user