From 17bb174444ed374578aaa38387f59ab8389b8822 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 9 Nov 2022 11:42:32 -0800 Subject: [PATCH] Refactored embeddings to embedding_result --- .../tasks/python/components/containers/BUILD | 4 +- .../{embeddings.py => embedding_result.py} | 0 mediapipe/tasks/python/components/utils/BUILD | 2 +- .../components/utils/cosine_similarity.py | 4 +- mediapipe/tasks/python/test/vision/BUILD | 2 +- .../python/test/vision/image_embedder_test.py | 14 ++--- mediapipe/tasks/python/vision/BUILD | 3 +- .../tasks/python/vision/image_embedder.py | 52 +++++++++---------- 8 files changed, 40 insertions(+), 41 deletions(-) rename mediapipe/tasks/python/components/containers/{embeddings.py => embedding_result.py} (100%) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 1f01e2955..7091c1f41 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -106,8 +106,8 @@ py_library( ) py_library( - name = "embeddings", - srcs = ["embeddings.py"], + name = "embedding_result", + srcs = ["embedding_result.py"], deps = [ "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/components/containers/embeddings.py b/mediapipe/tasks/python/components/containers/embedding_result.py similarity index 100% rename from mediapipe/tasks/python/components/containers/embeddings.py rename to mediapipe/tasks/python/components/containers/embedding_result.py diff --git a/mediapipe/tasks/python/components/utils/BUILD b/mediapipe/tasks/python/components/utils/BUILD index 6d00bb31a..50d4094c0 100644 --- a/mediapipe/tasks/python/components/utils/BUILD +++ b/mediapipe/tasks/python/components/utils/BUILD @@ -22,7 +22,7 @@ py_library( name = "cosine_similarity", srcs = ["cosine_similarity.py"], deps = [ - "//mediapipe/tasks/python/components/containers:embeddings", + "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/processors:embedder_options", ], ) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index 616f2651f..d6102f0b5 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -15,10 +15,10 @@ import numpy as np -from mediapipe.tasks.python.components.containers import embeddings +from mediapipe.tasks.python.components.containers import embedding_result from mediapipe.tasks.python.components.processors import embedder_options -_Embedding = embeddings.Embedding +_Embedding = embedding_result.Embedding _EmbedderOptions = embedder_options.EmbedderOptions diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 5fd396161..553e1f5a6 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -85,7 +85,7 @@ py_test( "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/utils:cosine_similarity", - "//mediapipe/tasks/python/components/containers:embeddings", + "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 0dfce91c2..e9ff50ed4 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -23,7 +23,7 @@ from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module -from mediapipe.tasks.python.components.containers import embeddings as embeddings_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -31,13 +31,13 @@ from mediapipe.tasks.python.vision import image_embedder from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +ImageEmbedderResult = embedding_result_module.EmbeddingResult _Rect = rect.Rect _BaseOptions = base_options_module.BaseOptions _EmbedderOptions = embedder_options_module.EmbedderOptions -_FloatEmbedding = embeddings_module.FloatEmbedding -_QuantizedEmbedding = embeddings_module.QuantizedEmbedding -_Embedding = embeddings_module.Embedding -_EmbeddingResult = embeddings_module.EmbeddingResult +_FloatEmbedding = embedding_result_module.FloatEmbedding +_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding +_Embedding = embedding_result_module.Embedding _Image = image_module.Image _ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedderOptions = image_embedder.ImageEmbedderOptions @@ -346,7 +346,7 @@ class ImageEmbedderTest(parameterized.TestCase): observed_timestamp_ms = -1 - def check_result(result: _EmbeddingResult, output_image: _Image, + def check_result(result: ImageEmbedderResult, output_image: _Image, timestamp_ms: int): # Checks cosine similarity. self._check_cosine_similarity(result, crop_result, quantize=False, @@ -378,7 +378,7 @@ class ImageEmbedderTest(parameterized.TestCase): image_processing_options = _ImageProcessingOptions(roi) observed_timestamp_ms = -1 - def check_result(result: _EmbeddingResult, output_image: _Image, + def check_result(result: ImageEmbedderResult, output_image: _Image, timestamp_ms: int): # Checks cosine similarity. self._check_cosine_similarity(result, crop_result, quantize=False, diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 3c040fb4d..0af8f07e1 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -113,7 +113,8 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", - "//mediapipe/tasks/python/components/containers:embeddings", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index 2851970d8..bec9682d0 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -22,9 +22,10 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.utils import cosine_similarity -from mediapipe.tasks.python.components.containers import embeddings as embeddings_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -32,7 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module -_ImageEmbedderResult = embeddings_module.EmbeddingResult +ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _EmbedderOptions = embedder_options.EmbedderOptions @@ -76,7 +77,7 @@ class ImageEmbedderOptions: quantize: Optional[bool] = None embedder_options: _EmbedderOptions = _EmbedderOptions() result_callback: Optional[ - Callable[[embeddings_module.EmbeddingResult, image_module.Image, + Callable[[ImageEmbedderResult, image_module.Image, int], None]] = None @doc_controls.do_not_generate_docs @@ -140,17 +141,17 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): def packets_callback(output_packets: Mapping[str, packet_module.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return - embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) - embeddings = embeddings_module.EmbeddingResult([ - embeddings_module.Embedding.create_from_pb2(embedding) - for embedding in embedding_result_proto.embeddings - ]) + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp - options.result_callback(embeddings, image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + options.result_callback( + ImageEmbedderResult.create_from_pb2(embedding_result_proto), + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -174,7 +175,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> _ImageEmbedderResult: + ) -> ImageEmbedderResult: """Performs image embedding extraction on the provided MediaPipe Image. Extraction is performed on the region of interest specified by the `roi` argument if provided, or on the entire image otherwise. @@ -195,19 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto( normalized_rect.to_pb2())}) - embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) - return embeddings_module.EmbeddingResult([ - embeddings_module.Embedding.create_from_pb2(embedding) - for embedding in embedding_result_proto.embeddings - ]) + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + + return ImageEmbedderResult.create_from_pb2(embedding_result_proto) def embed_for_video( self, image: image_module.Image, timestamp_ms: int, image_processing_options: Optional[_ImageProcessingOptions] = None - ) -> _ImageEmbedderResult: + ) -> ImageEmbedderResult: """Performs image embedding extraction on the provided video frames. Extraction is performed on the region of interested specified by the `roi` argument if provided, or on the entire image otherwise. @@ -237,13 +237,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2()).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) }) - embedding_result_proto = packet_getter.get_proto( - output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) - return embeddings_module.EmbeddingResult([ - embeddings_module.Embedding.create_from_pb2(embedding) - for embedding in embedding_result_proto.embeddings - ]) + return ImageEmbedderResult.create_from_pb2(embedding_result_proto) def embed_async( self, @@ -290,8 +288,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): }) @staticmethod - def cosine_similarity(u: embeddings_module.Embedding, - v: embeddings_module.Embedding) -> float: + def cosine_similarity(u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding) -> float: """Utility function to compute cosine similarity [1] between two embedding entries. May return an InvalidArgumentError if e.g. the feature vectors are of different types (quantized vs. float), have different sizes, or have a