Fixed some typos and revised image embedder tests
This commit is contained in:
		
							parent
							
								
									7ec0d8cf3b
								
							
						
					
					
						commit
						0e9b925726
					
				|  | @ -100,28 +100,22 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|       embedder = _ImageEmbedder.create_from_options(options) | ||||
|       self.assertIsInstance(embedder, _ImageEmbedder) | ||||
| 
 | ||||
|   def _check_cosine_similarity(self, result0, result1, quantize, | ||||
|                                expected_similarity): | ||||
|     # Checks head_index and head_name. | ||||
|     self.assertEqual(result0.embeddings[0].head_index, 0) | ||||
|     self.assertEqual(result1.embeddings[0].head_index, 0) | ||||
|     self.assertEqual(result0.embeddings[0].head_name, 'feature') | ||||
|     self.assertEqual(result1.embeddings[0].head_name, 'feature') | ||||
|   def _check_embedding_value(self, result, expected_first_value): | ||||
|     # Check embedding first value. | ||||
|     self.assertAlmostEqual(result.embeddings[0].embedding[0], | ||||
|                            expected_first_value, delta=_EPSILON) | ||||
| 
 | ||||
|     # Check embedding sizes. | ||||
|     def _check_embedding_size(result): | ||||
|       self.assertLen(result.embeddings, 1) | ||||
|       embedding_result = result.embeddings[0] | ||||
|       self.assertLen(embedding_result.embedding, 1024) | ||||
|       if quantize: | ||||
|         self.assertEqual(embedding_result.embedding.dtype, np.uint8) | ||||
|       else: | ||||
|         self.assertEqual(embedding_result.embedding.dtype, float) | ||||
| 
 | ||||
|     # Checks results sizes. | ||||
|     _check_embedding_size(result0) | ||||
|     _check_embedding_size(result1) | ||||
|   def _check_embedding_size(self, result, quantize, expected_embedding_size): | ||||
|     # Check embedding size. | ||||
|     self.assertLen(result.embeddings, 1) | ||||
|     embedding_result = result.embeddings[0] | ||||
|     self.assertLen(embedding_result.embedding, expected_embedding_size) | ||||
|     if quantize: | ||||
|       self.assertEqual(embedding_result.embedding.dtype, np.uint8) | ||||
|     else: | ||||
|       self.assertEqual(embedding_result.embedding.dtype, float) | ||||
| 
 | ||||
|   def _check_cosine_similarity(self, result0, result1, expected_similarity): | ||||
|     # Checks cosine similarity. | ||||
|     similarity = _ImageEmbedder.cosine_similarity( | ||||
|         result0.embeddings[0], result1.embeddings[0]) | ||||
|  | @ -129,13 +123,17 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|                            delta=_SIMILARITY_TOLERANCE) | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (False, False, False, ModelFileType.FILE_NAME, 0.925519, -0.2101883), | ||||
|       (True, False, False, ModelFileType.FILE_NAME, 0.925519, -0.0142344), | ||||
|       # (False, True, False, ModelFileType.FILE_NAME, 0.926791, 229), | ||||
|       (False, False, True, ModelFileType.FILE_CONTENT, 0.999931, -0.195062) | ||||
|       (False, False, False, ModelFileType.FILE_NAME, | ||||
|        0.925519, 1024, (-0.2101883, -0.193027)), | ||||
|       (True, False, False, ModelFileType.FILE_NAME, | ||||
|        0.925519, 1024, (-0.0142344, -0.0131606)), | ||||
|       # (False, True, False, ModelFileType.FILE_NAME, | ||||
|       #  0.926791, 1024, (229, 231)), | ||||
|       (False, False, True, ModelFileType.FILE_CONTENT, | ||||
|        0.999931, 1024, (-0.195062, -0.193027)) | ||||
|   ) | ||||
|   def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, | ||||
|                  expected_similarity, expected_first_value): | ||||
|                  expected_similarity, expected_size, expected_first_values): | ||||
|     # Creates embedder. | ||||
|     if model_file_type is ModelFileType.FILE_NAME: | ||||
|       base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
|  | @ -163,12 +161,13 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|     image_result = embedder.embed(self.test_image, image_processing_options) | ||||
|     crop_result = embedder.embed(self.test_cropped_image) | ||||
| 
 | ||||
|     # Check embedding value. | ||||
|     self.assertAlmostEqual(image_result.embeddings[0].embedding[0], | ||||
|                            expected_first_value, delta=_EPSILON) | ||||
| 
 | ||||
|     # Checks cosine similarity. | ||||
|     self._check_cosine_similarity(image_result, crop_result, quantize, | ||||
|     # Checks embeddings and cosine similarity. | ||||
|     expected_result0_value, expected_result1_value = expected_first_values | ||||
|     self._check_embedding_size(image_result, quantize, expected_size) | ||||
|     self._check_embedding_size(crop_result, quantize, expected_size) | ||||
|     self._check_embedding_value(image_result, expected_result0_value) | ||||
|     self._check_embedding_value(crop_result, expected_result1_value) | ||||
|     self._check_cosine_similarity(image_result, crop_result, | ||||
|                                   expected_similarity) | ||||
|     # Closes the embedder explicitly when the embedder is not used in | ||||
|     # a context. | ||||
|  | @ -201,7 +200,7 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|       crop_result = embedder.embed(self.test_cropped_image) | ||||
| 
 | ||||
|       # Checks cosine similarity. | ||||
|       self._check_cosine_similarity(image_result, crop_result, quantize, | ||||
|       self._check_cosine_similarity(image_result, crop_result, | ||||
|                                     expected_similarity) | ||||
| 
 | ||||
|   def test_missing_result_callback(self): | ||||
|  | @ -283,8 +282,7 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|                                                 timestamp) | ||||
|         # Checks cosine similarity. | ||||
|         self._check_cosine_similarity( | ||||
|             image_result, crop_result, quantize=False, | ||||
|             expected_similarity=0.925519) | ||||
|             image_result, crop_result, expected_similarity=0.925519) | ||||
| 
 | ||||
