Revised image embedder implementation

This commit is contained in:
kinaryml 2022-11-07 13:59:07 -08:00
parent ba1ee5b404
commit 664d9c49e7
12 changed files with 79 additions and 150 deletions

View File

@ -92,7 +92,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
],
] + select({
# TODO: Build text_classifier_graph on Windows.
"//mediapipe:windows": [],

View File

@ -22,8 +22,7 @@ from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
_EmbeddingEntryProto = embeddings_pb2.EmbeddingEntry
_EmbeddingsProto = embeddings_pb2.Embeddings
_EmbeddingProto = embeddings_pb2.Embedding
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
@ -99,28 +98,34 @@ class QuantizedEmbedding:
@dataclasses.dataclass
class EmbeddingEntry:
"""Floating-point or scalar-quantized embedding with an optional timestamp.
class Embedding:
"""Embedding result for a given embedder head.
Attributes:
embedding: The actual embedding, either floating-point or scalar-quantized.
timestamp_ms: The optional timestamp (in milliseconds) associated to the
embedding entry. This is useful for time series use cases, e.g. audio
embedding.
head_index: The index of the embedder head that produced this embedding.
This is useful for multi-head models.
head_name: The name of the embedder head, which is the corresponding tensor
metadata name (if any). This is useful for multi-head models.
"""
embedding: np.ndarray
timestamp_ms: Optional[int] = None
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingEntryProto:
"""Generates a EmbeddingEntry protobuf object."""
def to_pb2(self) -> _EmbeddingProto:
"""Generates a Embedding protobuf object."""
if self.embedding.dtype == float:
return _EmbeddingEntryProto(float_embedding=self.embedding)
return _EmbeddingProto(float_embedding=self.embedding,
head_index=self.head_index,
head_name=self.head_name)
elif self.embedding.dtype == np.uint8:
return _EmbeddingEntryProto(quantized_embedding=bytes(self.embedding))
return _EmbeddingProto(quantized_embedding=bytes(self.embedding),
head_index=self.head_index,
head_name=self.head_name)
else:
raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")
@ -128,17 +133,21 @@ class EmbeddingEntry:
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
"""Creates a `Embedding` object from the given protobuf object."""
quantized_embedding = np.array(
bytearray(pb2_obj.quantized_embedding.values))
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
if len(quantized_embedding) == 0:
return EmbeddingEntry(embedding=float_embedding)
return Embedding(embedding=float_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
else:
return EmbeddingEntry(embedding=quantized_embedding)
return Embedding(embedding=quantized_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
@ -147,55 +156,7 @@ class EmbeddingEntry:
Returns:
True if the objects are equal.
"""
if not isinstance(other, EmbeddingEntry):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class Embeddings:
"""Embeddings for a given embedder head.
Attributes:
entries: A list of `ClassificationEntry` objects.
head_index: The index of the embedder head that produced this embedding.
This is useful for multi-head models.
head_name: The name of the embedder head, which is the corresponding tensor
metadata name (if any). This is useful for multi-head models.
"""
entries: List[EmbeddingEntry]
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingsProto:
"""Generates a Embeddings protobuf object."""
return _EmbeddingsProto(
entries=[entry.to_pb2() for entry in self.entries],
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbeddingsProto) -> 'Embeddings':
"""Creates a `Embeddings` object from the given protobuf object."""
return Embeddings(
entries=[
EmbeddingEntry.create_from_pb2(entry)
for entry in pb2_obj.entries
],
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Embeddings):
if not isinstance(other, Embedding):
return False
return self.to_pb2().__eq__(other.to_pb2())
@ -203,12 +164,12 @@ class Embeddings:
@dataclasses.dataclass
class EmbeddingResult:
"""Contains one set of results per embedder head.
"""Embedding results for a given embedder model.
Attributes:
embeddings: A list of `Embeddings` objects.
embeddings: A list of `Embedding` objects.
"""
embeddings: List[Embeddings]
embeddings: List[Embedding]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _EmbeddingResultProto:
@ -225,7 +186,7 @@ class EmbeddingResult:
"""Creates a `EmbeddingResult` object from the given protobuf object."""
return EmbeddingResult(
embeddings=[
Embeddings.create_from_pb2(embedding)
Embedding.create_from_pb2(embedding)
for embedding in pb2_obj.embeddings
])

View File

@ -28,3 +28,12 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -16,7 +16,7 @@
import dataclasses
from typing import Any, Optional
from mediapipe.tasks.cc.components.proto import embedder_options_pb2
from mediapipe.tasks.cc.components.processors.proto import embedder_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions

View File

@ -1,28 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "embedder_options",
srcs = ["embedder_options.py"],
deps = [
"//mediapipe/tasks/cc/components/proto:embedder_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -1,13 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -23,6 +23,6 @@ py_library(
srcs = ["cosine_similarity.py"],
deps = [
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/proto:embedder_options",
"//mediapipe/tasks/python/components/processors:embedder_options",
],
)

View File

@ -16,9 +16,9 @@
import numpy as np
from mediapipe.tasks.python.components.containers import embeddings
from mediapipe.tasks.python.components.proto import embedder_options
from mediapipe.tasks.python.components.processors import embedder_options
_EmbeddingEntry = embeddings.EmbeddingEntry
_Embedding = embeddings.Embedding
_EmbedderOptions = embedder_options.EmbedderOptions
@ -36,15 +36,15 @@ def _compute_cosine_similarity(u, v):
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
"""Utility function to compute cosine similarity between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have an
def cosine_similarity(u: _Embedding, v: _Embedding) -> float:
"""Utility function to compute cosine similarity between two embedding.
May return an InvalidArgumentError if e.g. the feature vectors are of
different types (quantized vs. float), have different sizes, or have an
L2-norm of 0.
Args:
u: An embedding entry.
v: An embedding entry.
u: An embedding.
v: An embedding.
"""
if len(u.embedding) != len(v.embedding):
raise ValueError(f"Cannot compute cosine similarity between embeddings "

View File

@ -83,7 +83,7 @@ py_test(
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/proto:embedder_options",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/containers:rect",

View File

@ -22,7 +22,7 @@ from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
@ -36,8 +36,7 @@ _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_FloatEmbedding = embeddings_module.FloatEmbedding
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
_EmbeddingEntry = embeddings_module.EmbeddingEntry
_Embeddings = embeddings_module.Embeddings
_Embedding = embeddings_module.Embedding
_EmbeddingResult = embeddings_module.EmbeddingResult
_Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder
@ -81,12 +80,12 @@ class ImageEmbedderTest(parameterized.TestCase):
# Check embedding sizes.
def _check_embedding_size(result):
self.assertLen(result.embeddings, 1)
embedding_entry = result.embeddings[0].entries[0]
self.assertLen(embedding_entry.embedding, 1024)
embedding_result = result.embeddings[0]
self.assertLen(embedding_result.embedding, 1024)
if quantize:
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
else:
self.assertEqual(embedding_entry.embedding.dtype, float)
self.assertEqual(embedding_result.embedding.dtype, float)
# Checks results sizes.
_check_embedding_size(result0)
@ -94,7 +93,7 @@ class ImageEmbedderTest(parameterized.TestCase):
# Checks cosine similarity.
similarity = _ImageEmbedder.cosine_similarity(
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
result0.embeddings[0], result1.embeddings[0])
self.assertAlmostEqual(similarity, expected_similarity,
delta=_SIMILARITY_TOLERANCE)
@ -134,7 +133,7 @@ class ImageEmbedderTest(parameterized.TestCase):
crop_result = embedder.embed(self.test_cropped_image)
# Check embedding value.
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
self.assertAlmostEqual(image_result.embeddings[0].embedding[0],
expected_first_value)
# Checks cosine similarity.

View File

@ -114,6 +114,7 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.python.components.proto import embedder_options
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.core import base_options as base_options_module
@ -32,6 +32,7 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_ImageEmbedderResult = embeddings_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions
@ -40,8 +41,8 @@ _TaskInfo = task_info_module.TaskInfo
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskRunner = task_runner_module.TaskRunner
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
_EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
_EMBEDDINGS_TAG = 'EMBEDDINGS'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
@ -140,15 +141,15 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
embedding_result = embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding)
embeddings = embeddings_module.EmbeddingResult([
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(embedding_result, image,
options.result_callback(embeddings, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
@ -158,8 +159,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_EMBEDDING_RESULT_TAG,
_EMBEDDING_RESULT_OUT_STREAM_NAME]),
':'.join([_EMBEDDINGS_TAG,
_EMBEDDINGS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
@ -173,7 +174,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult:
) -> _ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image.
Extraction is performed on the region of interest specified by the `roi`
argument if provided, or on the entire image otherwise.
@ -195,18 +196,18 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())})
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
def embed_for_video(
self, image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult:
) -> _ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames.
Extraction is performed on the region of interested specified by the `roi`
argument if provided, or on the entire image otherwise.
@ -237,11 +238,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
return embeddings_module.EmbeddingResult([
embeddings_module.Embeddings.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
embeddings_module.Embedding.create_from_pb2(embedding)
for embedding in embedding_result_proto.embeddings
])
def embed_async(
@ -289,8 +290,8 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
})
@staticmethod
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
v: embeddings_module.EmbeddingEntry) -> float:
def cosine_similarity(u: embeddings_module.Embedding,
v: embeddings_module.Embedding) -> float:
"""Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a