Revised image embedder implementation
This commit is contained in:
parent
492607152a
commit
e2d50745ac
|
@ -131,18 +131,14 @@ class EmbeddingEntry:
|
||||||
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
|
cls, pb2_obj: _EmbeddingEntryProto) -> 'EmbeddingEntry':
|
||||||
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
|
"""Creates a `EmbeddingEntry` object from the given protobuf object."""
|
||||||
|
|
||||||
if pb2_obj.float_embedding:
|
quantized_embedding = np.array(
|
||||||
return EmbeddingEntry(
|
bytearray(pb2_obj.quantized_embedding.values))
|
||||||
embedding=np.array(pb2_obj.float_embedding.values, dtype=float))
|
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
|
||||||
|
|
||||||
elif pb2_obj.quantized_embedding:
|
|
||||||
return EmbeddingEntry(
|
|
||||||
embedding=np.array(bytearray(pb2_obj.quantized_embedding.values),
|
|
||||||
dtype=np.uint8))
|
|
||||||
|
|
||||||
|
if len(quantized_embedding) == 0:
|
||||||
|
return EmbeddingEntry(embedding=float_embedding)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either float_embedding or quantized_embedding must "
|
return EmbeddingEntry(embedding=quantized_embedding)
|
||||||
"exist.")
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Checks if this object is equal to the given object.
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
|
@ -45,8 +45,7 @@ class EmbedderOptions:
|
||||||
"""Generates a EmbedderOptions protobuf object."""
|
"""Generates a EmbedderOptions protobuf object."""
|
||||||
return _EmbedderOptionsProto(
|
return _EmbedderOptionsProto(
|
||||||
l2_normalize=self.l2_normalize,
|
l2_normalize=self.l2_normalize,
|
||||||
quantize=self.quantize
|
quantize=self.quantize)
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
|
@ -54,8 +53,7 @@ class EmbedderOptions:
|
||||||
"""Creates a `EmbedderOptions` object from the given protobuf object."""
|
"""Creates a `EmbedderOptions` object from the given protobuf object."""
|
||||||
return EmbedderOptions(
|
return EmbedderOptions(
|
||||||
l2_normalize=pb2_obj.l2_normalize,
|
l2_normalize=pb2_obj.l2_normalize,
|
||||||
quantize=pb2_obj.quantize
|
quantize=pb2_obj.quantize)
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Checks if this object is equal to the given object.
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
28
mediapipe/tasks/python/components/utils/BUILD
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cosine_similarity",
|
||||||
|
srcs = ["cosine_similarity.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||||
|
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
13
mediapipe/tasks/python/components/utils/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
61
mediapipe/tasks/python/components/utils/cosine_similarity.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Cosine similarity utilities."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import embeddings
|
||||||
|
from mediapipe.tasks.python.components.proto import embedder_options
|
||||||
|
|
||||||
|
_EmbeddingEntry = embeddings.EmbeddingEntry
|
||||||
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cosine_similarity(u, v):
|
||||||
|
if len(u.embedding) <= 0:
|
||||||
|
raise ValueError("Cannot compute cosing similarity on empty embeddings.")
|
||||||
|
|
||||||
|
norm_u = np.linalg.norm(u.embedding)
|
||||||
|
norm_v = np.linalg.norm(v.embedding)
|
||||||
|
|
||||||
|
if norm_u <= 0 or norm_v <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot compute cosine similarity on embedding with 0 norm.")
|
||||||
|
|
||||||
|
return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(u: _EmbeddingEntry, v: _EmbeddingEntry) -> float:
|
||||||
|
"""Utility function to compute cosine similarity between two embedding
|
||||||
|
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
|
of different types (quantized vs. float), have different sizes, or have an
|
||||||
|
L2-norm of 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u: An embedding entry.
|
||||||
|
v: An embedding entry.
|
||||||
|
"""
|
||||||
|
if len(u.embedding) != len(v.embedding):
|
||||||
|
raise ValueError(f"Cannot compute cosine similarity between embeddings "
|
||||||
|
f"of different sizes "
|
||||||
|
f"({len(u.embedding)} vs. {len(v.embedding)}).")
|
||||||
|
|
||||||
|
if u.embedding.dtype == float and v.embedding.dtype == float:
|
||||||
|
return _compute_cosine_similarity(u, v)
|
||||||
|
|
||||||
|
if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8:
|
||||||
|
return _compute_cosine_similarity(u, v)
|
||||||
|
|
||||||
|
raise ValueError("Cannot compute cosine similarity between quantized and "
|
||||||
|
"float embeddings.")
|
|
@ -84,11 +84,13 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/python/components/proto:embedder_options",
|
"//mediapipe/tasks/python/components/proto:embedder_options",
|
||||||
|
"//mediapipe/tasks/python/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/python/components/containers:embeddings",
|
"//mediapipe/tasks/python/components/containers:embeddings",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
"//mediapipe/tasks/python/vision:image_embedder",
|
"//mediapipe/tasks/python/vision:image_embedder",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"""Tests for image embedder."""
|
"""Tests for image embedder."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -23,31 +24,32 @@ from absl.testing import parameterized
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
|
from mediapipe.tasks.python.components.proto import embedder_options as embedder_options_module
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import image_embedder
|
from mediapipe.tasks.python.vision import image_embedder
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_NormalizedRect = rect_module.NormalizedRect
|
_Rect = rect.Rect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
_EmbedderOptions = embedder_options_module.EmbedderOptions
|
||||||
_FloatEmbedding = embeddings_module.FloatEmbedding
|
_FloatEmbedding = embeddings_module.FloatEmbedding
|
||||||
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
_QuantizedEmbedding = embeddings_module.QuantizedEmbedding
|
||||||
_ClassificationEntry = embeddings_module.EmbeddingEntry
|
_EmbeddingEntry = embeddings_module.EmbeddingEntry
|
||||||
_Classifications = embeddings_module.Embeddings
|
_Embeddings = embeddings_module.Embeddings
|
||||||
_ClassificationResult = embeddings_module.EmbeddingResult
|
_EmbeddingResult = embeddings_module.EmbeddingResult
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||||
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
||||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
_MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite'
|
_MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite'
|
||||||
_IMAGE_FILE = 'burger.jpg'
|
_BURGER_IMAGE_FILE = 'burger.jpg'
|
||||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
_BURGER_CROPPED_IMAGE_FILE = 'burger_crop.jpg'
|
||||||
_DENY_LIST = ['cheeseburger']
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
_SCORE_THRESHOLD = 0.5
|
_SIMILARITY_TOLERANCE = 1e-6
|
||||||
_MAX_RESULTS = 3
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFileType(enum.Enum):
|
class ModelFileType(enum.Enum):
|
||||||
|
@ -55,18 +57,55 @@ class ModelFileType(enum.Enum):
|
||||||
FILE_NAME = 2
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
class ImageClassifierTest(parameterized.TestCase):
|
class ImageEmbedderTest(parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = _Image.create_from_file(
|
self.test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
test_utils.get_test_data_path(
|
||||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
os.path.join(_TEST_DATA_DIR, _BURGER_IMAGE_FILE)))
|
||||||
|
self.test_cropped_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _BURGER_CROPPED_IMAGE_FILE)))
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||||
|
|
||||||
|
def _check_cosine_similarity(self, result0, result1, quantize,
|
||||||
|
expected_similarity):
|
||||||
|
# Checks head_index and head_name.
|
||||||
|
self.assertEqual(result0.embeddings[0].head_index, 0)
|
||||||
|
self.assertEqual(result1.embeddings[0].head_index, 0)
|
||||||
|
self.assertEqual(result0.embeddings[0].head_name, 'feature')
|
||||||
|
self.assertEqual(result1.embeddings[0].head_name, 'feature')
|
||||||
|
|
||||||
|
# Check embedding sizes.
|
||||||
|
def _check_embedding_size(result):
|
||||||
|
self.assertLen(result.embeddings, 1)
|
||||||
|
embedding_entry = result.embeddings[0].entries[0]
|
||||||
|
self.assertLen(embedding_entry.embedding, 1024)
|
||||||
|
if quantize:
|
||||||
|
self.assertEqual(embedding_entry.embedding.dtype, np.uint8)
|
||||||
|
else:
|
||||||
|
self.assertEqual(embedding_entry.embedding.dtype, float)
|
||||||
|
|
||||||
|
# Checks results sizes.
|
||||||
|
_check_embedding_size(result0)
|
||||||
|
_check_embedding_size(result1)
|
||||||
|
|
||||||
|
# Checks cosine similarity.
|
||||||
|
similarity = _ImageEmbedder.cosine_similarity(
|
||||||
|
result0.embeddings[0].entries[0], result1.embeddings[0].entries[0])
|
||||||
|
self.assertAlmostEqual(similarity, expected_similarity,
|
||||||
|
delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME, False, False),
|
(False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883),
|
||||||
(ModelFileType.FILE_CONTENT, False, False))
|
(True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344),
|
||||||
def test_embed(self, model_file_type, l2_normalize, quantize):
|
# (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229),
|
||||||
|
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062)
|
||||||
|
)
|
||||||
|
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||||
|
expected_similarity, expected_first_value):
|
||||||
# Creates embedder.
|
# Creates embedder.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
@ -84,15 +123,254 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
base_options=base_options, embedder_options=embedder_options)
|
base_options=base_options, embedder_options=embedder_options)
|
||||||
embedder = _ImageEmbedder.create_from_options(options)
|
embedder = _ImageEmbedder.create_from_options(options)
|
||||||
|
|
||||||
# Performs image embedding extraction on the input.
|
image_processing_options = None
|
||||||
image_result = embedder.embed(self.test_image)
|
if with_roi:
|
||||||
|
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||||
|
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||||
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
|
|
||||||
# TODO: Verify results.
|
# Extracts both embeddings.
|
||||||
|
image_result = embedder.embed(self.test_image, image_processing_options)
|
||||||
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
|
# Check embedding value.
|
||||||
|
self.assertAlmostEqual(image_result.embeddings[0].entries[0].embedding[0],
|
||||||
|
expected_first_value)
|
||||||
|
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(image_result, crop_result, quantize,
|
||||||
|
expected_similarity)
|
||||||
# Closes the embedder explicitly when the classifier is not used in
|
# Closes the embedder explicitly when the classifier is not used in
|
||||||
# a context.
|
# a context.
|
||||||
embedder.close()
|
embedder.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(False, False, ModelFileType.FILE_NAME, 0.925519),
|
||||||
|
(False, False, ModelFileType.FILE_CONTENT, 0.925519))
|
||||||
|
def test_embed_in_context(self, l2_normalize, quantize, model_file_type,
|
||||||
|
expected_similarity):
|
||||||
|
# Creates embedder.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
embedder_options = _EmbedderOptions(l2_normalize=l2_normalize,
|
||||||
|
quantize=quantize)
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=base_options, embedder_options=embedder_options)
|
||||||
|
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
# Extracts both embeddings.
|
||||||
|
image_result = embedder.embed(self.test_image)
|
||||||
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(image_result, crop_result, quantize,
|
||||||
|
expected_similarity)
|
||||||
|
|
||||||
|
def test_missing_result_callback(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback must be provided'):
|
||||||
|
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||||
|
def test_illegal_result_callback(self, running_mode):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=running_mode,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback should not be provided'):
|
||||||
|
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_calling_embed_for_video_in_image_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
embedder.embed_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_embed_async_in_image_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
embedder.embed_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_embed_in_video_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
embedder.embed(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_embed_async_in_video_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
embedder.embed_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_embed_for_video_with_out_of_order_timestamp(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
unused_result = embedder.embed_for_video(self.test_image, 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
embedder.embed_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_embed_for_video(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||||
|
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
# Extracts both embeddings.
|
||||||
|
image_result = embedder0.embed_for_video(self.test_image, timestamp)
|
||||||
|
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||||
|
timestamp)
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(
|
||||||
|
image_result, crop_result, quantize=False,
|
||||||
|
expected_similarity=0.925519)
|
||||||
|
|
||||||
|
def test_embed_for_video_succeeds_with_region_of_interest(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||||
|
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||||
|
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||||
|
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||||
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
|
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
# Extracts both embeddings.
|
||||||
|
image_result = embedder0.embed_for_video(self.test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||||
|
timestamp)
|
||||||
|
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(
|
||||||
|
image_result, crop_result, quantize=False,
|
||||||
|
expected_similarity=0.999931)
|
||||||
|
|
||||||
|
def test_calling_embed_in_live_stream_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
embedder.embed(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
embedder.embed_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
embedder.embed_async(self.test_image, 100)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
embedder.embed_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_embed_async_calls(self):
|
||||||
|
# Get the embedding result for the cropped image.
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||||
|
expected_similarity=0.925519)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
self.test_image.numpy_view()))
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
embedder.embed_async(self.test_image, timestamp)
|
||||||
|
|
||||||
|
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||||
|
# Get the embedding result for the cropped image.
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
|
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||||
|
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||||
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: _EmbeddingResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
# Checks cosine similarity.
|
||||||
|
self._check_cosine_similarity(result, crop_result, quantize=False,
|
||||||
|
expected_similarity=0.999931)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
self.test_image.numpy_view()))
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
options = _ImageEmbedderOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
embedder.embed_async(self.test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -118,6 +118,7 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,20 +23,21 @@ from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.proto import embedder_options
|
from mediapipe.tasks.python.components.proto import embedder_options
|
||||||
|
from mediapipe.tasks.python.components.utils import cosine_similarity
|
||||||
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
from mediapipe.tasks.python.components.containers import embeddings as embeddings_module
|
||||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_NormalizedRect = rect_module.NormalizedRect
|
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||||
_EmbedderOptions = embedder_options.EmbedderOptions
|
_EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
|
|
||||||
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
|
_EMBEDDING_RESULT_OUT_STREAM_NAME = 'embedding_result_out'
|
||||||
|
@ -44,17 +45,12 @@ _EMBEDDING_RESULT_TAG = 'EMBEDDING_RESULT'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
_NORM_RECT_NAME = 'norm_rect_in'
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
_NORM_RECT_TAG = 'NORM_RECT'
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'
|
||||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
|
||||||
# Builds a NormalizedRect covering the entire image.
|
|
||||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ImageEmbedderOptions:
|
class ImageEmbedderOptions:
|
||||||
"""Options for the image embedder task.
|
"""Options for the image embedder task.
|
||||||
|
@ -75,6 +71,8 @@ class ImageEmbedderOptions:
|
||||||
"""
|
"""
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
|
l2_normalize: Optional[bool] = None
|
||||||
|
quantize: Optional[bool] = None
|
||||||
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
embedder_options: _EmbedderOptions = _EmbedderOptions()
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
|
Callable[[embeddings_module.EmbeddingResult, image_module.Image,
|
||||||
|
@ -157,7 +155,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([_EMBEDDING_RESULT_TAG,
|
':'.join([_EMBEDDING_RESULT_TAG,
|
||||||
|
@ -174,7 +172,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def embed(
|
def embed(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> embeddings_module.EmbeddingResult:
|
) -> embeddings_module.EmbeddingResult:
|
||||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||||
Extraction is performed on the region of interest specified by the `roi`
|
Extraction is performed on the region of interest specified by the `roi`
|
||||||
|
@ -182,7 +180,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A embedding result object that contains a list of embeddings.
|
A embedding result object that contains a list of embeddings.
|
||||||
|
@ -191,10 +189,11 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image embedder failed to run.
|
RuntimeError: If image embedder failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_image_data({
|
output_packets = self._process_image_data({
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())})
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
|
normalized_rect.to_pb2())})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||||
|
|
||||||
|
@ -206,7 +205,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def embed_for_video(
|
def embed_for_video(
|
||||||
self, image: image_module.Image,
|
self, image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> embeddings_module.EmbeddingResult:
|
) -> embeddings_module.EmbeddingResult:
|
||||||
"""Performs image embedding extraction on the provided video frames.
|
"""Performs image embedding extraction on the provided video frames.
|
||||||
Extraction is performed on the region of interested specified by the `roi`
|
Extraction is performed on the region of interested specified by the `roi`
|
||||||
|
@ -220,7 +219,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A embedding result object that contains a list of embeddings.
|
A embedding result object that contains a list of embeddings.
|
||||||
|
@ -229,12 +228,13 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image embedder failed to run.
|
RuntimeError: If image embedder failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_video_data({
|
output_packets = self._process_video_data({
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at(
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
embedding_result_proto = packet_getter.get_proto(
|
embedding_result_proto = packet_getter.get_proto(
|
||||||
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
output_packets[_EMBEDDING_RESULT_OUT_STREAM_NAME])
|
||||||
|
@ -248,7 +248,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
""" Sends live image data to embedder, and the results will be available via
|
""" Sends live image data to embedder, and the results will be available via
|
||||||
the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
||||||
|
@ -273,16 +273,39 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the current input timestamp is smaller than what the image
|
ValueError: If the current input timestamp is smaller than what the image
|
||||||
embedder has already processed.
|
embedder has already processed.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
self._send_live_stream_data({
|
self._send_live_stream_data({
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()).at(
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cosine_similarity(u: embeddings_module.EmbeddingEntry,
|
||||||
|
v: embeddings_module.EmbeddingEntry) -> float:
|
||||||
|
"""Utility function to compute cosine similarity [1] between two embedding
|
||||||
|
entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
||||||
|
of different types (quantized vs. float), have different sizes, or have a
|
||||||
|
an L2-norm of 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u: An embedding entry.
|
||||||
|
v: An embedding entry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cosine similarity for the two embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: May return an error if e.g. the feature vectors are of
|
||||||
|
different types (quantized vs. float), have different sizes, or have
|
||||||
|
an L2-norm of 0
|
||||||
|
"""
|
||||||
|
return cosine_similarity.cosine_similarity(u, v)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user