Revised image embedder implementation

This commit is contained in:
kinaryml 2022-11-07 13:59:07 -08:00
parent ba1ee5b404
commit 664d9c49e7
12 changed files with 79 additions and 150 deletions

View File

@ -92,7 +92,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
],
] + select({ ] + select({
# TODO: Build text_classifier_graph on Windows. # TODO: Build text_classifier_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],

View File

@ -22,8 +22,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding _FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding _QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
_EmbeddingEntryProto = embeddings_pb2.EmbeddingEntry _EmbeddingProto = embeddings_pb2.Embedding
_EmbeddingsProto = embeddings_pb2.Embeddings
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult _EmbeddingResultProto = embeddings_pb2.EmbeddingResult
@ -99,28 +98,34 @@ class QuantizedEmbedding:
@dataclasses.dataclass @dataclasses.dataclass
class EmbeddingEntry: class Embedding:
"""Floating-point or scalar-quantized embedding with an optional timestamp. """Embedding result for a given embedder head.
Attributes: Attributes:
embedding: The actual embedding, either floating-point or scalar-quantized. embedding: The actual embedding, either floating-point or scalar-quantized.
timestamp_ms: The optional timestamp (in milliseconds) associated to the head_index: The index of the embedder head that produced this embedding.
embedding entry. This is useful for time series use cases, e.g. audio This is useful for multi-head models.
embedding. head_name: The name of the embedder head, which is the corresponding tensor
metadata name (if any). This is useful for multi-head models.
""" """
embedding: np.ndarray embedding: np.ndarray
timestamp_ms: Optional[int] = None head_index: int
head_name: str
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingEntryProto: def to_pb2(self) -> _EmbeddingProto:
"""Generates a EmbeddingEntry protobuf object.""" """Generates a Embedding protobuf object."""
if self.embedding.dtype == float: if self.embedding.dtype == float:
return _EmbeddingEntryProto(float_embedding=self.embedding) return _EmbeddingProto(float_embedding=self.embedding,
head_index=self.head_index,
head_name=self.head_name)
elif self.embedding.dtype == np.uint8: elif self.embedding.dtype == np.uint8:
return _EmbeddingEntryProto(quantized_embedding=bytes(self.embedding)) return _EmbeddingProto(quantized_embedding=bytes(self.embedding),
head_index=self.head_index,
head_name=self.head_name)
else: else:
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.") raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
@ -128,17 +133,21 @@ class EmbeddingEntry:
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2( def create_from_pb2(
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry': cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
"""Creates a `EmbeddingEntry` object from the given protobuf object.""" """Creates a `Embedding` object from the given protobuf object."""
quantized_embedding = np.array( quantized_embedding = np.array(
bytearray(pb2_obj.quantized_embedding.values)) bytearray(pb2_obj.quantized_embedding.values))
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float) float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
if len(quantized_embedding) == 0: if len(quantized_embedding) == 0:
return EmbeddingEntry(embedding=float_embedding) return Embedding(embedding=float_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
else: else:
return EmbeddingEntry(embedding=quantized_embedding) return Embedding(embedding=quantized_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.
@ -147,55 +156,7 @@ class EmbeddingEntry:
Returns: Returns:
True if the objects are equal. True if the objects are equal.
""" """
if not isinstance(other, EmbeddingEntry): if not isinstance(other, Embedding):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class Embeddings:
"""Embeddings for a given embedder head.
Attributes:
entries: A list of `ClassificationEntry` objects.
head_index: The index of the embedder head that produced this embedding.
This is useful for multi-head models.
head_name: The name of the embedder head, which is the corresponding tensor
metadata name (if any). This is useful for multi-head models.
"""
entries: List[EmbeddingEntry]
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingsProto:
"""Generates a Embeddings protobuf object."""
return _EmbeddingsProto(
entries=[entry.to_pb2() for entry in self.entries],
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbeddingsProto) -> 'Embeddings':
"""Creates a `Embeddings` object from the given protobuf object."""
return Embeddings(
entries=[
EmbeddingEntry.create_from_pb2(entry)
for entry in pb2_obj.entries
],
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Embeddings):
return False return False
return self.to_pb2().__eq__(other.to_pb2()) return self.to_pb2().__eq__(other.to_pb2())
@ -203,12 +164,12 @@ class Embeddings:
@dataclasses.dataclass @dataclasses.dataclass
class EmbeddingResult: class EmbeddingResult:
"""Contains one set of results per embedder head. """Embedding results for a given embedder model.
Attributes: Attributes:
embeddings: A list of `Embeddings` objects. embeddings: A list of `Embedding` objects.
""" """
embeddings: List[Embeddings] embeddings: List[Embedding]
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingResultProto: def to_pb2(self) -> _EmbeddingResultProto:
@ -225,7 +186,7 @@ class EmbeddingResult:
"""Creates a `EmbeddingResult` object from the given protobuf object.""" """Creates a `EmbeddingResult` object from the given protobuf object."""
return EmbeddingResult( return EmbeddingResult(
embeddings=[ embeddings=[
Embeddings.create_from_pb2(embedding) Embedding.create_from_pb2(embedding)
for embedding in pb2_obj.embeddings for embedding in pb2_obj.embeddings
]) ])

View File

@ -28,3 +28,12 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],
) )
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -16,7 +16,7 @@
import dataclasses import dataclasses
from typing import Any, Optional from typing import Any, Optional
from mediapipe.tasks.cc.components.proto import embedder_options_pb2 from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions

View File

@ -1,28 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -1,13 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -23,6 +23,6 @@ py_library(
srcs = ["cosine_similarity.py"], srcs = ["cosine_similarity.py"],
deps = [ deps = [
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/proto:embedder_options", "//mediapipe/tasks/python/components/processors:embedder_options",
], ],
) )

View File

@ -16,9 +16,9 @@
import numpy as np import numpy as np
from mediapipe.tasks.python.components.containers import embeddings from mediapipe.tasks.python.components.containers import embeddings
from mediapipe.tasks.python.components.proto import embedder_options from mediapipe.tasks.python.components.processors import embedder_options
_EmbeddingEntry = embeddings.EmbeddingEntry _Embedding = embeddings.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions
@ -36,15 +36,15 @@ def _compute_cosine_similarity(u, v):
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float: def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
"""Utility function to compute cosine similarity between two embedding """Utility function to compute cosine similarity between two embedding.
entries. May return an InvalidArgumentError if e.g. the feature vectors are May return an InvalidArgumentError if e.g. the feature vectors are of
of different types (quantized vs. float), have different sizes, or have an different types (quantized vs. float), have different sizes, or have an
L2-norm of 0. L2-norm of 0.
Args: Args:
u: An embedding entry. u: An embedding.
v: An embedding entry. v: An embedding.
""" """
if len(u.embedding) != len(v.embedding): if len(u.embedding) != len(v.embedding):
raise ValueError(f"Cannot compute cosine similarity between embeddings " raise ValueError(f"Cannot compute cosine similarity between embeddings "

View File

@ -83,7 +83,7 @@ py_test(
], ],
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/proto:embedder_options", "//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity", "//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",

View File

@ -22,7 +22,7 @@ from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
@ -36,8 +36,7 @@ _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions _EmbedderOptions = embedder_options_module.EmbedderOptions
_FloatEmbedding = embeddings_module.FloatEmbedding _FloatEmbedding = embeddings_module.FloatEmbedding
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding _QuantizedEmbedding = embeddings_module.QuantizedEmbedding
_EmbeddingEntry = embeddings_module.EmbeddingEntry _Embedding = embeddings_module.Embedding
_Embeddings = embeddings_module.Embeddings
_EmbeddingResult = embeddings_module.EmbeddingResult _EmbeddingResult = embeddings_module.EmbeddingResult
_Image = image_module.Image _Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedder = image_embedder.ImageEmbedder
@ -81,12 +80,12 @@ class ImageEmbedderTest(parameterized.TestCase):
# Check embedding sizes. # Check embedding sizes.
def _check_embedding_size(result): def _check_embedding_size(result):
self.assertLen(result.embeddings, 1) self.assertLen(result.embeddings, 1)
embedding_entry = result.embeddings[0].entries[0] embedding_result = result.embeddings[0]
self.assertLen(embedding_entry.embedding, 1024) self.assertLen(embedding_result.embedding, 1024)
if quantize: if quantize:
self.assertEqual(embedding_entry.embedding.dtype, np.uint8) self.assertEqual(embedding_result.embedding.dtype, np.uint8)
else: else:
self.assertEqual(embedding_entry.embedding.dtype, float) self.assertEqual(embedding_result.embedding.dtype, float)
# Checks results sizes. # Checks results sizes.
_check_embedding_size(result0) _check_embedding_size(result0)
@ -94,7 +93,7 @@ class ImageEmbedderTest(parameterized.TestCase):
# Checks cosine similarity. # Checks cosine similarity.
similarity = _ImageEmbedder.cosine_similarity( similarity = _ImageEmbedder.cosine_similarity(
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0]) result0.embeddings[0], result1.embeddings[0])
self.assertAlmostEqual(similarity, expected_similarity, self.assertAlmostEqual(similarity, expected_similarity,
delta=_SIMILARITY_TOLERANCE) delta=_SIMILARITY_TOLERANCE)
@ -134,7 +133,7 @@ class ImageEmbedderTest(parameterized.TestCase):
crop_result = embedder.embed(self.test_cropped_image) crop_result = embedder.embed(self.test_cropped_image)
# Check embedding value. # Check embedding value.
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0], self.assertAlmostEqual(image_result.embeddings[0].embedding[0],
expected_first_value) expected_first_value)
# Checks cosine similarity. # Checks cosine similarity.

