Refactored embeddings to embedding_result

This commit is contained in:
kinaryml 2022-11-09 11:42:32 -08:00
parent dc30cf9732
commit 17bb174444
8 changed files with 40 additions and 41 deletions

View File

@ -106,8 +106,8 @@ py_library(
)
py_library(
name = "embeddings",
srcs = ["embeddings.py"],
name = "embedding_result",
srcs = ["embedding_result.py"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -22,7 +22,7 @@ py_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.py"],
deps = [
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
],
)

View File

@ -15,10 +15,10 @@
import numpy as np
from mediapipe.tasks.python.components.containers import embeddings
from mediapipe.tasks.python.components.containers import embedding_result
from mediapipe.tasks.python.components.processors import embedder_options
_Embedding = embeddings.Embedding
_Embedding = embedding_result.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions

View File

@ -85,7 +85,7 @@ py_test(
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",

View File

@ -23,7 +23,7 @@ from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
@ -31,13 +31,13 @@ from mediapipe.tasks.python.vision import image_embedder
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
ImageEmbedderResult = embedding_result_module.EmbeddingResult
_Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_FloatEmbedding = embeddings_module.FloatEmbedding
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
_Embedding = embeddings_module.Embedding
_EmbeddingResult = embeddings_module.EmbeddingResult
_FloatEmbedding = embedding_result_module.FloatEmbedding
_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding
_Embedding = embedding_result_module.Embedding
_Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
@ -346,7 +346,7 @@ class ImageEmbedderTest(parameterized.TestCase):
observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image,
def check_result(result: ImageEmbedderResult, output_image: _Image,
timestamp_ms: int):
# Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False,
@ -378,7 +378,7 @@ class ImageEmbedderTest(parameterized.TestCase):
image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image,
def check_result(result: ImageEmbedderResult, output_image: _Image,
timestamp_ms: int):
# Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False,

View File

@ -113,7 +113,8 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -22,9 +22,10 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -32,7 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_ImageEmbedderResult = embeddings_module.EmbeddingResult
ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions
@ -76,7 +77,7 @@ class ImageEmbedderOptions:
quantize: Optional[bool] = None
embedder_options: _EmbedderOptions = _EmbedderOptions()
result_callback: Optional[
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
Callable[[ImageEmbedderResult, image_module.Image,
int], None]] = None
@doc_controls.do_not_generate_docs
@ -140,17 +141,17 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
embeddings = embeddings_module.EmbeddingResult([
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(embeddings, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
options.result_callback(
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -174,7 +175,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> _ImageEmbedderResult:
) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image.
Extraction is performed on the region of interest specified by the `roi`
argument if provided, or on the entire image otherwise.
@ -195,19 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())})
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
def embed_for_video(
self, image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> _ImageEmbedderResult:
) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames.
Extraction is performed on the region of interested specified by the `roi`
argument if provided, or on the entire image otherwise.
@ -237,13 +237,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
return embeddings_module.EmbeddingResult([
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
def embed_async(
self,
@ -290,8 +288,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
})
@staticmethod
def cosine_similarity(u: embeddings_module.Embedding,
v: embeddings_module.Embedding) -> float:
def cosine_similarity(u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding) -> float:
"""Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a