Fixed some typos and revised image embedder tests
This commit is contained in:
parent
7ec0d8cf3b
commit
0e9b925726
|
@ -100,28 +100,22 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
embedder = _ImageEmbedder.create_from_options(options)
|
embedder = _ImageEmbedder.create_from_options(options)
|
||||||
self.assertIsInstance(embedder, _ImageEmbedder)
|
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||||
|
|
||||||
def _check_cosine_similarity(self, result0, result1, quantize,
|
def _check_embedding_value(self, result, expected_first_value):
|
||||||
expected_similarity):
|
# Check embedding first value.
|
||||||
# Checks head_index and head_name.
|
self.assertAlmostEqual(result.embeddings[0].embedding[0],
|
||||||
self.assertEqual(result0.embeddings[0].head_index, 0)
|
expected_first_value, delta=_EPSILON)
|
||||||
self.assertEqual(result1.embeddings[0].head_index, 0)
|
|
||||||
self.assertEqual(result0.embeddings[0].head_name, 'feature')
|
|
||||||
self.assertEqual(result1.embeddings[0].head_name, 'feature')
|
|
||||||
|
|
||||||
# Check embedding sizes.
|
def _check_embedding_size(self, result, quantize, expected_embedding_size):
|
||||||
def _check_embedding_size(result):
|
# Check embedding size.
|
||||||
self.assertLen(result.embeddings, 1)
|
self.assertLen(result.embeddings, 1)
|
||||||
embedding_result = result.embeddings[0]
|
embedding_result = result.embeddings[0]
|
||||||
self.assertLen(embedding_result.embedding, 1024)
|
self.assertLen(embedding_result.embedding, expected_embedding_size)
|
||||||
if quantize:
|
if quantize:
|
||||||
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(embedding_result.embedding.dtype, float)
|
self.assertEqual(embedding_result.embedding.dtype, float)
|
||||||
|
|
||||||
# Checks results sizes.
|
def _check_cosine_similarity(self, result0, result1, expected_similarity):
|
||||||
_check_embedding_size(result0)
|
|
||||||
_check_embedding_size(result1)
|
|
||||||
|
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
similarity = _ImageEmbedder.cosine_similarity(
|
similarity = _ImageEmbedder.cosine_similarity(
|
||||||
result0.embeddings[0], result1.embeddings[0])
|
result0.embeddings[0], result1.embeddings[0])
|
||||||
|
@ -129,13 +123,17 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
delta=_SIMILARITY_TOLERANCE)
|
delta=_SIMILARITY_TOLERANCE)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883),
|
(False, False, False, ModelFileType.FILE_NAME,
|
||||||
(True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344),
|
0.925519, 1024, (-0.2101883, -0.193027)),
|
||||||
# (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229),
|
(True, False, False, ModelFileType.FILE_NAME,
|
||||||
(False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062)
|
0.925519, 1024, (-0.0142344, -0.0131606)),
|
||||||
|
# (False, True, False, ModelFileType.FILE_NAME,
|
||||||
|
# 0.926791, 1024, (229, 231)),
|
||||||
|
(False, False, True, ModelFileType.FILE_CONTENT,
|
||||||
|
0.999931, 1024, (-0.195062, -0.193027))
|
||||||
)
|
)
|
||||||
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||||
expected_similarity, expected_first_value):
|
expected_similarity, expected_size, expected_first_values):
|
||||||
# Creates embedder.
|
# Creates embedder.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
@ -163,12 +161,13 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
image_result = embedder.embed(self.test_image, image_processing_options)
|
image_result = embedder.embed(self.test_image, image_processing_options)
|
||||||
crop_result = embedder.embed(self.test_cropped_image)
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
# Check embedding value.
|
# Checks embeddings and cosine similarity.
|
||||||
self.assertAlmostEqual(image_result.embeddings[0].embedding[0],
|
expected_result0_value, expected_result1_value = expected_first_values
|
||||||
expected_first_value, delta=_EPSILON)
|
self._check_embedding_size(image_result, quantize, expected_size)
|
||||||
|
self._check_embedding_size(crop_result, quantize, expected_size)
|
||||||
# Checks cosine similarity.
|
self._check_embedding_value(image_result, expected_result0_value)
|
||||||
self._check_cosine_similarity(image_result, crop_result, quantize,
|
self._check_embedding_value(crop_result, expected_result1_value)
|
||||||
|
self._check_cosine_similarity(image_result, crop_result,
|
||||||
expected_similarity)
|
expected_similarity)
|
||||||
# Closes the embedder explicitly when the embedder is not used in
|
# Closes the embedder explicitly when the embedder is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
@ -201,7 +200,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
crop_result = embedder.embed(self.test_cropped_image)
|
crop_result = embedder.embed(self.test_cropped_image)
|
||||||
|
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(image_result, crop_result, quantize,
|
self._check_cosine_similarity(image_result, crop_result,
|
||||||
expected_similarity)
|
expected_similarity)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
|
@ -283,8 +282,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
timestamp)
|
timestamp)
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(
|
self._check_cosine_similarity(
|
||||||
image_result, crop_result, quantize=False,
|
image_result, crop_result, expected_similarity=0.925519)
|
||||||
expected_similarity=0.925519)
|
|
||||||
|
|
||||||
def test_embed_for_video_succeeds_with_region_of_interest(self):
|
def test_embed_for_video_succeeds_with_region_of_interest(self):
|
||||||
options = _ImageEmbedderOptions(
|
options = _ImageEmbedderOptions(
|
||||||
|
@ -305,8 +303,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(
|
self._check_cosine_similarity(
|
||||||
image_result, crop_result, quantize=False,
|
image_result, crop_result, expected_similarity=0.999931)
|
||||||
expected_similarity=0.999931)
|
|
||||||
|
|
||||||
def test_calling_embed_in_live_stream_mode(self):
|
def test_calling_embed_in_live_stream_mode(self):
|
||||||
options = _ImageEmbedderOptions(
|
options = _ImageEmbedderOptions(
|
||||||
|
@ -352,7 +349,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
self._check_cosine_similarity(result, crop_result,
|
||||||
expected_similarity=0.925519)
|
expected_similarity=0.925519)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
@ -384,7 +381,7 @@ class ImageEmbedderTest(parameterized.TestCase):
|
||||||
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
def check_result(result: ImageEmbedderResult, output_image: _Image,
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
# Checks cosine similarity.
|
# Checks cosine similarity.
|
||||||
self._check_cosine_similarity(result, crop_result, quantize=False,
|
self._check_cosine_similarity(result, crop_result,
|
||||||
expected_similarity=0.999931)
|
expected_similarity=0.999931)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.array_equal(output_image.numpy_view(),
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
|
|
@ -20,7 +20,6 @@ from mediapipe.python import packet_creator
|
||||||
from mediapipe.python import packet_getter
|
from mediapipe.python import packet_getter
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
|
||||||
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2
|
||||||
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
|
||||||
from mediapipe.tasks.python.components.processors import embedder_options
|
from mediapipe.tasks.python.components.processors import embedder_options
|
||||||
|
@ -40,7 +39,6 @@ _EmbedderOptions = embedder_options.EmbedderOptions
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
|
||||||
|
|
||||||
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
_EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out'
|
||||||
_EMBEDDINGS_TAG = 'EMBEDDINGS'
|
_EMBEDDINGS_TAG = 'EMBEDDINGS'
|
||||||
|
@ -112,7 +110,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
`ImageEmbedderOptions`.
|
`ImageEmbedderOptions`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If failed to create `ImageClassifier` object from the provided
|
ValueError: If failed to create `ImageEmbedder` object from the provided
|
||||||
file such as invalid file path.
|
file such as invalid file path.
|
||||||
RuntimeError: If other types of error occurred.
|
RuntimeError: If other types of error occurred.
|
||||||
"""
|
"""
|
||||||
|
@ -185,7 +183,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image_processing_options: Options for image processing.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A embedding result object that contains a list of embeddings.
|
An embedding result object that contains a list of embeddings.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
@ -223,7 +221,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image_processing_options: Options for image processing.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A embedding result object that contains a list of embeddings.
|
An embedding result object that contains a list of embeddings.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
@ -265,7 +263,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||||
per input image.
|
per input image.
|
||||||
|
|
||||||
The `result_callback` provides:
|
The `result_callback` provides:
|
||||||
- A embedding result object that contains a list of embeddings.
|
- An embedding result object that contains a list of embeddings.
|
||||||
- The input image that the image embedder runs on.
|
- The input image that the image embedder runs on.
|
||||||
- The input timestamp in milliseconds.
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user