View File

@ -114,6 +114,7 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",

View File

@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.python.components.proto import embedder_options from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
@ -32,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_ImageEmbedderResult = embeddings_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions
@ -40,8 +41,8 @@ _TaskInfo = task_info_module.TaskInfo
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner_module.TaskRunner
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out' _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
_EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT' _EMBEDDINGS_TAG = 'EMBEDDINGS'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
@ -140,15 +141,15 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
embedding_result = embeddings_module.EmbeddingResult([ embeddings = embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding) embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings for embedding in embedding_result_proto.embeddings
]) ])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(embedding_result, image, options.result_callback(embeddings, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
@ -158,8 +159,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=[
':'.join([_EMBEDDING_RESULT_TAG, ':'.join([_EMBEDDINGS_TAG,
_EMBEDDING_RESULT_OUT_STREAM_NAME]), _EMBEDDINGS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
], ],
task_options=options) task_options=options)
@ -173,7 +174,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult: ) -> _ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image. """Performs image embedding extraction on the provided MediaPipe Image.
Extraction is performed on the region of interest specified by the `roi` Extraction is performed on the region of interest specified by the `roi`
argument if provided, or on the entire image otherwise. argument if provided, or on the entire image otherwise.
@ -195,18 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())}) normalized_rect.to_pb2())})
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([ return embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding) embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings for embedding in embedding_result_proto.embeddings
]) ])
def embed_for_video( def embed_for_video(
self, image: image_module.Image, self, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult: ) -> _ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames. """Performs image embedding extraction on the provided video frames.
Extraction is performed on the region of interested specified by the `roi` Extraction is performed on the region of interested specified by the `roi`
argument if provided, or on the entire image otherwise. argument if provided, or on the entire image otherwise.
@ -237,11 +238,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([ return embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding) embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings for embedding in embedding_result_proto.embeddings
]) ])
def embed_async( def embed_async(
@ -289,8 +290,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
}) })
@staticmethod @staticmethod
def cosine_similarity(u: embeddings_module.EmbeddingEntry, def cosine_similarity(u: embeddings_module.Embedding,
v: embeddings_module.EmbeddingEntry) -> float: v: embeddings_module.Embedding) -> float:
"""Utility function to compute cosine similarity [1] between two embedding """Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a of different types (quantized vs. float), have different sizes, or have a