Revised image embedder implementation
This commit is contained in:
parent
492607152a
commit
e2d50745ac
|
@ -131,18 +131,14 @@ class EmbeddingEntry:
|
|||
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
|
||||
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
|
||||
|
||||
if pb2_obj.float_embedding:
|
||||
return EmbeddingEntry(
|
||||
embedding=np.array(pb2_obj.float_embedding.values, dtype=float))
|
||||
|
||||
elif pb2_obj.quantized_embedding:
|
||||
return EmbeddingEntry(
|
||||
embedding=np.array(bytearray(pb2_obj.quantized_embedding.values),
|
||||
dtype=np.uint8))
|
||||
quantized_embedding = np.array(
|
||||
bytearray(pb2_obj.quantized_embedding.values))
|
||||
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
||||
|
||||
if len(quantized_embedding) == 0:
|
||||
return EmbeddingEntry(embedding=float_embedding)
|
||||
else:
|
||||
raise ValueError("Either float_embedding or quantized_embedding must "
|
||||
"exist.")
|
||||
return EmbeddingEntry(embedding=quantized_embedding)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
|
@ -45,8 +45,7 @@ class EmbedderOptions:
|
|||
"""Generates a EmbedderOptions protobuf object."""
|
||||
return _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize,
|
||||
quantize=self.quantize
|
||||
)
|
||||
quantize=self.quantize)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -54,8 +53,7 @@ class EmbedderOptions:
|
|||
"""Creates a `EmbedderOptions` object from the given protobuf object."""
|
||||
return EmbedderOptions(
|
||||
l2_normalize=pb2_obj.l2_normalize,
|
||||
quantize=pb2_obj.quantize
|
||||
)
|
||||
quantize=pb2_obj.quantize)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "cosine_similarity",
|
||||
srcs = ["cosine_similarity.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||
],
|
||||
)
|
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Cosine similarity utilities."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.components.containers import embeddings
|
||||
from mediapipe.tasks.python.components.proto import embedder_options
|
||||
|
||||
_EmbeddingEntry = embeddings.EmbeddingEntry
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
|
||||
|
||||
def _compute_cosine_similarity(u, v):
|
||||
if len(u.embedding) <= 0:
|
||||
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
||||
|
||||
norm_u = np.linalg.norm(u.embedding)
|
||||
norm_v = np.linalg.norm(v.embedding)
|
||||
|
||||
if norm_u <= 0 or norm_v <= 0:
|
||||
raise ValueError(
|
||||
"Cannot compute cosine similarity on embedding with 0 norm.")
|
||||
|
||||
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||
|
||||
|
||||
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
|
||||
"""Utility function to compute cosine similarity between two embedding
|
||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
of different types (quantized vs. float), have different sizes, or have an
|
||||
L2-norm of 0.
|
||||
|
||||
Args:
|
||||
u: An embedding entry.
|
||||
v: An embedding entry.
|
||||
"""
|
||||
if len(u.embedding) != len(v.embedding):
|
||||
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
||||
f"of different sizes "
|
||||
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
||||
|
||||
if u.embedding.dtype == float and v.embedding.dtype == float:
|
||||
return _compute_cosine_similarity(u, v)
|
||||
|
||||
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
||||
return _compute_cosine_similarity(u, v)
|
||||
|
||||
raise ValueError("Cannot compute cosine similarity between quantized and "
|
||||
"float embeddings.")
|
|
@ -84,11 +84,13 @@ py_test(
|
|||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:image_embedder",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Tests for image embedder."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -23,31 +24,32 @@ from absl.testing import parameterized
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_embedder
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_NormalizedRect = rect_module.NormalizedRect
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
||||
_ClassificationEntry = embeddings_module.EmbeddingEntry
|
||||
_Classifications = embeddings_module.Embeddings
|
||||
_ClassificationResult = embeddings_module.EmbeddingResult
|
||||
_EmbeddingEntry = embeddings_module.EmbeddingEntry
|
||||
_Embeddings = embeddings_module.Embeddings
|
||||
_EmbeddingResult = embeddings_module.EmbeddingResult
|
||||
_Image = image_module.Image
|
||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite'
|
||||
_IMAGE_FILE = 'burger.jpg'
|
||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
_BURGER_IMAGE_FILE = 'burger.jpg'
|
||||
_BURGER_CROPPED_IMAGE_FILE = 'burger_crop.jpg'
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
_SIMILARITY_TOLERANCE = 1e-6
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
|
@ -55,18 +57,55 @@ class ModelFileType(enum.Enum):
|
|||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ImageClassifierTest(parameterized.TestCase):
|
||||
class ImageEmbedderTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BURGER_IMAGE_FILE)))
|
||||
self.test_cropped_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BURGER_CROPPED_IMAGE_FILE)))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
|
||||
def _check_cosine_similarity(self, result0, result1, quantize,
|
||||
expected_similarity):
|
||||
# Checks head_index and head_name.
|
||||
self.assertEqual(result0.embeddings[0].head_index, 0)
|
||||
self.assertEqual(result1.embeddings[0].head_index, 0)
|
||||
self.assertEqual(result0.embeddings[0].head_name, 'feature')
|
||||
self.assertEqual(result1.embeddings[0].head_name, 'feature')
|
||||
|
||||
# Check embedding sizes.
|
||||
def _check_embedding_size(result):
|
||||
self.assertLen(result.embeddings, 1)
|
||||
embedding_entry = result.embeddings[0].entries[0]
|
||||
self.assertLen(embedding_entry.embedding, 1024)
|
||||
if quantize:
|
||||
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
|
||||
else:
|
||||
self.assertEqual(embedding_entry.embedding.dtype, float)
|
||||
|
||||
# Checks results sizes.
|
||||
_check_embedding_size(result0)
|
||||
_check_embedding_size(result1)
|
||||
|
||||
# Checks cosine similarity.
|
||||
similarity = _ImageEmbedder.cosine_similarity(
|
||||
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
|
||||
self.assertAlmostEqual(similarity, expected_similarity,
|
||||
delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, False, False),
|
||||
(ModelFileType.FILE_CONTENT, False, False))
|
||||
def test_embed(self, model_file_type, l2_normalize, quantize):
|
||||
(False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883),
|
||||
(True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344),
|
||||
# (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229),
|
||||
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062)
|
||||
)
|
||||
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||
expected_similarity, expected_first_value):
|
||||
# Creates embedder.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -84,15 +123,254 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
base_options=base_options, embedder_options=embedder_options)
|
||||
embedder = _ImageEmbedder.create_from_options(options)
|
||||
|
||||
# Performs image embedding extraction on the input.
|
||||
image_result = embedder.embed(self.test_image)
|
||||
image_processing_options = None
|
||||
if with_roi:
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
|
||||
# TODO: Verify results.
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder.embed(self.test_image, image_processing_options)
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Check embedding value.
|
||||
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
|
||||
expected_first_value)
|
||||
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(image_result, crop_result, quantize,
|
||||
expected_similarity)
|
||||
# Closes the embedder explicitly when the classifier is not used in
|
||||
# a context.
|
||||
embedder.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(False, False, ModelFileType.FILE_NAME, 0.925519),
|
||||
(False, False, ModelFileType.FILE_CONTENT, 0.925519))
|
||||
def test_embed_in_context(self, l2_normalize, quantize, model_file_type,
|
||||
expected_similarity):
|
||||
# Creates embedder.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
embedder_options = _EmbedderOptions(l2_normalize=l2_normalize,
|
||||
quantize=quantize)
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=base_options, embedder_options=embedder_options)
|
||||
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder.embed(self.test_image)
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(image_result, crop_result, quantize,
|
||||
expected_similarity)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||
pass
|
||||
|
||||
def test_calling_embed_for_video_in_image_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_embed_async_in_image_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_calling_embed_in_video_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
embedder.embed(self.test_image)
|
||||
|
||||
def test_calling_embed_async_in_video_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_embed_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
unused_result = embedder.embed_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_embed_for_video(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder0.embed_for_video(self.test_image, timestamp)
|
||||
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||
timestamp)
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
image_result, crop_result, quantize=False,
|
||||
expected_similarity=0.925519)
|
||||
|
||||
def test_embed_for_video_succeeds_with_region_of_interest(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder0.embed_for_video(self.test_image, timestamp,
|
||||
image_processing_options)
|
||||
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||
timestamp)
|
||||
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
image_result, crop_result, quantize=False,
|
||||
expected_similarity=0.999931)
|
||||
|
||||
def test_calling_embed_in_live_stream_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
embedder.embed(self.test_image)
|
||||
|
||||
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
embedder.embed_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_embed_async_calls(self):
|
||||
# Get the embedding result for the cropped image.
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||
expected_similarity=0.925519)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
for timestamp in range(0, 300, 30):
|
||||
embedder.embed_async(self.test_image, timestamp)
|
||||
|
||||
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||
# Get the embedding result for the cropped image.
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||
expected_similarity=0.999931)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
for timestamp in range(0, 300, 30):
|
||||
embedder.embed_async(self.test_image, timestamp,
|
||||
image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -118,6 +118,7 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -23,20 +23,21 @@ from mediapipe.python._framework_bindings import packet as packet_module
|
|||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.proto import embedder_options
|
||||
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_NormalizedRect = rect_module.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskRunner = task_runner_module.TaskRunner
|
||||
|
||||
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
|
||||
|
@ -44,17 +45,12 @@ _EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
|
|||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
||||
# Builds a NormalizedRect covering the entire image.
|
||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageEmbedderOptions:
|
||||
"""Options for the image embedder task.
|
||||
|
@ -75,6 +71,8 @@ class ImageEmbedderOptions:
|
|||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
||||
result_callback: Optional[
|
||||
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
|
||||
|
@ -157,7 +155,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_EMBEDDING_RESULT_TAG,
|
||||
|
@ -174,7 +172,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
def embed(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> embeddings_module.EmbeddingResult:
|
||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||
Extraction is performed on the region of interest specified by the `roi`
|
||||
|
@ -182,7 +180,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A embedding result object that contains a list of embeddings.
|
||||
|
@ -191,10 +189,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image embedder failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())})
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2())})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||
|
||||
|
@ -206,7 +205,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
def embed_for_video(
|
||||
self, image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> embeddings_module.EmbeddingResult:
|
||||
"""Performs image embedding extraction on the provided video frames.
|
||||
Extraction is performed on the region of interested specified by the `roi`
|
||||
|
@ -220,7 +219,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A embedding result object that contains a list of embeddings.
|
||||
|
@ -229,12 +228,13 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image embedder failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
embedding_result_proto = packet_getter.get_proto(
|
||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||
|
@ -248,7 +248,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> None:
|
||||
""" Sends live image data to embedder, and the results will be available via
|
||||
the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
||||
|
@ -273,16 +273,39 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
roi: The region of interest.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the image
|
||||
embedder has already processed.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
|
||||
v: embeddings_module.EmbeddingEntry) -> float:
|
||||
"""Utility function to compute cosine similarity [1] between two embedding
|
||||
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||
of different types (quantized vs. float), have different sizes, or have a
|
||||
an L2-norm of 0.
|
||||
|
||||
Args:
|
||||
u: An embedding entry.
|
||||
v: An embedding entry.
|
||||
|
||||
Returns:
|
||||
The cosine similarity for the two embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: May return an error if e.g. the feature vectors are of
|
||||
different types (quantized vs. float), have different sizes, or have
|
||||
an L2-norm of 0
|
||||
"""
|
||||
return cosine_similarity.cosine_similarity(u, v)
|
||||
|
|
Loading…
Reference in New Issue
Block a user