|   def test_embed_for_video_succeeds_with_region_of_interest(self): | ||||
|     options = _ImageEmbedderOptions( | ||||
|  | @ -305,8 +303,7 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
| 
 | ||||
|         # Checks cosine similarity. | ||||
|         self._check_cosine_similarity( | ||||
|             image_result, crop_result, quantize=False, | ||||
|             expected_similarity=0.999931) | ||||
|             image_result, crop_result, expected_similarity=0.999931) | ||||
| 
 | ||||
|   def test_calling_embed_in_live_stream_mode(self): | ||||
|     options = _ImageEmbedderOptions( | ||||
|  | @ -352,8 +349,8 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|     def check_result(result: ImageEmbedderResult, output_image: _Image, | ||||
|                      timestamp_ms: int): | ||||
|       # Checks cosine similarity. | ||||
|       self._check_cosine_similarity(result, crop_result, quantize=False, | ||||
|         expected_similarity=0.925519) | ||||
|       self._check_cosine_similarity(result, crop_result, | ||||
|                                     expected_similarity=0.925519) | ||||
|       self.assertTrue( | ||||
|         np.array_equal(output_image.numpy_view(), | ||||
|                        self.test_image.numpy_view())) | ||||
|  | @ -384,7 +381,7 @@ class ImageEmbedderTest(parameterized.TestCase): | |||
|     def check_result(result: ImageEmbedderResult, output_image: _Image, | ||||
|                      timestamp_ms: int): | ||||
|       # Checks cosine similarity. | ||||
|       self._check_cosine_similarity(result, crop_result, quantize=False, | ||||
|       self._check_cosine_similarity(result, crop_result, | ||||
|                                     expected_similarity=0.999931) | ||||
|       self.assertTrue( | ||||
|         np.array_equal(output_image.numpy_view(), | ||||
|  |  | |||
|  | @ -20,7 +20,6 @@ from mediapipe.python import packet_creator | |||
| from mediapipe.python import packet_getter | ||||
| from mediapipe.python._framework_bindings import image as image_module | ||||
| from mediapipe.python._framework_bindings import packet as packet_module | ||||
| from mediapipe.python._framework_bindings import task_runner as task_runner_module | ||||
| from mediapipe.tasks.cc.vision.image_embedder.proto import image_embedder_graph_options_pb2 | ||||
| from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 | ||||
| from mediapipe.tasks.python.components.processors import embedder_options | ||||
|  | @ -40,7 +39,6 @@ _EmbedderOptions = embedder_options.EmbedderOptions | |||
| _RunningMode = running_mode_module.VisionTaskRunningMode | ||||
| _TaskInfo = task_info_module.TaskInfo | ||||
| _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | ||||
| _TaskRunner = task_runner_module.TaskRunner | ||||
| 
 | ||||
| _EMBEDDINGS_OUT_STREAM_NAME = 'embeddings_out' | ||||
| _EMBEDDINGS_TAG = 'EMBEDDINGS' | ||||
|  | @ -112,7 +110,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): | |||
|       `ImageEmbedderOptions`. | ||||
| 
 | ||||
|     Raises: | ||||
|       ValueError: If failed to create `ImageClassifier` object from the provided | ||||
|       ValueError: If failed to create `ImageEmbedder` object from the provided | ||||
|         file such as invalid file path. | ||||
|       RuntimeError: If other types of error occurred. | ||||
|     """ | ||||
|  | @ -185,7 +183,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): | |||
|       image_processing_options: Options for image processing. | ||||
| 
 | ||||
|     Returns: | ||||
|       A embedding result object that contains a list of embeddings. | ||||
|       An embedding result object that contains a list of embeddings. | ||||
| 
 | ||||
|     Raises: | ||||
|       ValueError: If any of the input arguments is invalid. | ||||
|  | @ -223,7 +221,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): | |||
|       image_processing_options: Options for image processing. | ||||
| 
 | ||||
|     Returns: | ||||
|       A embedding result object that contains a list of embeddings. | ||||
|       An embedding result object that contains a list of embeddings. | ||||
| 
 | ||||
|     Raises: | ||||
|       ValueError: If any of the input arguments is invalid. | ||||
|  | @ -265,7 +263,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): | |||
|     per input image. | ||||
| 
 | ||||
|     The `result_callback` provides: | ||||
|       - A embedding result object that contains a list of embeddings. | ||||
|       - An embedding result object that contains a list of embeddings. | ||||
|       - The input image that the image embedder runs on. | ||||
|       - The input timestamp in milliseconds. | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user