Refactored embeddings to embedding_result
This commit is contained in:
parent
dc30cf9732
commit
17bb174444
|
@ -106,8 +106,8 @@ py_library(
|
|||
)
|
||||
|
||||
py_library(
|
||||
name = "embeddings",
|
||||
srcs = ["embeddings.py"],
|
||||
name = "embedding_result",
|
||||
srcs = ["embedding_result.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -22,7 +22,7 @@ py_library(
|
|||
name = "cosine_similarity",
|
||||
srcs = ["cosine_similarity.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.components.containers import embeddings
|
||||
from mediapipe.tasks.python.components.containers import embedding_result
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
|
||||
_Embedding = embeddings.Embedding
|
||||
_Embedding = embedding_result.Embedding
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ py_test(
|
|||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
|
|
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
|||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
@ -31,13 +31,13 @@ from mediapipe.tasks.python.vision import image_embedder
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
||||
_Embedding = embeddings_module.Embedding
|
||||
_EmbeddingResult = embeddings_module.EmbeddingResult
|
||||
_FloatEmbedding = embedding_result_module.FloatEmbedding
|
||||
_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding
|
||||
_Embedding = embedding_result_module.Embedding
|
||||
_Image = image_module.Image
|
||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
||||
|
@ -346,7 +346,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||
|
@ -378,7 +378,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
|||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||
|
|
|
@ -113,7 +113,8 @@ py_library(
|
|||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -22,9 +22,10 @@ from mediapipe.python._framework_bindings import image as image_module
|
|||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||
from mediapipe.tasks.python.components.processors import embedder_options
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -32,7 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_ImageEmbedderResult = embeddings_module.EmbeddingResult
|
||||
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
@ -76,7 +77,7 @@ class ImageEmbedderOptions:
|
|||
quantize: Optional[bool] = None
|
||||
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
||||
result_callback: Optional[
|
||||
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
|
||||
Callable[[ImageEmbedderResult, image_module.Image,
|
||||
int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -140,17 +141,17 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
|
||||
embeddings = embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(embeddings, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
options.result_callback(
|
||||
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -174,7 +175,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> _ImageEmbedderResult:
|
||||
) -> ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||
Extraction is performed on the region of interest specified by the `roi`
|
||||
argument if provided, or on the entire image otherwise.
|
||||
|
@ -195,19 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2())})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
|
||||
return embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
|
||||
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||
|
||||
def embed_for_video(
|
||||
self, image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> _ImageEmbedderResult:
|
||||
) -> ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided video frames.
|
||||
Extraction is performed on the region of interested specified by the `roi`
|
||||
argument if provided, or on the entire image otherwise.
|
||||
|
@ -237,13 +237,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
|
||||
return embeddings_module.EmbeddingResult([
|
||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
||||
for embedding in embedding_result_proto.embeddings
|
||||
])
|
||||
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||
|
||||
def embed_async(
|
||||
self,
|
||||
|
@ -290,8 +288,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(u: embeddings_module.Embedding,
|
||||
v: embeddings_module.Embedding) -> float:
|
||||
def cosine_similarity(u: embedding_result_module.Embedding,
|
||||
v: embedding_result_module.Embedding) -> float:
|
||||
"""Utility function to compute cosine similarity [1] between two embedding
|
||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
of different types (quantized vs. float), have different sizes, or have a
|
||||
|
|
Loading…
Reference in New Issue
Block a user