Open-sources TextEmbedder.

PiperOrigin-RevId: 487041832
This commit is contained in:
MediaPipe Team 2022-11-08 13:46:50 -08:00 committed by Copybara-Service
parent ace098f370
commit 0363d60511
7 changed files with 641 additions and 0 deletions

View File

@ -0,0 +1,87 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_embedder",
srcs = ["text_embedder.cc"],
hdrs = ["text_embedder.h"],
deps = [
":text_embedder_graph",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_embedder_graph",
srcs = ["text_embedder_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "text_embedder_test",
srcs = ["text_embedder_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
],
deps = [
":text_embedder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "text_embedder_graph_options_proto",
srcs = ["text_embedder_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,36 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.text.text_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message TextEmbedderGraphOptions {
extend mediapipe.CalculatorOptions {
optional TextEmbedderGraphOptions ext = 477589892;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring the embedder behavior, such as normalization or
// quantization.
optional components.processors.proto.EmbedderOptions embedder_options = 2;
}

View File

@ -0,0 +1,104 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
#include <memory>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
constexpr char kTextTag[] = "TEXT";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextInStreamName[] = "text_in";
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kGraphTypeName[] =
"mediapipe.tasks.text.text_embedder.TextEmbedderGraph";
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Creates a MediaPipe graph config that contains a single node of type
// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto) {
api2::builder::Graph graph;
auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<proto::TextEmbedderGraphOptions>().Swap(
options_proto.get());
graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag);
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingsTag);
return graph.GetConfig();
}
// Converts the user-facing TextEmbedderOptions struct to the internal
// TextEmbedderGraphOptions proto.
std::unique_ptr<proto::TextEmbedderGraphOptions>
ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) {
auto options_proto = std::make_unique<proto::TextEmbedderGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
auto embedder_options_proto =
std::make_unique<components::processors::proto::EmbedderOptions>(
components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<TextEmbedder>> TextEmbedder::Create(
std::unique_ptr<TextEmbedderOptions> options) {
std::unique_ptr<proto::TextEmbedderGraphOptions> options_proto =
ConvertTextEmbedderOptionsToProto(options.get());
return core::TaskApiFactory::Create<TextEmbedder,
proto::TextEmbedderGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver));
}
absl::StatusOr<TextEmbedderResult> TextEmbedder::Embed(absl::string_view text) {
ASSIGN_OR_RETURN(
auto output_packets,
runner_->Process(
{{kTextInStreamName, MakePacket<std::string>(std::string(text))}}));
return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
}
absl::StatusOr<double> TextEmbedder::CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v);
}
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,96 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
namespace mediapipe::tasks::text::text_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using TextEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// Options for configuring a MediaPipe text embedder task.
struct TextEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization.
components::processors::EmbedderOptions embedder_options;
};
// Performs embedding extraction on text.
//
// This API expects a TFLite model with TFLite Model Metadata that contains the
// mandatory (described below) input tensors and output tensors. Metadata should
// contain the input process unit for the model's Tokenizer as well as input /
// output tensor metadata.
//
// TODO: Support Universal Sentence Encoder.
// Input tensors:
// (kTfLiteInt32)
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names
// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and
// segment ids respectively
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
// input ids
//
// At least one output tensor with:
// (kTfLiteFloat32)
// - `N` components corresponding to the `N` dimensions of the returned
// feature vector for this output layer.
// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
class TextEmbedder : core::BaseTaskApi {
public:
using BaseTaskApi::BaseTaskApi;
// Creates a TextEmbedder from the provided `options`. A non-default
// OpResolver can be specified in the BaseOptions in order to support custom
// Ops or specify a subset of built-in Ops.
static absl::StatusOr<std::unique_ptr<TextEmbedder>> Create(
std::unique_ptr<TextEmbedderOptions> options);
// Performs embedding extraction on the input `text`.
absl::StatusOr<TextEmbedderResult> Embed(absl::string_view text);
// Shuts down the TextEmbedder when all the work is done.
absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embeddings.
// May return an InvalidArgumentError if e.g. the embeddings are of different
// types (quantized vs. float), have different sizes, or have a an L2-norm of
// 0.
//
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity(
const components::containers::Embedding& u,
const components::containers::Embedding& v);
};
} // namespace mediapipe::tasks::text::text_embedder
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_

View File

@ -0,0 +1,145 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
} // namespace
// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding
// extraction.
// - Accepts input text and outputs embeddings on CPU.
//
// Inputs:
// TEXT - std::string
// Input text to perform embedding extraction on.
//
// Outputs:
// EMBEDDINGS - EmbeddingResult
// The embedding result.
//
// Example:
// node {
// calculator: "mediapipe.tasks.text.TextEmbedderGraph"
// input_stream: "TEXT:text_in"
// output_stream: "EMBEDDINGS:embedding_result_out"
// options {
// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// }
// }
// }
class TextEmbedderGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
CHECK(sc != nullptr);
ASSIGN_OR_RETURN(const ModelResources* model_resources,
CreateModelResources<proto::TextEmbedderGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
Source<EmbeddingResult> embedding_result_out,
BuildTextEmbedderTask(sc->Options<proto::TextEmbedderGraphOptions>(),
*model_resources,
graph[Input<std::string>(kTextTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig();
}
private:
// Adds a mediapipe TextEmbedder task graph into the provided
// builder::Graph instance. The TextEmbedder task takes an input
// text (std::string) and returns an embedding result.
//
// task_options: the mediapipe tasks TextEmbedderGraphOptions proto.
// model_resources: the ModelResources object initialized from a
// TextEmbedder model file with model metadata.
// text_in: (std::string) stream to run embedding extraction on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<EmbeddingResult>> BuildTextEmbedderTask(
const proto::TextEmbedderGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
// The metadata extractor side-output comes from the
// ModelResourcesCalculator.
inference.SideOut(kMetadataExtractorTag) >>
preprocessing.SideIn(kMetadataExtractorTag);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects its input stream to the
// inference results.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding result.
return postprocessing[Output<EmbeddingResult>(kEmbeddingsTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::text::text_embedder::TextEmbedderGraph);
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -0,0 +1,143 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h"
#include <memory>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
// Note that these models use dynamic-sized tensors.
// Embedding model with BERT preprocessing.
constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite";
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
// Tolerancy for cosine similarity evaluation.
constexpr double kSimilarityTolerancy = 1e-6;
using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
using ::testing::Optional;
class EmbedderTest : public tflite_shims::testing::Test {};
TEST_F(EmbedderTest, FailsWithMissingModel) {
auto text_embedder =
TextEmbedder::Create(std::make_unique<TextEmbedderOptions>());
ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument);
ASSERT_THAT(
text_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(EmbedderTest, SucceedsWithMobileBert) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
EXPECT_EQ(result0.embeddings.size(), 1);
EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16);
EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
EXPECT_EQ(result1.embeddings.size(), 1);
EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16);
EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithQuantization) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
options->embedder_options.quantize = true;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result.embeddings.size(), 1);
ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder