From 1604908a590da3c549e9eaea1f8ac7d926b32623 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 10 Nov 2022 02:16:51 -0800 Subject: [PATCH 1/6] Added files needed for the text embedder's implementation and tests --- mediapipe/python/BUILD | 1 + .../tasks/python/components/containers/BUILD | 9 + .../components/containers/embedding_result.py | 210 ++++++++++++++++++ .../tasks/python/components/processors/BUILD | 9 + .../components/processors/embedder_options.py | 70 ++++++ mediapipe/tasks/python/components/utils/BUILD | 28 +++ .../tasks/python/components/utils/__init__.py | 13 ++ .../components/utils/cosine_similarity.py | 61 +++++ mediapipe/tasks/python/test/text/BUILD | 17 ++ .../python/test/text/text_embedder_test.py | 207 +++++++++++++++++ mediapipe/tasks/python/text/BUILD | 20 ++ mediapipe/tasks/python/text/text_embedder.py | 166 ++++++++++++++ 12 files changed, 811 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/embedding_result.py create mode 100644 mediapipe/tasks/python/components/processors/embedder_options.py create mode 100644 mediapipe/tasks/python/components/utils/BUILD create mode 100644 mediapipe/tasks/python/components/utils/__init__.py create mode 100644 mediapipe/tasks/python/components/utils/cosine_similarity.py create mode 100644 mediapipe/tasks/python/test/text/text_embedder_test.py create mode 100644 mediapipe/tasks/python/text/text_embedder.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 8548a60d8..b19f17b29 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -98,6 +98,7 @@ cc_library( "//conditions:default": [ "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", ], }), ) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 9e0a90911..7091c1f41 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -104,3 +104,12 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "embedding_result", + srcs = ["embedding_result.py"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/containers/embedding_result.py b/mediapipe/tasks/python/components/containers/embedding_result.py new file mode 100644 index 000000000..35882b9f3 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/embedding_result.py @@ -0,0 +1,210 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Embeddings data class.""" + +import dataclasses +from typing import Any, Optional, List + +import numpy as np +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding +_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding +_EmbeddingProto = embeddings_pb2.Embedding +_EmbeddingResultProto = embeddings_pb2.EmbeddingResult + + +@dataclasses.dataclass +class FloatEmbedding: + """Defines a dense floating-point embedding. + + Attributes: + values: A NumPy array indicating the raw output of the embedding layer. + """ + + values: np.ndarray + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FloatEmbeddingProto: + """Generates a FloatEmbedding protobuf object.""" + return _FloatEmbeddingProto(values=self.values) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _FloatEmbeddingProto) -> 'FloatEmbedding': + """Creates a `FloatEmbedding` object from the given protobuf object.""" + return FloatEmbedding(values=np.array(pb2_obj.value_float, dtype=float)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, FloatEmbedding): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class QuantizedEmbedding: + """Defines a dense scalar-quantized embedding. + + Attributes: + values: A NumPy array indicating the raw output of the embedding layer. + """ + + values: np.ndarray + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _QuantizedEmbeddingProto: + """Generates a QuantizedEmbedding protobuf object.""" + return _QuantizedEmbeddingProto(values=self.values) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _QuantizedEmbeddingProto) -> 'QuantizedEmbedding': + """Creates a `QuantizedEmbedding` object from the given protobuf object.""" + return QuantizedEmbedding( + values=np.array(bytearray(pb2_obj.value_string), dtype=np.uint8)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, QuantizedEmbedding): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Embedding: + """Embedding result for a given embedder head. + + Attributes: + embedding: The actual embedding, either floating-point or scalar-quantized. + head_index: The index of the embedder head that produced this embedding. + This is useful for multi-head models. + head_name: The name of the embedder head, which is the corresponding tensor + metadata name (if any). This is useful for multi-head models. + """ + + embedding: np.ndarray + head_index: Optional[int] = None + head_name: Optional[str] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbeddingProto: + """Generates a Embedding protobuf object.""" + + if self.embedding.dtype == float: + return _EmbeddingProto(float_embedding=self.embedding, + head_index=self.head_index, + head_name=self.head_name) + + elif self.embedding.dtype == np.uint8: + return _EmbeddingProto(quantized_embedding=bytes(self.embedding), + head_index=self.head_index, + head_name=self.head_name) + + else: + raise ValueError("Invalid dtype. Only float and np.uint8 are supported.") + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _EmbeddingProto) -> 'Embedding': + """Creates a `Embedding` object from the given protobuf object.""" + + quantized_embedding = np.array( + bytearray(pb2_obj.quantized_embedding.values)) + float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float) + + if len(quantized_embedding) == 0: + return Embedding(embedding=float_embedding, + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + else: + return Embedding(embedding=quantized_embedding, + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, Embedding): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class EmbeddingResult: + """Embedding results for a given embedder model. + Attributes: + embeddings: A list of `Embedding` objects. + timestamp_ms: The optional timestamp (in milliseconds) of the start of the + chunk of data corresponding to these results. This is only used for + embedding extraction on time series (e.g. audio embedding). In these use + cases, the amount of data to process might exceed the maximum size that + the model can process: to solve this, the input data is split into + multiple chunks starting at different timestamps. + """ + + embeddings: List[Embedding] + timestamp_ms: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbeddingResultProto: + """Generates a EmbeddingResult protobuf object.""" + return _EmbeddingResultProto( + embeddings=[ + embedding.to_pb2() for embedding in self.embeddings + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _EmbeddingResultProto) -> 'EmbeddingResult': + """Creates a `EmbeddingResult` object from the given protobuf object.""" + return EmbeddingResult( + embeddings=[ + Embedding.create_from_pb2(embedding) + for embedding in pb2_obj.embeddings + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, EmbeddingResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index f87a579b0..eef368db0 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -28,3 +28,12 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "embedder_options", + srcs = ["embedder_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/processors/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py new file mode 100644 index 000000000..dcd316dcd --- /dev/null +++ b/mediapipe/tasks/python/components/processors/embedder_options.py @@ -0,0 +1,70 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Embedder options data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions + + +@dataclasses.dataclass +class EmbedderOptions: + """Shared options used by all embedding extraction tasks. + + Attributes: + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. + """ + + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbedderOptionsProto: + """Generates a EmbedderOptions protobuf object.""" + return _EmbedderOptionsProto( + l2_normalize=self.l2_normalize, + quantize=self.quantize) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _EmbedderOptionsProto) -> 'EmbedderOptions': + """Creates a `EmbedderOptions` object from the given protobuf object.""" + return EmbedderOptions( + l2_normalize=pb2_obj.l2_normalize, + quantize=pb2_obj.quantize) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, EmbedderOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD new file mode 100644 index 000000000..50d4094c0 --- /dev/null +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -0,0 +1,28 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.py"], + deps = [ + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + ], +) diff --git a/mediapipe/tasks/python/components/utils/__init__.py b/mediapipe/tasks/python/components/utils/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/components/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py new file mode 100644 index 000000000..d6102f0b5 --- /dev/null +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -0,0 +1,61 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cosine similarity utilities.""" + +import numpy as np + +from mediapipe.tasks.python.components.containers import embedding_result +from mediapipe.tasks.python.components.processors import embedder_options + +_Embedding = embedding_result.Embedding +_EmbedderOptions = embedder_options.EmbedderOptions + + +def _compute_cosine_similarity(u, v): + if len(u.embedding) <= 0: + raise ValueError("Cannot compute cosing similarity on empty embeddings.") + + norm_u = np.linalg.norm(u.embedding) + norm_v = np.linalg.norm(v.embedding) + + if norm_u <= 0 or norm_v <= 0: + raise ValueError( + "Cannot compute cosine similarity on embedding with 0 norm.") + + return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) + + +def cosine_similarity(u: _Embedding, v: _Embedding) -> float: + """Utility function to compute cosine similarity between two embedding. + May return an InvalidArgumentError if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have an + L2-norm of 0. + + Args: + u: An embedding. + v: An embedding. + """ + if len(u.embedding) != len(v.embedding): + raise ValueError(f"Cannot compute cosine similarity between embeddings " + f"of different sizes " + f"({len(u.embedding)} vs. {len(v.embedding)}).") + + if u.embedding.dtype == float and v.embedding.dtype == float: + return _compute_cosine_similarity(u, v) + + if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8: + return _compute_cosine_similarity(u, v) + + raise ValueError("Cannot compute cosine similarity between quantized and " + "float embeddings.") diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index d7176b0a5..976ea1ec2 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -34,3 +34,20 @@ py_test( "//mediapipe/tasks/python/text:text_classifier", ], ) + +py_test( + name = "text_embedder_test", + srcs = ["text_embedder_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + deps = [ + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/components/utils:cosine_similarity", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:text_embedder", + ], +) diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py new file mode 100644 index 000000000..b39f51a8d --- /dev/null +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -0,0 +1,207 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for text embedder.""" + +import enum +import os +from unittest import mock + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import text_embedder + +ImageEmbedderResult = embedding_result_module.EmbeddingResult +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptions = embedder_options_module.EmbedderOptions +_FloatEmbedding = embedding_result_module.FloatEmbedding +_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding +_Embedding = embedding_result_module.Embedding +_TextEmbedder = text_embedder.TextEmbedder +_TextEmbedderOptions = text_embedder.TextEmbedderOptions + +_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' +_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' +# Tolerance for embedding vector coordinate values. +_EPSILON = 1e-4 +# Tolerance for cosine similarity evaluation. +_SIMILARITY_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageEmbedderTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _TextEmbedder.create_from_model_path(self.model_path) as embedder: + self.assertIsInstance(embedder, _TextEmbedder) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _TextEmbedderOptions(base_options=base_options) + with _TextEmbedder.create_from_options(options) as embedder: + self.assertIsInstance(embedder, _TextEmbedder) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _TextEmbedderOptions(base_options=base_options) + _TextEmbedder.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _TextEmbedderOptions(base_options=base_options) + embedder = _TextEmbedder.create_from_options(options) + self.assertIsInstance(embedder, _TextEmbedder) + + def _check_embedding_value(self, result, expected_first_value): + # Check embedding first value. + self.assertAlmostEqual(result.embeddings[0].embedding[0], + expected_first_value, delta=_EPSILON) + + def _check_embedding_size(self, result, quantize, expected_embedding_size): + # Check embedding size. + self.assertLen(result.embeddings, 1) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, expected_embedding_size) + if quantize: + self.assertEqual(embedding_result.embedding.dtype, np.uint8) + else: + self.assertEqual(embedding_result.embedding.dtype, float) + + def _check_cosine_similarity(self, result0, result1, expected_similarity): + # Checks cosine similarity. + similarity = _TextEmbedder.cosine_similarity( + result0.embeddings[0], result1.embeddings[0]) + self.assertAlmostEqual(similarity, expected_similarity, + delta=_SIMILARITY_TOLERANCE) + + @parameterized.parameters( + (False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, + 0.969514, 512, (19.9016, 22.626251)), + (True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, + 0.969514, 512, (0.0585837, 0.0723035)), + (False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, + 0.999937, 16, (0.0309356, 0.0312863)), + (True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, + 0.999937, 16, (0.549632, 0.552879)), + ) + def test_embed(self, l2_normalize, quantize, model_name, model_file_type, + expected_similarity, expected_size, expected_first_values): + # Creates embedder. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + embedder_options = _EmbedderOptions(l2_normalize=l2_normalize, + quantize=quantize) + options = _TextEmbedderOptions( + base_options=base_options, embedder_options=embedder_options) + embedder = _TextEmbedder.create_from_options(options) + + # Extracts both embeddings. + positive_text0 = "it's a charming and often affecting journey" + positive_text1 = "what a great and fantastic trip" + + result0 = embedder.embed(positive_text0) + result1 = embedder.embed(positive_text1) + + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(result0, quantize, expected_size) + self._check_embedding_size(result1, quantize, expected_size) + self._check_embedding_value(result0, expected_result0_value) + self._check_embedding_value(result1, expected_result1_value) + self._check_cosine_similarity(result0, result1, expected_similarity) + # Closes the embedder explicitly when the embedder is not used in + # a context. + embedder.close() + + @parameterized.parameters( + (False, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, + 0.969514, 512, (19.9016, 22.626251)), + (True, False, _BERT_MODEL_FILE, ModelFileType.FILE_NAME, + 0.969514, 512, (0.0585837, 0.0723035)), + (False, False, _REGEX_MODEL_FILE, ModelFileType.FILE_NAME, + 0.999937, 16, (0.0309356, 0.0312863)), + (True, False, _REGEX_MODEL_FILE, ModelFileType.FILE_CONTENT, + 0.999937, 16, (0.549632, 0.552879)), + ) + def test_embed_in_context(self, l2_normalize, quantize, model_name, + model_file_type, expected_similarity, + expected_size, expected_first_values): + # Creates embedder. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + embedder_options = _EmbedderOptions(l2_normalize=l2_normalize, + quantize=quantize) + options = _TextEmbedderOptions( + base_options=base_options, embedder_options=embedder_options) + with _TextEmbedder.create_from_options(options) as embedder: + # Extracts both embeddings. + positive_text0 = "it's a charming and often affecting journey" + positive_text1 = "what a great and fantastic trip" + + result0 = embedder.embed(positive_text0) + result1 = embedder.embed(positive_text1) + + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(result0, quantize, expected_size) + self._check_embedding_size(result1, quantize, expected_size) + self._check_embedding_value(result0, expected_result0_value) + self._check_embedding_value(result1, expected_result1_value) + self._check_cosine_similarity(result0, result1, expected_similarity) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index fd5d701b4..8372c4bdb 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -36,3 +36,23 @@ py_library( "//mediapipe/tasks/python/text/core:base_text_task_api", ], ) + +py_library( + name = "text_embedder", + srcs = [ + "text_embedder.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py new file mode 100644 index 000000000..0f22caca4 --- /dev/null +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -0,0 +1,166 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe text embedder task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.python.components.processors import embedder_options +from mediapipe.tasks.python.components.utils import cosine_similarity +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +TextEmbedderResult = embedding_result_module.EmbeddingResult +_BaseOptions = base_options_module.BaseOptions +_TextEmbedderGraphOptionsProto = text_embedder_graph_options_pb2.TextEmbedderGraphOptions +_EmbedderOptions = embedder_options.EmbedderOptions +_TaskInfo = task_info_module.TaskInfo + +_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' +_EMBEDDINGS_TAG = 'EMBEDDINGS' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class TextEmbedderOptions: + """Options for the text embedder task. + + Attributes: + base_options: Base options for the text embedder task. + embedder_options: Options for the text embedder task. + """ + base_options: _BaseOptions + embedder_options: _EmbedderOptions = _EmbedderOptions() + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextEmbedderGraphOptionsProto: + """Generates an TextEmbedderOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + embedder_options_proto = self.embedder_options.to_pb2() + + return _TextEmbedderGraphOptionsProto( + base_options=base_options_proto, + embedder_options=embedder_options_proto + ) + + +class TextEmbedder(base_text_task_api.BaseTextTaskApi): + """Class that performs embedding extraction on text.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder': + """Creates an `TextEmbedder` object from a TensorFlow Lite model and the + default `TextEmbedderOptions`. + + Args: + model_path: Path to the model. + + Returns: + `TextEmbedder` object that's created from the model file and the default + `TextEmbedderOptions`. + + Raises: + ValueError: If failed to create `TextEmbedder` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = TextEmbedderOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: TextEmbedderOptions) -> 'TextEmbedder': + """Creates the `TextEmbedder` object from text embedder options. + + Args: + options: Options for the text embedder task. + + Returns: + `TextEmbedder` object that's created from `options`. + + Raises: + ValueError: If failed to create `TextEmbedder` object from + `TextEmbedderOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], + output_streams=[ + ':'.join([ + _EMBEDDINGS_TAG, + _EMBEDDINGS_OUT_STREAM_NAME + ]) + ], + task_options=options) + return cls(task_info.generate_graph_config()) + + def embed( + self, + text: str, + ) -> TextEmbedderResult: + """Performs text embedding extraction on the provided text. + + Args: + text: The input text. + + Returns: + An embedding result object that contains a list of embeddings. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If text embedder failed to run. + """ + output_packets = self._runner.process( + {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}) + + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + + return TextEmbedderResult.create_from_pb2(embedding_result_proto) + + @staticmethod + def cosine_similarity(u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding) -> float: + """Utility function to compute cosine similarity [1] between two embedding + entries. May return an InvalidArgumentError if e.g. the feature vectors are + of different types (quantized vs. float), have different sizes, or have a + an L2-norm of 0. + + Args: + u: An embedding entry. + v: An embedding entry. + + Returns: + The cosine similarity for the two embeddings. + + Raises: + ValueError: May return an error if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have + an L2-norm of 0 + """ + return cosine_similarity.cosine_similarity(u, v) From fae77fc742e4fe96ac8d86936ba38f7716889c9a Mon Sep 17 00:00:00 2001 From: Kinar R <42828719+kinaryml@users.noreply.github.com> Date: Sat, 12 Nov 2022 01:27:20 +0530 Subject: [PATCH 2/6] Update text_embedder_test.py --- mediapipe/tasks/python/test/text/text_embedder_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index b39f51a8d..8ef616150 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -50,7 +50,7 @@ class ModelFileType(enum.Enum): FILE_NAME = 2 -class ImageEmbedderTest(parameterized.TestCase): +class TextEmbedderTest(parameterized.TestCase): def setUp(self): super().setUp() From a8103629c7d9cc36118b9c066488b71e406b9673 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 12 Nov 2022 07:42:46 -0800 Subject: [PATCH 3/6] Updated Text Embedder API --- mediapipe/tasks/python/text/BUILD | 1 - mediapipe/tasks/python/text/text_embedder.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index d0304fe17..15a5372ea 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -43,7 +43,6 @@ py_library( "text_embedder.py", ], deps = [ - "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_py_pb2", diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index 0f22caca4..fc0faa73e 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,11 +14,9 @@ """MediaPipe text embedder task.""" import dataclasses -from typing import Callable, Mapping, Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter -from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.text.text_embedder.proto import text_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.python.components.processors import embedder_options @@ -70,7 +68,7 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi): """Class that performs embedding extraction on text.""" @classmethod - def create_from_model_path(cls, model_path: str) -> 'ImageEmbedder': + def create_from_model_path(cls, model_path: str) -> 'TextEmbedder': """Creates an `TextEmbedder` object from a TensorFlow Lite model and the default `TextEmbedderOptions`. @@ -147,7 +145,9 @@ class TextEmbedder(base_text_task_api.BaseTextTaskApi): def cosine_similarity(u: embedding_result_module.Embedding, v: embedding_result_module.Embedding) -> float: """Utility function to compute cosine similarity [1] between two embedding - entries. May return an InvalidArgumentError if e.g. the feature vectors are + entries. + + May return an InvalidArgumentError if e.g. the feature vectors are of different types (quantized vs. float), have different sizes, or have a an L2-norm of 0. From 157092d93e80b95962c9c44a4df147bfcaac59cd Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 12 Nov 2022 07:47:32 -0800 Subject: [PATCH 4/6] Removed unused dataclasses --- .../components/containers/embedding_result.py | 36 ------------------- .../python/test/text/text_embedder_test.py | 3 -- 2 files changed, 39 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/embedding_result.py b/mediapipe/tasks/python/components/containers/embedding_result.py index 8ddbb3ae5..63c315603 100644 --- a/mediapipe/tasks/python/components/containers/embedding_result.py +++ b/mediapipe/tasks/python/components/containers/embedding_result.py @@ -26,42 +26,6 @@ _EmbeddingProto = embeddings_pb2.Embedding _EmbeddingResultProto = embeddings_pb2.EmbeddingResult -@dataclasses.dataclass -class FloatEmbedding: - """Defines a dense floating-point embedding. - - Attributes: - values: A NumPy array indicating the raw output of the embedding layer. - """ - - values: np.ndarray - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _FloatEmbeddingProto) -> 'FloatEmbedding': - """Creates a `FloatEmbedding` object from the given protobuf object.""" - return FloatEmbedding(values=np.array(pb2_obj.values, dtype=float)) - - -@dataclasses.dataclass -class QuantizedEmbedding: - """Defines a dense scalar-quantized embedding. - - Attributes: - values: A NumPy array indicating the raw output of the embedding layer. - """ - - values: np.ndarray - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2( - cls, pb2_obj: _QuantizedEmbeddingProto) -> 'QuantizedEmbedding': - """Creates a `QuantizedEmbedding` object from the given protobuf object.""" - return QuantizedEmbedding( - values=np.array(bytearray(pb2_obj.values), dtype=np.uint8)) - - @dataclasses.dataclass class Embedding: """Embedding result for a given embedder head. diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 8ef616150..e1f1a45e8 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -15,7 +15,6 @@ import enum import os -from unittest import mock import numpy as np from absl.testing import absltest @@ -30,8 +29,6 @@ from mediapipe.tasks.python.text import text_embedder ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _EmbedderOptions = embedder_options_module.EmbedderOptions -_FloatEmbedding = embedding_result_module.FloatEmbedding -_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding _Embedding = embedding_result_module.Embedding _TextEmbedder = text_embedder.TextEmbedder _TextEmbedderOptions = text_embedder.TextEmbedderOptions From a3788a23bc1ad376bfc9c4d8bd4e10bc5c8db9c3 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 12 Nov 2022 07:48:55 -0800 Subject: [PATCH 5/6] Removed unused code in image_embedder_test --- mediapipe/tasks/python/test/vision/image_embedder_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index d28320d71..097196fb9 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -35,8 +35,6 @@ ImageEmbedderResult = embedding_result_module.EmbeddingResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions _EmbedderOptions = embedder_options_module.EmbedderOptions -_FloatEmbedding = embedding_result_module.FloatEmbedding -_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding _Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder From a7ed160a8ed410ee21a1574e6261039cb37c8a94 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 12 Nov 2022 08:55:56 -0800 Subject: [PATCH 6/6] Fixed a bug in embedding_result --- .../tasks/python/components/containers/embedding_result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/tasks/python/components/containers/embedding_result.py b/mediapipe/tasks/python/components/containers/embedding_result.py index 63c315603..999f74535 100644 --- a/mediapipe/tasks/python/components/containers/embedding_result.py +++ b/mediapipe/tasks/python/components/containers/embedding_result.py @@ -51,7 +51,7 @@ class Embedding: bytearray(pb2_obj.quantized_embedding.values)) float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float) - if not quantized_embedding: + if not pb2_obj.quantized_embedding.values: return Embedding( embedding=float_embedding, head_index=pb2_obj.head_index,