From 0e9b9257262abc625089664e13571af7168ff576 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 10 Nov 2022 02:30:17 -0800 Subject: [PATCH] Fixed some typos and revised image embedder tests --- .../python/test/vision/image_embedder_test.py | 75 +++++++++---------- .../tasks/python/vision/image_embedder.py | 10 +-- 2 files changed, 40 insertions(+), 45 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 18fd644a2..3df1ed1c5 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -100,28 +100,22 @@ class ImageEmbedderTest(parameterized.TestCase): embedder = _ImageEmbedder.create_from_options(options) self.assertIsInstance(embedder, _ImageEmbedder) - def _check_cosine_similarity(self, result0, result1, quantize, - expected_similarity): - # Checks head_index and head_name. - self.assertEqual(result0.embeddings[0].head_index, 0) - self.assertEqual(result1.embeddings[0].head_index, 0) - self.assertEqual(result0.embeddings[0].head_name, 'feature') - self.assertEqual(result1.embeddings[0].head_name, 'feature') + def _check_embedding_value(self, result, expected_first_value): + # Check embedding first value. + self.assertAlmostEqual(result.embeddings[0].embedding[0], + expected_first_value, delta=_EPSILON) - # Check embedding sizes. - def _check_embedding_size(result): - self.assertLen(result.embeddings, 1) - embedding_result = result.embeddings[0] - self.assertLen(embedding_result.embedding, 1024) - if quantize: - self.assertEqual(embedding_result.embedding.dtype, np.uint8) - else: - self.assertEqual(embedding_result.embedding.dtype, float) - - # Checks results sizes. - _check_embedding_size(result0) - _check_embedding_size(result1) + def _check_embedding_size(self, result, quantize, expected_embedding_size): + # Check embedding size. + self.assertLen(result.embeddings, 1) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, expected_embedding_size) + if quantize: + self.assertEqual(embedding_result.embedding.dtype, np.uint8) + else: + self.assertEqual(embedding_result.embedding.dtype, float) + def _check_cosine_similarity(self, result0, result1, expected_similarity): # Checks cosine similarity. similarity = _ImageEmbedder.cosine_similarity( result0.embeddings[0], result1.embeddings[0]) @@ -129,13 +123,17 @@ class ImageEmbedderTest(parameterized.TestCase): delta=_SIMILARITY_TOLERANCE) @parameterized.parameters( - (False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883), - (True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344), - # (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229), - (False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062) + (False, False, False, ModelFileType.FILE_NAME, + 0.925519, 1024, (-0.2101883, -0.193027)), + (True, False, False, ModelFileType.FILE_NAME, + 0.925519, 1024, (-0.0142344, -0.0131606)), + # (False, True, False, ModelFileType.FILE_NAME, + # 0.926791, 1024, (229, 231)), + (False, False, True, ModelFileType.FILE_CONTENT, + 0.999931, 1024, (-0.195062, -0.193027)) ) def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, - expected_similarity, expected_first_value): + expected_similarity, expected_size, expected_first_values): # Creates embedder. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(model_asset_path=self.model_path) @@ -163,12 +161,13 @@ class ImageEmbedderTest(parameterized.TestCase): image_result = embedder.embed(self.test_image, image_processing_options) crop_result = embedder.embed(self.test_cropped_image) - # Check embedding value. - self.assertAlmostEqual(image_result.embeddings[0].embedding[0], - expected_first_value, delta=_EPSILON) - - # Checks cosine similarity. - self._check_cosine_similarity(image_result, crop_result, quantize, + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(image_result, quantize, expected_size) + self._check_embedding_size(crop_result, quantize, expected_size) + self._check_embedding_value(image_result, expected_result0_value) + self._check_embedding_value(crop_result, expected_result1_value) + self._check_cosine_similarity(image_result, crop_result, expected_similarity) # Closes the embedder explicitly when the embedder is not used in # a context. @@ -201,7 +200,7 @@ class ImageEmbedderTest(parameterized.TestCase): crop_result = embedder.embed(self.test_cropped_image) # Checks cosine similarity. - self._check_cosine_similarity(image_result, crop_result, quantize, + self._check_cosine_similarity(image_result, crop_result, expected_similarity) def test_missing_result_callback(self): @@ -283,8 +282,7 @@ class ImageEmbedderTest(parameterized.TestCase): timestamp) # Checks cosine similarity. self._check_cosine_similarity( - image_result, crop_result, quantize=False, - expected_similarity=0.925519) + image_result, crop_result, expected_similarity=0.925519) def test_embed_for_video_succeeds_with_region_of_interest(self): options = _ImageEmbedderOptions( @@ -305,8 +303,7 @@ class ImageEmbedderTest(parameterized.TestCase): # Checks cosine similarity. self._check_cosine_similarity( - image_result, crop_result, quantize=False, - expected_similarity=0.999931) + image_result, crop_result, expected_similarity=0.999931) def test_calling_embed_in_live_stream_mode(self): options = _ImageEmbedderOptions( @@ -352,8 +349,8 @@ class ImageEmbedderTest(parameterized.TestCase): def check_result(result: ImageEmbedderResult, output_image: _Image, timestamp_ms: int): # Checks cosine similarity. - self._check_cosine_similarity(result, crop_result, quantize=False, - expected_similarity=0.925519) + self._check_cosine_similarity(result, crop_result, + expected_similarity=0.925519) self.assertTrue( np.array_equal(output_image.numpy_view(), self.test_image.numpy_view())) @@ -384,7 +381,7 @@ class ImageEmbedderTest(parameterized.TestCase): def check_result(result: ImageEmbedderResult, output_image: _Image, timestamp_ms: int): # Checks cosine similarity. - self._check_cosine_similarity(result, crop_result, quantize=False, + self._check_cosine_similarity(result, crop_result, expected_similarity=0.999931) self.assertTrue( np.array_equal(output_image.numpy_view(), diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index bec9682d0..e696ebdc8 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -20,7 +20,6 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module -from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 from mediapipe.tasks.python.components.processors import embedder_options @@ -40,7 +39,6 @@ _EmbedderOptions = embedder_options.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_TaskRunner = task_runner_module.TaskRunner _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' _EMBEDDINGS_TAG = 'EMBEDDINGS' @@ -112,7 +110,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): `ImageEmbedderOptions`. Raises: - ValueError: If failed to create `ImageClassifier` object from the provided + ValueError: If failed to create `ImageEmbedder` object from the provided file such as invalid file path. RuntimeError: If other types of error occurred. """ @@ -185,7 +183,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): image_processing_options: Options for image processing. Returns: - A embedding result object that contains a list of embeddings. + An embedding result object that contains a list of embeddings. Raises: ValueError: If any of the input arguments is invalid. @@ -223,7 +221,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): image_processing_options: Options for image processing. Returns: - A embedding result object that contains a list of embeddings. + An embedding result object that contains a list of embeddings. Raises: ValueError: If any of the input arguments is invalid. @@ -265,7 +263,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): per input image. The `result_callback` provides: - - A embedding result object that contains a list of embeddings. + - An embedding result object that contains a list of embeddings. - The input image that the image embedder runs on. - The input timestamp in milliseconds.