Revised image embedder implementation

This commit is contained in:
kinaryml 2022-11-03 14:30:21 -07:00
parent 492607152a
commit e2d50745ac
9 changed files with 456 additions and 56 deletions

View File

@ -131,18 +131,14 @@ class EmbeddingEntry:
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry': cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
"""Creates a `EmbeddingEntry` object from the given protobuf object.""" """Creates a `EmbeddingEntry` object from the given protobuf object."""
if pb2_obj.float_embedding: quantized_embedding = np.array(
return EmbeddingEntry( bytearray(pb2_obj.quantized_embedding.values))
embedding=np.array(pb2_obj.float_embedding.values, dtype=float)) float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
elif pb2_obj.quantized_embedding:
return EmbeddingEntry(
embedding=np.array(bytearray(pb2_obj.quantized_embedding.values),
dtype=np.uint8))
if len(quantized_embedding) == 0:
return EmbeddingEntry(embedding=float_embedding)
else: else:
raise ValueError("Either float_embedding or quantized_embedding must " return EmbeddingEntry(embedding=quantized_embedding)
"exist.")
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -45,8 +45,7 @@ class EmbedderOptions:
"""Generates a EmbedderOptions protobuf object.""" """Generates a EmbedderOptions protobuf object."""
return _EmbedderOptionsProto( return _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, l2_normalize=self.l2_normalize,
quantize=self.quantize quantize=self.quantize)
)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -54,8 +53,7 @@ class EmbedderOptions:
"""Creates a `EmbedderOptions` object from the given protobuf object.""" """Creates a `EmbedderOptions` object from the given protobuf object."""
return EmbedderOptions( return EmbedderOptions(
l2_normalize=pb2_obj.l2_normalize, l2_normalize=pb2_obj.l2_normalize,
quantize=pb2_obj.quantize quantize=pb2_obj.quantize)
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -0,0 +1,28 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.py"],
deps = [
"//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/proto:embedder_options",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,61 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cosine similarity utilities."""
import numpy as np
from mediapipe.tasks.python.components.containers import embeddings
from mediapipe.tasks.python.components.proto import embedder_options
_EmbeddingEntry = embeddings.EmbeddingEntry
_EmbedderOptions = embedder_options.EmbedderOptions
def _compute_cosine_similarity(u, v):
if len(u.embedding) <= 0:
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
norm_u = np.linalg.norm(u.embedding)
norm_v = np.linalg.norm(v.embedding)
if norm_u <= 0 or norm_v <= 0:
raise ValueError(
"Cannot compute cosine similarity on embedding with 0 norm.")
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
"""Utility function to compute cosine similarity between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have an
L2-norm of 0.
Args:
u: An embedding entry.
v: An embedding entry.
"""
if len(u.embedding) != len(v.embedding):
raise ValueError(f"Cannot compute cosine similarity between embeddings "
f"of different sizes "
f"({len(u.embedding)} vs. {len(v.embedding)}).")
if u.embedding.dtype == float and v.embedding.dtype == float:
return _compute_cosine_similarity(u, v)
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
return _compute_cosine_similarity(u, v)
raise ValueError("Cannot compute cosine similarity between quantized and "
"float embeddings.")

View File

@ -84,11 +84,13 @@ py_test(
deps = [ deps = [
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/proto:embedder_options", "//mediapipe/tasks/python/components/proto:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/components/containers:embeddings", "//mediapipe/tasks/python/components/containers:embeddings",
"//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_embedder", "//mediapipe/tasks/python/vision:image_embedder",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )

View File

@ -14,6 +14,7 @@
"""Tests for image embedder.""" """Tests for image embedder."""
import enum import enum
import os
from unittest import mock from unittest import mock
import numpy as np import numpy as np
@ -23,31 +24,32 @@ from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_embedder from mediapipe.tasks.python.vision import image_embedder
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_NormalizedRect = rect_module.NormalizedRect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions _EmbedderOptions = embedder_options_module.EmbedderOptions
_FloatEmbedding = embeddings_module.FloatEmbedding _FloatEmbedding = embeddings_module.FloatEmbedding
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding _QuantizedEmbedding = embeddings_module.QuantizedEmbedding
_ClassificationEntry = embeddings_module.EmbeddingEntry _EmbeddingEntry = embeddings_module.EmbeddingEntry
_Classifications = embeddings_module.Embeddings _Embeddings = embeddings_module.Embeddings
_ClassificationResult = embeddings_module.EmbeddingResult _EmbeddingResult = embeddings_module.EmbeddingResult
_Image = image_module.Image _Image = image_module.Image
_ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedder = image_embedder.ImageEmbedder
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions _ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite' _MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite'
_IMAGE_FILE = 'burger.jpg' _BURGER_IMAGE_FILE = 'burger.jpg'
_ALLOW_LIST = ['cheeseburger', 'guacamole'] _BURGER_CROPPED_IMAGE_FILE = 'burger_crop.jpg'
_DENY_LIST = ['cheeseburger'] _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
_SCORE_THRESHOLD = 0.5 _SIMILARITY_TOLERANCE = 1e-6
_MAX_RESULTS = 3
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
@ -55,18 +57,55 @@ class ModelFileType(enum.Enum):
FILE_NAME = 2 FILE_NAME = 2
class ImageClassifierTest(parameterized.TestCase): class ImageEmbedderTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(
self.model_path = test_utils.get_test_data_path(_MODEL_FILE) os.path.join(_TEST_DATA_DIR, _BURGER_IMAGE_FILE)))
self.test_cropped_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BURGER_CROPPED_IMAGE_FILE)))
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def _check_cosine_similarity(self, result0, result1, quantize,
expected_similarity):
# Checks head_index and head_name.
self.assertEqual(result0.embeddings[0].head_index, 0)
self.assertEqual(result1.embeddings[0].head_index, 0)
self.assertEqual(result0.embeddings[0].head_name, 'feature')
self.assertEqual(result1.embeddings[0].head_name, 'feature')
# Check embedding sizes.
def _check_embedding_size(result):
self.assertLen(result.embeddings, 1)
embedding_entry = result.embeddings[0].entries[0]
self.assertLen(embedding_entry.embedding, 1024)
if quantize:
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
else:
self.assertEqual(embedding_entry.embedding.dtype, float)
# Checks results sizes.
_check_embedding_size(result0)
_check_embedding_size(result1)
# Checks cosine similarity.
similarity = _ImageEmbedder.cosine_similarity(
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
self.assertAlmostEqual(similarity, expected_similarity,
delta=_SIMILARITY_TOLERANCE)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, False, False), (False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883),
(ModelFileType.FILE_CONTENT, False, False)) (True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344),
def test_embed(self, model_file_type, l2_normalize, quantize): # (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229),
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062)
)
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
expected_similarity, expected_first_value):
# Creates embedder. # Creates embedder.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -84,15 +123,254 @@ class ImageClassifierTest(parameterized.TestCase):
base_options=base_options, embedder_options=embedder_options) base_options=base_options, embedder_options=embedder_options)
embedder = _ImageEmbedder.create_from_options(options) embedder = _ImageEmbedder.create_from_options(options)
# Performs image embedding extraction on the input. image_processing_options = None
image_result = embedder.embed(self.test_image) if with_roi:
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
image_processing_options = _ImageProcessingOptions(roi)
# TODO: Verify results. # Extracts both embeddings.
image_result = embedder.embed(self.test_image, image_processing_options)
crop_result = embedder.embed(self.test_cropped_image)
# Check embedding value.
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
expected_first_value)
# Checks cosine similarity.
self._check_cosine_similarity(image_result, crop_result, quantize,
expected_similarity)
# Closes the embedder explicitly when the classifier is not used in # Closes the embedder explicitly when the classifier is not used in
# a context. # a context.
embedder.close() embedder.close()
@parameterized.parameters(
(False, False, ModelFileType.FILE_NAME, 0.925519),
(False, False, ModelFileType.FILE_CONTENT, 0.925519))
def test_embed_in_context(self, l2_normalize, quantize, model_file_type,
expected_similarity):
# Creates embedder.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
embedder_options = _EmbedderOptions(l2_normalize=l2_normalize,
quantize=quantize)
options = _ImageEmbedderOptions(
base_options=base_options, embedder_options=embedder_options)
with _ImageEmbedder.create_from_options(options) as embedder:
# Extracts both embeddings.
image_result = embedder.embed(self.test_image)
crop_result = embedder.embed(self.test_cropped_image)
# Checks cosine similarity.
self._check_cosine_similarity(image_result, crop_result, quantize,
expected_similarity)
def test_missing_result_callback(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _ImageEmbedder.create_from_options(options) as unused_embedder:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _ImageEmbedder.create_from_options(options) as unused_embedder:
pass
def test_calling_embed_for_video_in_image_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
embedder.embed_for_video(self.test_image, 0)
def test_calling_embed_async_in_image_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
embedder.embed_async(self.test_image, 0)
def test_calling_embed_in_video_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
embedder.embed(self.test_image)
def test_calling_embed_async_in_video_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
embedder.embed_async(self.test_image, 0)
def test_embed_for_video_with_out_of_order_timestamp(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder:
unused_result = embedder.embed_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
embedder.embed_for_video(self.test_image, 0)
def test_embed_for_video(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder0, \
_ImageEmbedder.create_from_options(options) as embedder1:
for timestamp in range(0, 300, 30):
# Extracts both embeddings.
image_result = embedder0.embed_for_video(self.test_image, timestamp)
crop_result = embedder1.embed_for_video(self.test_cropped_image,
timestamp)
# Checks cosine similarity.
self._check_cosine_similarity(
image_result, crop_result, quantize=False,
expected_similarity=0.925519)
def test_embed_for_video_succeeds_with_region_of_interest(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageEmbedder.create_from_options(options) as embedder0, \
_ImageEmbedder.create_from_options(options) as embedder1:
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30):
# Extracts both embeddings.
image_result = embedder0.embed_for_video(self.test_image, timestamp,
image_processing_options)
crop_result = embedder1.embed_for_video(self.test_cropped_image,
timestamp)
# Checks cosine similarity.
self._check_cosine_similarity(
image_result, crop_result, quantize=False,
expected_similarity=0.999931)
def test_calling_embed_in_live_stream_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
embedder.embed(self.test_image)
def test_calling_classify_for_video_in_live_stream_mode(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
embedder.embed_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self):
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageEmbedder.create_from_options(options) as embedder:
embedder.embed_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
embedder.embed_async(self.test_image, 0)
def test_embed_async_calls(self):
# Get the embedding result for the cropped image.
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder:
crop_result = embedder.embed(self.test_cropped_image)
observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image,
timestamp_ms: int):
# Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False,
expected_similarity=0.925519)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _ImageEmbedder.create_from_options(options) as embedder:
for timestamp in range(0, 300, 30):
embedder.embed_async(self.test_image, timestamp)
def test_classify_async_succeeds_with_region_of_interest(self):
# Get the embedding result for the cropped image.
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageEmbedder.create_from_options(options) as embedder:
crop_result = embedder.embed(self.test_cropped_image)
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1
def check_result(result: _EmbeddingResult, output_image: _Image,
timestamp_ms: int):
# Checks cosine similarity.
self._check_cosine_similarity(result, crop_result, quantize=False,
expected_similarity=0.999931)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _ImageEmbedder.create_from_options(options) as embedder:
for timestamp in range(0, 300, 30):
embedder.embed_async(self.test_image, timestamp,
image_processing_options)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -118,6 +118,7 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )

View File

@ -23,20 +23,21 @@ from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
from mediapipe.tasks.python.components.proto import embedder_options from mediapipe.tasks.python.components.proto import embedder_options
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_NormalizedRect = rect_module.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_EmbedderOptions = embedder_options.EmbedderOptions _EmbedderOptions = embedder_options.EmbedderOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner_module.TaskRunner
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out' _EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
@ -44,17 +45,12 @@ _EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_NAME = 'norm_rect_in' _NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT' _NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass @dataclasses.dataclass
class ImageEmbedderOptions: class ImageEmbedderOptions:
"""Options for the image embedder task. """Options for the image embedder task.
@ -75,6 +71,8 @@ class ImageEmbedderOptions:
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
embedder_options: _EmbedderOptions = _EmbedderOptions() embedder_options: _EmbedderOptions = _EmbedderOptions()
result_callback: Optional[ result_callback: Optional[
Callable[[embeddings_module.EmbeddingResult, image_module.Image, Callable[[embeddings_module.EmbeddingResult, image_module.Image,
@ -157,7 +155,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=[
':'.join([_EMBEDDING_RESULT_TAG, ':'.join([_EMBEDDING_RESULT_TAG,
@ -174,7 +172,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
def embed( def embed(
self, self,
image: image_module.Image, image: image_module.Image,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult: ) -> embeddings_module.EmbeddingResult:
"""Performs image embedding extraction on the provided MediaPipe Image. """Performs image embedding extraction on the provided MediaPipe Image.
Extraction is performed on the region of interest specified by the `roi` Extraction is performed on the region of interest specified by the `roi`
@ -182,7 +180,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A embedding result object that contains a list of embeddings. A embedding result object that contains a list of embeddings.
@ -191,10 +189,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run. RuntimeError: If image embedder failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())}) _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2())})
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
@ -206,7 +205,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
def embed_for_video( def embed_for_video(
self, image: image_module.Image, self, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> embeddings_module.EmbeddingResult: ) -> embeddings_module.EmbeddingResult:
"""Performs image embedding extraction on the provided video frames. """Performs image embedding extraction on the provided video frames.
Extraction is performed on the region of interested specified by the `roi` Extraction is performed on the region of interested specified by the `roi`
@ -220,7 +219,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A embedding result object that contains a list of embeddings. A embedding result object that contains a list of embeddings.
@ -229,12 +228,13 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run. RuntimeError: If image embedder failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
embedding_result_proto = packet_getter.get_proto( embedding_result_proto = packet_getter.get_proto(
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME]) output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
@ -248,7 +248,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None: ) -> None:
""" Sends live image data to embedder, and the results will be available via """ Sends live image data to embedder, and the results will be available via
the "result_callback" provided in the ImageEmbedderOptions. Embedding the "result_callback" provided in the ImageEmbedderOptions. Embedding
@ -273,16 +273,39 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
embedder has already processed. embedder has already processed.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
@staticmethod
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
v: embeddings_module.EmbeddingEntry) -> float:
"""Utility function to compute cosine similarity [1] between two embedding
entries. May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a
an L2-norm of 0.
Args:
u: An embedding entry.
v: An embedding entry.
Returns:
The cosine similarity for the two embeddings.
Raises:
ValueError: May return an error if e.g. the feature vectors are of
different types (quantized vs. float), have different sizes, or have
an L2-norm of 0
"""
return cosine_similarity.cosine_similarity(u, v)