Refactored embeddings to embedding_result

This commit is contained in:
kinaryml 2022-11-09 11:42:32 -08:00
parent dc30cf9732
commit 17bb174444
8 changed files with 40 additions and 41 deletions

View File

@ -106,8 +106,8 @@ py_library(
) )
py_library( py_library(
name = "embeddings", name = "embedding_result",
srcs = ["embeddings.py"], srcs = ["embedding_result.py"],
deps = [ deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -22,7 +22,7 @@ py_library(
name = "cosine_similarity", name = "cosine_similarity",
srcs = ["cosine_similarity.py"], srcs = ["cosine_similarity.py"],
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/processors:embedder_options",
], ],
) )

View File

@ -15,10 +15,10 @@
import numpy as np import numpy as np
from mediapipe.tasks.python.components.containers import embeddings from mediapipe.tasks.python.components.containers import embedding_result
from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.processors import embedder_options
_Embedding = embeddings.Embedding _Embedding = embedding_result.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions

View File

@ -85,7 +85,7 @@ py_test(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/processors:embedder_options", "//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -31,13 +31,13 @@ from mediapipe.tasks.python.vision import image_embedder
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
ImageEmbedderResult = embedding_result_module.EmbeddingResult
_Rect = rect.Rect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions _EmbedderOptions = embedder_options_module.EmbedderOptions
_FloatEmbedding = embeddings_module.FloatEmbedding _FloatEmbedding = embedding_result_module.FloatEmbedding
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding _QuantizedEmbedding = embedding_result_module.QuantizedEmbedding
_Embedding = embeddings_module.Embedding _Embedding = embedding_result_module.Embedding
_EmbeddingResult = embeddings_module.EmbeddingResult
_Image = image_module.Image _Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedder = image_embedder.ImageEmbedder
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions _ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
@ -346,7 +346,7 @@ class ImageEmbedderTest(parameterized.TestCase):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image, def check_result(result: ImageEmbedderResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
# Checks cosine similarity. # Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False, self._check_cosine_similarity(result, crop_result, quantize=False,
@ -378,7 +378,7 @@ class ImageEmbedderTest(parameterized.TestCase):
image_processing_options = _ImageProcessingOptions(roi) image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image, def check_result(result: ImageEmbedderResult, output_image: _Image,
timestamp_ms: int): timestamp_ms: int):
# Checks cosine similarity. # Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False, self._check_cosine_similarity(result, crop_result, quantize=False,

View File

@ -113,7 +113,8 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -22,9 +22,10 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.python.components.processors import embedder_options from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -32,7 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_ImageEmbedderResult = embeddings_module.EmbeddingResult ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions
@ -76,7 +77,7 @@ class ImageEmbedderOptions:
quantize: Optional[bool] = None quantize: Optional[bool] = None
embedder_options: _EmbedderOptions = _EmbedderOptions() embedder_options: _EmbedderOptions = _EmbedderOptions()
result_callback: Optional[ result_callback: Optional[
Callable[[embeddings_module.EmbeddingResult, image_module.Image, Callable[[ImageEmbedderResult, image_module.Image,
int], None]] = None int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -140,16 +141,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet_module.Packet]): def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
embeddings = embeddings_module.EmbeddingResult([ embedding_result_proto = embeddings_pb2.EmbeddingResult()
embeddings_module.Embedding.create_from_pb2(embedding) embedding_result_proto.CopyFrom(
for embedding in embedding_result_proto.embeddings packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(embeddings, image, options.result_callback(
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
@ -174,7 +175,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> _ImageEmbedderResult: ) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image. """Performs image embedding extraction on the provided MediaPipe Image.
Extraction is performed on the region of interest specified by the `roi` Extraction is performed on the region of interest specified by the `roi`
argument if provided, or on the entire image otherwise. argument if provided, or on the entire image otherwise.
@ -195,19 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())}) normalized_rect.to_pb2())})
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([ embedding_result_proto = embeddings_pb2.EmbeddingResult()
embeddings_module.Embedding.create_from_pb2(embedding) embedding_result_proto.CopyFrom(
for embedding in embedding_result_proto.embeddings packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
])
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
def embed_for_video( def embed_for_video(
self, image: image_module.Image, self, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> _ImageEmbedderResult: ) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames. """Performs image embedding extraction on the provided video frames.
Extraction is performed on the region of interested specified by the `roi` Extraction is performed on the region of interested specified by the `roi`
argument if provided, or on the entire image otherwise. argument if provided, or on the entire image otherwise.
@ -237,13 +237,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()).at( normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = embeddings_pb2.EmbeddingResult()
output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
return embeddings_module.EmbeddingResult([ return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
def embed_async( def embed_async(
self, self,
@ -290,8 +288,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
}) })
@staticmethod @staticmethod
def cosine_similarity(u: embeddings_module.Embedding, def cosine_similarity(u: embedding_result_module.Embedding,
v: embeddings_module.Embedding) -> float: v: embedding_result_module.Embedding) -> float:
"""Utility function to compute cosine similarity [1] between two embedding """Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a of different types (quantized vs. float), have different sizes, or have a