Refactored embeddings to embedding_result
This commit is contained in:
parent
dc30cf9732
commit
17bb174444
|
@ -106,8 +106,8 @@ py_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "embeddings",
|
name = "embedding_result",
|
||||||
srcs = ["embeddings.py"],
|
srcs = ["embedding_result.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
|
|
@ -22,7 +22,7 @@ py_library(
|
||||||
name = "cosine_similarity",
|
name = "cosine_similarity",
|
||||||
srcs = ["cosine_similarity.py"],
|
srcs = ["cosine_similarity.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,10 +15,10 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mediapipe.tasks.python.components.containers import embeddings
|
from mediapipe.tasks.python.components.containers import embedding_result
|
||||||
from mediapipe.tasks.python.components.processors import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
|
||||||
_Embedding = embeddings.Embedding
|
_Embedding = embedding_result.Embedding
|
||||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ py_test(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/processors:embedder_options",
|
"//mediapipe/tasks/python/components/processors:embedder_options",
|
||||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
|
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||||
from mediapipe.tasks.python.components.containers import rect
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
@ -31,13 +31,13 @@ from mediapipe.tasks.python.vision import image_embedder
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||||
_Rect = rect.Rect
|
_Rect = rect.Rect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
_FloatEmbedding = embedding_result_module.FloatEmbedding
|
||||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
_QuantizedEmbedding = embedding_result_module.QuantizedEmbedding
|
||||||
_Embedding = embeddings_module.Embedding
|
_Embedding = embedding_result_module.Embedding
|
||||||
_EmbeddingResult = embeddings_module.EmbeddingResult
|
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||||
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
||||||
|
@ -346,7 +346,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
|
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||||
|
@ -378,7 +378,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
image_processing_options = _ImageProcessingOptions(roi)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||||
|
|
|
@ -113,7 +113,8 @@ py_library(
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
|
|
@ -22,9 +22,10 @@ from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||||
from mediapipe.tasks.python.components.processors import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
@ -32,7 +33,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_ImageEmbedderResult = embeddings_module.EmbeddingResult
|
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
@ -76,7 +77,7 @@ class ImageEmbedderOptions:
|
||||||
quantize: Optional[bool] = None
|
quantize: Optional[bool] = None
|
||||||
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
|
Callable[[ImageEmbedderResult, image_module.Image,
|
||||||
int], None]] = None
|
int], None]] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
|
@ -140,16 +141,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
|
||||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
|
||||||
|
|
||||||
embeddings = embeddings_module.EmbeddingResult([
|
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
embedding_result_proto.CopyFrom(
|
||||||
for embedding in embedding_result_proto.embeddings
|
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||||
])
|
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(embeddings, image,
|
options.result_callback(
|
||||||
|
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
|
||||||
|
image,
|
||||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
|
@ -174,7 +175,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> _ImageEmbedderResult:
|
) -> ImageEmbedderResult:
|
||||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||||
Extraction is performed on the region of interest specified by the `roi`
|
Extraction is performed on the region of interest specified by the `roi`
|
||||||
argument if provided, or on the entire image otherwise.
|
argument if provided, or on the entire image otherwise.
|
||||||
|
@ -195,19 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
normalized_rect.to_pb2())})
|
normalized_rect.to_pb2())})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
|
||||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
|
||||||
|
|
||||||
return embeddings_module.EmbeddingResult([
|
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
embedding_result_proto.CopyFrom(
|
||||||
for embedding in embedding_result_proto.embeddings
|
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||||
])
|
|
||||||
|
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||||
|
|
||||||
def embed_for_video(
|
def embed_for_video(
|
||||||
self, image: image_module.Image,
|
self, image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> _ImageEmbedderResult:
|
) -> ImageEmbedderResult:
|
||||||
"""Performs image embedding extraction on the provided video frames.
|
"""Performs image embedding extraction on the provided video frames.
|
||||||
Extraction is performed on the region of interested specified by the `roi`
|
Extraction is performed on the region of interested specified by the `roi`
|
||||||
argument if provided, or on the entire image otherwise.
|
argument if provided, or on the entire image otherwise.
|
||||||
|
@ -237,13 +237,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()).at(
|
normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||||
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
embedding_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
return embeddings_module.EmbeddingResult([
|
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||||
embeddings_module.Embedding.create_from_pb2(embedding)
|
|
||||||
for embedding in embedding_result_proto.embeddings
|
|
||||||
])
|
|
||||||
|
|
||||||
def embed_async(
|
def embed_async(
|
||||||
self,
|
self,
|
||||||
|
@ -290,8 +288,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
})
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cosine_similarity(u: embeddings_module.Embedding,
|
def cosine_similarity(u: embedding_result_module.Embedding,
|
||||||
v: embeddings_module.Embedding) -> float:
|
v: embedding_result_module.Embedding) -> float:
|
||||||
"""Utility function to compute cosine similarity [1] between two embedding
|
"""Utility function to compute cosine similarity [1] between two embedding
|
||||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
of different types (quantized vs. float), have different sizes, or have a
|
of different types (quantized vs. float), have different sizes, or have a
|
||||||
|
|
Loading…
Reference in New Issue
Block a user