From 0363d60511676379ac558838592cf373c4416f8e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 13:46:50 -0800 Subject: [PATCH] Open-sources TextEmbedder. PiperOrigin-RevId: 487041832 --- mediapipe/tasks/cc/text/text_embedder/BUILD | 87 +++++++++++ .../tasks/cc/text/text_embedder/proto/BUILD | 30 ++++ .../proto/text_embedder_graph_options.proto | 36 +++++ .../cc/text/text_embedder/text_embedder.cc | 104 +++++++++++++ .../cc/text/text_embedder/text_embedder.h | 96 ++++++++++++ .../text/text_embedder/text_embedder_graph.cc | 145 ++++++++++++++++++ .../text/text_embedder/text_embedder_test.cc | 143 +++++++++++++++++ 7 files changed, 641 insertions(+) create mode 100644 mediapipe/tasks/cc/text/text_embedder/BUILD create mode 100644 mediapipe/tasks/cc/text/text_embedder/proto/BUILD create mode 100644 mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder.cc create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder.h create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc create mode 100644 mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD new file mode 100644 index 000000000..331902362 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -0,0 +1,87 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_embedder", + srcs = ["text_embedder.cc"], + hdrs = ["text_embedder.h"], + deps = [ + ":text_embedder_graph", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "text_embedder_graph", + srcs = ["text_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "text_embedder_test", + srcs = ["text_embedder_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + deps = [ + ":text_embedder", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/BUILD b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD new file mode 100644 index 000000000..146483af1 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "text_embedder_graph_options_proto", + srcs = ["text_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto new file mode 100644 index 000000000..6b8d41a57 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.proto @@ -0,0 +1,36 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.text.text_embedder.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message TextEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional TextEmbedderGraphOptions ext = 477589892; + } + + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the embedder behavior, such as normalization or + // quantization. + optional components.processors.proto.EmbedderOptions embedder_options = 2; +} diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc new file mode 100644 index 000000000..375058d57 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.cc @@ -0,0 +1,104 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h" + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" +#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +constexpr char kTextTag[] = "TEXT"; +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTextInStreamName[] = "text_in"; +constexpr char kEmbeddingsStreamName[] = "embeddings_out"; +constexpr char kGraphTypeName[] = + "mediapipe.tasks.text.text_embedder.TextEmbedderGraph"; + +using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; + +// Creates a MediaPipe graph config that contains a single node of type +// "mediapipe.tasks.text.text_embedder.TextEmbedderGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options_proto) { + api2::builder::Graph graph; + auto& task_graph = graph.AddNode(kGraphTypeName); + task_graph.GetOptions().Swap( + options_proto.get()); + graph.In(kTextTag).SetName(kTextInStreamName) >> task_graph.In(kTextTag); + task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >> + graph.Out(kEmbeddingsTag); + return graph.GetConfig(); +} + +// Converts the user-facing TextEmbedderOptions struct to the internal +// TextEmbedderGraphOptions proto. +std::unique_ptr +ConvertTextEmbedderOptionsToProto(TextEmbedderOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto embedder_options_proto = + std::make_unique( + components::processors::ConvertEmbedderOptionsToProto( + &(options->embedder_options))); + options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> TextEmbedder::Create( + std::unique_ptr options) { + std::unique_ptr options_proto = + ConvertTextEmbedderOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr TextEmbedder::Embed(absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextInStreamName, MakePacket(std::string(text))}})); + return ConvertToEmbeddingResult( + output_packets[kEmbeddingsStreamName].Get()); +} + +absl::StatusOr TextEmbedder::CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v) { + return components::utils::CosineSimilarity(u, v); +} + +} // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder.h b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h new file mode 100644 index 000000000..81f90fd27 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "mediapipe/tasks/cc/components/processors/embedder_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe::tasks::text::text_embedder { + +// Alias the shared EmbeddingResult struct as result typo. +using TextEmbedderResult = + ::mediapipe::tasks::components::containers::EmbeddingResult; + +// Options for configuring a MediaPipe text embedder task. +struct TextEmbedderOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the embedder behavior, such as L2-normalization or + // scalar-quantization. + components::processors::EmbedderOptions embedder_options; +}; + +// Performs embedding extraction on text. +// +// This API expects a TFLite model with TFLite Model Metadata that contains the +// mandatory (described below) input tensors and output tensors. Metadata should +// contain the input process unit for the model's Tokenizer as well as input / +// output tensor metadata. +// +// TODO: Support Universal Sentence Encoder. +// Input tensors: +// (kTfLiteInt32) +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` with names +// "ids", "mask", and "segment_ids" representing the input ids, mask ids, and +// segment ids respectively +// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the +// input ids +// +// At least one output tensor with: +// (kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class TextEmbedder : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a TextEmbedder from the provided `options`. A non-default + // OpResolver can be specified in the BaseOptions in order to support custom + // Ops or specify a subset of built-in Ops. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs embedding extraction on the input `text`. + absl::StatusOr Embed(absl::string_view text); + + // Shuts down the TextEmbedder when all the work is done. + absl::Status Close() { return runner_->Close(); } + + // Utility function to compute cosine similarity [1] between two embeddings. + // May return an InvalidArgumentError if e.g. the embeddings are of different + // types (quantized vs. float), have different sizes, or have a an L2-norm of + // 0. + // + // [1]: https://en.wikipedia.org/wiki/Cosine_similarity + static absl::StatusOr CosineSimilarity( + const components::containers::Embedding& u, + const components::containers::Embedding& v); +}; + +} // namespace mediapipe::tasks::text::text_embedder + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_EMBEDDER_TEXT_EMBEDDER_H_ diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc new file mode 100644 index 000000000..79eedb6b5 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -0,0 +1,145 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" +#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" +#include "mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options.pb.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +} // namespace + +// A "mediapipe.tasks.text.TextEmbedderGraph" performs text embedding +// extraction. +// - Accepts input text and outputs embeddings on CPU. +// +// Inputs: +// TEXT - std::string +// Input text to perform embedding extraction on. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult +// The embedding result. +// +// Example: +// node { +// calculator: "mediapipe.tasks.text.TextEmbedderGraph" +// input_stream: "TEXT:text_in" +// output_stream: "EMBEDDINGS:embedding_result_out" +// options { +// [mediapipe.tasks.text.text_embedder.proto.TextEmbedderGraphOptions.ext] { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// } +// } +// } +class TextEmbedderGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + CHECK(sc != nullptr); + ASSIGN_OR_RETURN(const ModelResources* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + Source embedding_result_out, + BuildTextEmbedderTask(sc->Options(), + *model_resources, + graph[Input(kTextTag)], graph)); + embedding_result_out >> graph[Output(kEmbeddingsTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe TextEmbedder task graph into the provided + // builder::Graph instance. The TextEmbedder task takes an input + // text (std::string) and returns an embedding result. + // + // task_options: the mediapipe tasks TextEmbedderGraphOptions proto. + // model_resources: the ModelResources object initialized from a + // TextEmbedder model file with model metadata. + // text_in: (std::string) stream to run embedding extraction on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildTextEmbedderTask( + const proto::TextEmbedderGraphOptions& task_options, + const ModelResources& model_resources, Source text_in, + Graph& graph) { + // Adds preprocessing calculators and connects them to the text input + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); + MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + model_resources, + preprocessing.GetOptions< + tasks::components::proto::TextPreprocessingGraphOptions>())); + text_in >> preprocessing.In(kTextTag); + + // Adds both InferenceCalculator and ModelResourcesCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + // The metadata extractor side-output comes from the + // ModelResourcesCalculator. + inference.SideOut(kMetadataExtractorTag) >> + preprocessing.SideIn(kMetadataExtractorTag); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects its input stream to the + // inference results. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( + model_resources, task_options.embedder_options(), + &postprocessing.GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the embedding result. + return postprocessing[Output(kEmbeddingsTag)]; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::text::text_embedder::TextEmbedderGraph); + +} // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc new file mode 100644 index 000000000..fa3d8af91 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_embedder/text_embedder.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/embedding_result.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe::tasks::text::text_embedder { +namespace { + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; + +// Note that these models use dynamic-sized tensors. +// Embedding model with BERT preprocessing. +constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; +// Embedding model with regex preprocessing. +constexpr char kRegexOneEmbeddingModel[] = + "regex_one_embedding_with_metadata.tflite"; + +// Tolerance for embedding vector coordinate values. +constexpr float kEpsilon = 1e-4; +// Tolerancy for cosine similarity evaluation. +constexpr double kSimilarityTolerancy = 1e-6; + +using ::mediapipe::file::JoinPath; +using ::testing::HasSubstr; +using ::testing::Optional; + +class EmbedderTest : public tflite_shims::testing::Test {}; + +TEST_F(EmbedderTest, FailsWithMissingModel) { + auto text_embedder = + TextEmbedder::Create(std::make_unique()); + ASSERT_EQ(text_embedder.status().code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT( + text_embedder.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + ASSERT_THAT(text_embedder.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(EmbedderTest, SucceedsWithMobileBert) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result0.embeddings.size(), 1); + ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512); + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + ASSERT_EQ(result1.embeddings.size(), 1); + ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512); + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.969514, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + +TEST(EmbedTest, SucceedsWithRegexOneEmbeddingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kRegexOneEmbeddingModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + EXPECT_EQ(result0.embeddings.size(), 1); + EXPECT_EQ(result0.embeddings[0].float_embedding.size(), 16); + + EXPECT_NEAR(result0.embeddings[0].float_embedding[0], 0.0309356f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + EXPECT_EQ(result1.embeddings.size(), 1); + EXPECT_EQ(result1.embeddings[0].float_embedding.size(), 16); + + EXPECT_NEAR(result1.embeddings[0].float_embedding[0], 0.0312863f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.999937, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + +TEST_F(EmbedderTest, SucceedsWithQuantization) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileBert); + options->embedder_options.quantize = true; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result.embeddings.size(), 1); + ASSERT_EQ(result.embeddings[0].quantized_embedding.size(), 512); + + MP_ASSERT_OK(text_embedder->Close()); +} + +} // namespace +} // namespace mediapipe::tasks::text::text_embedder