From 664d9c49e7cf120b753fb572c936be569e93b217 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 7 Nov 2022 13:59:07 -0800 Subject: [PATCH] Revised image embedder implementation --- mediapipe/python/BUILD | 1 - .../components/containers/embeddings.py | 99 ++++++------------- .../tasks/python/components/processors/BUILD | 9 ++ .../{proto => processors}/embedder_options.py | 2 +- mediapipe/tasks/python/components/proto/BUILD | 28 ------ .../tasks/python/components/proto/__init__.py | 13 --- mediapipe/tasks/python/components/utils/BUILD | 2 +- .../components/utils/cosine_similarity.py | 16 +-- mediapipe/tasks/python/test/vision/BUILD | 2 +- .../python/test/vision/image_embedder_test.py | 17 ++-- mediapipe/tasks/python/vision/BUILD | 1 + .../tasks/python/vision/image_embedder.py | 39 ++++---- 12 files changed, 79 insertions(+), 150 deletions(-) rename mediapipe/tasks/python/components/{proto => processors}/embedder_options.py (96%) delete mode 100644 mediapipe/tasks/python/components/proto/BUILD delete mode 100644 mediapipe/tasks/python/components/proto/__init__.py diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 2423370e6..0f049f305 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -92,7 +92,6 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", - ], ] + select({ # TODO: Build text_classifier_graph on Windows. "//mediapipe:windows": [], diff --git a/mediapipe/tasks/python/components/containers/embeddings.py b/mediapipe/tasks/python/components/containers/embeddings.py index c1185c84f..3a024079a 100644 --- a/mediapipe/tasks/python/components/containers/embeddings.py +++ b/mediapipe/tasks/python/components/containers/embeddings.py @@ -22,8 +22,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls _FloatEmbeddingProto = embeddings_pb2.FloatEmbedding _QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding -_EmbeddingEntryProto = embeddings_pb2.EmbeddingEntry -_EmbeddingsProto = embeddings_pb2.Embeddings +_EmbeddingProto = embeddings_pb2.Embedding _EmbeddingResultProto = embeddings_pb2.EmbeddingResult @@ -99,28 +98,34 @@ class QuantizedEmbedding: @dataclasses.dataclass -class EmbeddingEntry: - """Floating-point or scalar-quantized embedding with an optional timestamp. +class Embedding: + """Embedding result for a given embedder head. Attributes: embedding: The actual embedding, either floating-point or scalar-quantized. - timestamp_ms: The optional timestamp (in milliseconds) associated to the - embedding entry. This is useful for time series use cases, e.g. audio - embedding. + head_index: The index of the embedder head that produced this embedding. + This is useful for multi-head models. + head_name: The name of the embedder head, which is the corresponding tensor + metadata name (if any). This is useful for multi-head models. """ embedding: np.ndarray - timestamp_ms: Optional[int] = None + head_index: int + head_name: str @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbeddingEntryProto: - """Generates a EmbeddingEntry protobuf object.""" + def to_pb2(self) -> _EmbeddingProto: + """Generates a Embedding protobuf object.""" if self.embedding.dtype == float: - return _EmbeddingEntryProto(float_embedding=self.embedding) + return _EmbeddingProto(float_embedding=self.embedding, + head_index=self.head_index, + head_name=self.head_name) elif self.embedding.dtype == np.uint8: - return _EmbeddingEntryProto(quantized_embedding=bytes(self.embedding)) + return _EmbeddingProto(quantized_embedding=bytes(self.embedding), + head_index=self.head_index, + head_name=self.head_name) else: raise ValueError("Invalid dtype. Only float and np.uint8 are supported.") @@ -128,17 +133,21 @@ class EmbeddingEntry: @classmethod @doc_controls.do_not_generate_docs def create_from_pb2( - cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry': - """Creates a `EmbeddingEntry` object from the given protobuf object.""" + cls, pb2_obj: _EmbeddingProto) -> 'Embedding': + """Creates a `Embedding` object from the given protobuf object.""" quantized_embedding = np.array( bytearray(pb2_obj.quantized_embedding.values)) float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float) if len(quantized_embedding) == 0: - return EmbeddingEntry(embedding=float_embedding) + return Embedding(embedding=float_embedding, + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) else: - return EmbeddingEntry(embedding=quantized_embedding) + return Embedding(embedding=quantized_embedding, + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. @@ -147,55 +156,7 @@ class EmbeddingEntry: Returns: True if the objects are equal. """ - if not isinstance(other, EmbeddingEntry): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class Embeddings: - """Embeddings for a given embedder head. - Attributes: - entries: A list of `ClassificationEntry` objects. - head_index: The index of the embedder head that produced this embedding. - This is useful for multi-head models. - head_name: The name of the embedder head, which is the corresponding tensor - metadata name (if any). This is useful for multi-head models. - """ - - entries: List[EmbeddingEntry] - head_index: int - head_name: str - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _EmbeddingsProto: - """Generates a Embeddings protobuf object.""" - return _EmbeddingsProto( - entries=[entry.to_pb2() for entry in self.entries], - head_index=self.head_index, - head_name=self.head_name) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _EmbeddingsProto) -> 'Embeddings': - """Creates a `Embeddings` object from the given protobuf object.""" - return Embeddings( - entries=[ - EmbeddingEntry.create_from_pb2(entry) - for entry in pb2_obj.entries - ], - head_index=pb2_obj.head_index, - head_name=pb2_obj.head_name) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, Embeddings): + if not isinstance(other, Embedding): return False return self.to_pb2().__eq__(other.to_pb2()) @@ -203,12 +164,12 @@ class Embeddings: @dataclasses.dataclass class EmbeddingResult: - """Contains one set of results per embedder head. + """Embedding results for a given embedder model. Attributes: - embeddings: A list of `Embeddings` objects. + embeddings: A list of `Embedding` objects. """ - embeddings: List[Embeddings] + embeddings: List[Embedding] @doc_controls.do_not_generate_docs def to_pb2(self) -> _EmbeddingResultProto: @@ -225,7 +186,7 @@ class EmbeddingResult: """Creates a `EmbeddingResult` object from the given protobuf object.""" return EmbeddingResult( embeddings=[ - Embeddings.create_from_pb2(embedding) + Embedding.create_from_pb2(embedding) for embedding in pb2_obj.embeddings ]) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD index f87a579b0..eef368db0 100644 --- a/mediapipe/tasks/python/components/processors/BUILD +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -28,3 +28,12 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "embedder_options", + srcs = ["embedder_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/proto/embedder_options.py b/mediapipe/tasks/python/components/processors/embedder_options.py similarity index 96% rename from mediapipe/tasks/python/components/proto/embedder_options.py rename to mediapipe/tasks/python/components/processors/embedder_options.py index 49bcfb985..dcd316dcd 100644 --- a/mediapipe/tasks/python/components/proto/embedder_options.py +++ b/mediapipe/tasks/python/components/processors/embedder_options.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.proto import embedder_options_pb2 +from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions diff --git a/mediapipe/tasks/python/components/proto/BUILD b/mediapipe/tasks/python/components/proto/BUILD deleted file mode 100644 index 973f150ca..000000000 --- a/mediapipe/tasks/python/components/proto/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Placeholder for internal Python strict library compatibility macro. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -py_library( - name = "embedder_options", - srcs = ["embedder_options.py"], - deps = [ - "//mediapipe/tasks/cc/components/proto:embedder_options_py_pb2", - "//mediapipe/tasks/python/core:optional_dependencies", - ], -) diff --git a/mediapipe/tasks/python/components/proto/__init__.py b/mediapipe/tasks/python/components/proto/__init__.py deleted file mode 100644 index 65c1214af..000000000 --- a/mediapipe/tasks/python/components/proto/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index 7ec01a034..6d00bb31a 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -23,6 +23,6 @@ py_library( srcs = ["cosine_similarity.py"], deps = [ "//mediapipe/tasks/python/components/containers:embeddings", - "//mediapipe/tasks/python/components/proto:embedder_options", + "//mediapipe/tasks/python/components/processors:embedder_options", ], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 8723a55eb..616f2651f 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -16,9 +16,9 @@ import numpy as np from mediapipe.tasks.python.components.containers import embeddings -from mediapipe.tasks.python.components.proto import embedder_options +from mediapipe.tasks.python.components.processors import embedder_options -_EmbeddingEntry = embeddings.EmbeddingEntry +_Embedding = embeddings.Embedding _EmbedderOptions = embedder_options.EmbedderOptions @@ -36,15 +36,15 @@ def _compute_cosine_similarity(u, v): return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) -def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float: - """Utility function to compute cosine similarity between two embedding - entries. May return an InvalidArgumentError if e.g. the feature vectors are - of different types (quantized vs. float), have different sizes, or have an +def cosine_similarity(u: _Embedding, v: _Embedding) -> float: + """Utility function to compute cosine similarity between two embedding. + May return an InvalidArgumentError if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have an L2-norm of 0. Args: - u: An embedding entry. - v: An embedding entry. + u: An embedding. + v: An embedding. """ if len(u.embedding) != len(v.embedding): raise ValueError(f"Cannot compute cosine similarity between embeddings " diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 9fee8a023..5fd396161 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -83,7 +83,7 @@ py_test( ], deps = [ "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/proto:embedder_options", + "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:rect", diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 4f109ea29..7bcf48d71 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -22,7 +22,7 @@ from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module +from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.core import base_options as base_options_module @@ -36,8 +36,7 @@ _BaseOptions = base_options_module.BaseOptions _EmbedderOptions = embedder_options_module.EmbedderOptions _FloatEmbedding = embeddings_module.FloatEmbedding _QuantizedEmbedding = embeddings_module.QuantizedEmbedding -_EmbeddingEntry = embeddings_module.EmbeddingEntry -_Embeddings = embeddings_module.Embeddings +_Embedding = embeddings_module.Embedding _EmbeddingResult = embeddings_module.EmbeddingResult _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder @@ -81,12 +80,12 @@ class ImageEmbedderTest(parameterized.TestCase): # Check embedding sizes. def _check_embedding_size(result): self.assertLen(result.embeddings, 1) - embedding_entry = result.embeddings[0].entries[0] - self.assertLen(embedding_entry.embedding, 1024) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, 1024) if quantize: - self.assertEqual(embedding_entry.embedding.dtype, np.uint8) + self.assertEqual(embedding_result.embedding.dtype, np.uint8) else: - self.assertEqual(embedding_entry.embedding.dtype, float) + self.assertEqual(embedding_result.embedding.dtype, float) # Checks results sizes. _check_embedding_size(result0) @@ -94,7 +93,7 @@ class ImageEmbedderTest(parameterized.TestCase): # Checks cosine similarity. similarity = _ImageEmbedder.cosine_similarity( - result0.embeddings[0].entries[0], result1.embeddings[0].entries[0]) + result0.embeddings[0], result1.embeddings[0]) self.assertAlmostEqual(similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) @@ -134,7 +133,7 @@ class ImageEmbedderTest(parameterized.TestCase): crop_result = embedder.embed(self.test_cropped_image) # Check embedding value. - self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0], + self.assertAlmostEqual(image_result.embeddings[0].embedding[0], expected_first_value) # Checks cosine similarity. diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 9d406be28..1d60fa5b5 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -114,6 +114,7 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:embeddings", + "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index e287593f5..2851970d8 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 -from mediapipe.tasks.python.components.proto import embedder_options +from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.core import base_options as base_options_module @@ -32,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_ImageEmbedderResult = embeddings_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _EmbedderOptions = embedder_options.EmbedderOptions @@ -40,8 +41,8 @@ _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskRunner = task_runner_module.TaskRunner -_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out' -_EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT' +_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' +_EMBEDDINGS_TAG = 'EMBEDDINGS' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' @@ -140,15 +141,15 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) + output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) - embedding_result = embeddings_module.EmbeddingResult([ - embeddings_module.Embeddings.create_from_pb2(embedding) + embeddings = embeddings_module.EmbeddingResult([ + embeddings_module.Embedding.create_from_pb2(embedding) for embedding in embedding_result_proto.embeddings ]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp - options.result_callback(embedding_result, image, + options.result_callback(embeddings, image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( @@ -158,8 +159,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], output_streams=[ - ':'.join([_EMBEDDING_RESULT_TAG, - _EMBEDDING_RESULT_OUT_STREAM_NAME]), + ':'.join([_EMBEDDINGS_TAG, + _EMBEDDINGS_OUT_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ], task_options=options) @@ -173,7 +174,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> embeddings_module.EmbeddingResult: + ) -> _ImageEmbedderResult: """Performs image embedding extraction on the provided MediaPipe Image. Extraction is performed on the region of interest specified by the `roi` argument if provided, or on the entire image otherwise. @@ -195,18 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): _NORM_RECT_STREAM_NAME: packet_creator.create_proto( normalized_rect.to_pb2())}) embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) + output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) return embeddings_module.EmbeddingResult([ - embeddings_module.Embeddings.create_from_pb2(embedding) - for embedding in embedding_result_proto.embeddings + embeddings_module.Embedding.create_from_pb2(embedding) + for embedding in embedding_result_proto.embeddings ]) def embed_for_video( self, image: image_module.Image, timestamp_ms: int, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> embeddings_module.EmbeddingResult: + ) -> _ImageEmbedderResult: """Performs image embedding extraction on the provided video frames. Extraction is performed on the region of interested specified by the `roi` argument if provided, or on the entire image otherwise. @@ -237,11 +238,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) + output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) return embeddings_module.EmbeddingResult([ - embeddings_module.Embeddings.create_from_pb2(embedding) - for embedding in embedding_result_proto.embeddings + embeddings_module.Embedding.create_from_pb2(embedding) + for embedding in embedding_result_proto.embeddings ]) def embed_async( @@ -289,8 +290,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): }) @staticmethod - def cosine_similarity(u: embeddings_module.EmbeddingEntry, - v: embeddings_module.EmbeddingEntry) -> float: + def cosine_similarity(u: embeddings_module.Embedding, + v: embeddings_module.Embedding) -> float: """Utility function to compute cosine similarity [1] between two embedding entries. May return an InvalidArgumentError if e.g. the feature vectors are of different types (quantized vs. float), have different sizes, or have a