Removed an unnecessary check and updated tests to check if the masks are generated or not
This commit is contained in:
		
							parent
							
								
									9032bce577
								
							
						
					
					
						commit
						b511822815
					
				|  | @ -104,12 +104,19 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|       self, | ||||
|       actual_result: PoseLandmarkerResult, | ||||
|       expected_result: PoseLandmarkerResult, | ||||
|       output_segmentation_masks: bool, | ||||
|       margin: float | ||||
|   ): | ||||
|     self._expect_pose_landmarks_correct( | ||||
|         actual_result.pose_landmarks, expected_result.pose_landmarks, | ||||
|         margin | ||||
|     ) | ||||
|     if output_segmentation_masks: | ||||
|       self.assertIsInstance(actual_result.segmentation_masks, List) | ||||
|       for i, mask in enumerate(actual_result.segmentation_masks): | ||||
|         self.assertIsInstance(mask, _Image) | ||||
|     else: | ||||
|       self.assertIsNone(actual_result.segmentation_masks) | ||||
| 
 | ||||
|   def test_create_from_file_succeeds_with_valid_model_path(self): | ||||
|     # Creates with default option and valid model file successfully. | ||||
|  | @ -141,12 +148,17 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|       self.assertIsInstance(landmarker, _PoseLandmarker) | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (ModelFileType.FILE_NAME, | ||||
|       (ModelFileType.FILE_NAME, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_CONTENT, | ||||
|       (ModelFileType.FILE_CONTENT, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_NAME, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_CONTENT, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)) | ||||
|   ) | ||||
|   def test_detect(self, model_file_type, expected_detection_result): | ||||
|   def test_detect(self, model_file_type, output_segmentation_masks, | ||||
|                   expected_detection_result): | ||||
|     # Creates pose landmarker. | ||||
|     if model_file_type is ModelFileType.FILE_NAME: | ||||
|       base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
|  | @ -158,7 +170,10 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|       # Should never happen | ||||
|       raise ValueError('model_file_type is invalid.') | ||||
| 
 | ||||
|     options = _PoseLandmarkerOptions(base_options=base_options) | ||||
|     options = _PoseLandmarkerOptions( | ||||
|         base_options=base_options, | ||||
|         output_segmentation_masks=output_segmentation_masks | ||||
|     ) | ||||
|     landmarker = _PoseLandmarker.create_from_options(options) | ||||
| 
 | ||||
|     # Performs pose landmarks detection on the input. | ||||
|  | @ -166,19 +181,27 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
| 
 | ||||
|     # Comparing results. | ||||
|     self._expect_pose_landmarker_results_correct( | ||||
|         detection_result, expected_detection_result, _LANDMARKS_MARGIN | ||||
|         detection_result, expected_detection_result, output_segmentation_masks, | ||||
|         _LANDMARKS_MARGIN | ||||
|     ) | ||||
|     # Closes the pose landmarker explicitly when the pose landmarker is not used | ||||
|     # in a context. | ||||
|     landmarker.close() | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (ModelFileType.FILE_NAME, | ||||
|       (ModelFileType.FILE_NAME, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_CONTENT, | ||||
|       (ModelFileType.FILE_CONTENT, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_NAME, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (ModelFileType.FILE_CONTENT, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)) | ||||
|   ) | ||||
|   def test_detect_in_context(self, model_file_type, expected_detection_result): | ||||
|   def test_detect_in_context( | ||||
|       self, model_file_type, output_segmentation_masks, | ||||
|       expected_detection_result | ||||
|   ): | ||||
|     # Creates pose landmarker. | ||||
|     if model_file_type is ModelFileType.FILE_NAME: | ||||
|       base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
|  | @ -190,14 +213,18 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|       # Should never happen | ||||
|       raise ValueError('model_file_type is invalid.') | ||||
| 
 | ||||
|     options = _PoseLandmarkerOptions(base_options=base_options) | ||||
|     options = _PoseLandmarkerOptions( | ||||
|         base_options=base_options, | ||||
|         output_segmentation_masks=output_segmentation_masks | ||||
|     ) | ||||
|     with _PoseLandmarker.create_from_options(options) as landmarker: | ||||
|       # Performs pose landmarks detection on the input. | ||||
|       detection_result = landmarker.detect(self.test_image) | ||||
| 
 | ||||
|       # Comparing results. | ||||
|       self._expect_pose_landmarker_results_correct( | ||||
|         detection_result, expected_detection_result, _LANDMARKS_MARGIN | ||||
|           detection_result, expected_detection_result, | ||||
|           output_segmentation_masks, _LANDMARKS_MARGIN | ||||
|       ) | ||||
| 
 | ||||
|   def test_detect_fails_with_region_of_interest(self): | ||||
|  | @ -295,12 +322,15 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|         landmarker.detect_for_video(self.test_image, 0) | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (_POSE_IMAGE, 0, | ||||
|       (_POSE_IMAGE, 0, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (_BURGER_IMAGE, 0, | ||||
|       (_POSE_IMAGE, 0, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (_BURGER_IMAGE, 0, False, | ||||
|        PoseLandmarkerResult([], [], [])) | ||||
|   ) | ||||
|   def test_detect_for_video(self, image_path, rotation, expected_result): | ||||
|   def test_detect_for_video(self, image_path, rotation, | ||||
|                             output_segmentation_masks, expected_result): | ||||
|     test_image = _Image.create_from_file( | ||||
|         test_utils.get_test_data_path(image_path)) | ||||
|     # Set rotation parameters using ImageProcessingOptions. | ||||
|  | @ -308,6 +338,7 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|         rotation_degrees=rotation) | ||||
|     options = _PoseLandmarkerOptions( | ||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|         output_segmentation_masks=output_segmentation_masks, | ||||
|         running_mode=_RUNNING_MODE.VIDEO) | ||||
|     with _PoseLandmarker.create_from_options(options) as landmarker: | ||||
|       for timestamp in range(0, 300, 30): | ||||
|  | @ -315,7 +346,8 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|                                              image_processing_options) | ||||
|         if result.pose_landmarks: | ||||
|           self._expect_pose_landmarker_results_correct( | ||||
|               result, expected_result, _LANDMARKS_MARGIN | ||||
|               result, expected_result, output_segmentation_masks, | ||||
|               _LANDMARKS_MARGIN | ||||
|           ) | ||||
|         else: | ||||
|           self.assertEqual(result, expected_result) | ||||
|  | @ -352,12 +384,15 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|         landmarker.detect_async(self.test_image, 0) | ||||
| 
 | ||||
|   @parameterized.parameters( | ||||
|       (_POSE_IMAGE, 0, | ||||
|       (_POSE_IMAGE, 0, False, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (_BURGER_IMAGE, 0, | ||||
|       (_POSE_IMAGE, 0, True, | ||||
|        _get_expected_pose_landmarker_result(_POSE_LANDMARKS)), | ||||
|       (_BURGER_IMAGE, 0, False, | ||||
|        PoseLandmarkerResult([], [], [])) | ||||
|   ) | ||||
|   def test_detect_async_calls(self, image_path, rotation, expected_result): | ||||
|   def test_detect_async_calls(self, image_path, rotation, | ||||
|                               output_segmentation_masks, expected_result): | ||||
|     test_image = _Image.create_from_file( | ||||
|         test_utils.get_test_data_path(image_path)) | ||||
|     # Set rotation parameters using ImageProcessingOptions. | ||||
|  | @ -369,7 +404,8 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
|                      timestamp_ms: int): | ||||
|       if result.pose_landmarks: | ||||
|         self._expect_pose_landmarker_results_correct( | ||||
|             result, expected_result, _LANDMARKS_MARGIN | ||||
|             result, expected_result, output_segmentation_masks, | ||||
|             _LANDMARKS_MARGIN | ||||
|         ) | ||||
|       else: | ||||
|         self.assertEqual(result, expected_result) | ||||
|  | @ -380,6 +416,7 @@ class PoseLandmarkerTest(parameterized.TestCase): | |||
| 
 | ||||
|     options = _PoseLandmarkerOptions( | ||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|         output_segmentation_masks=output_segmentation_masks, | ||||
|         running_mode=_RUNNING_MODE.LIVE_STREAM, | ||||
|         result_callback=check_result) | ||||
|     with _PoseLandmarker.create_from_options(options) as landmarker: | ||||
|  |  | |||
|  | @ -263,7 +263,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi): | |||
|       ) | ||||
| 
 | ||||
|     output_streams = [ | ||||
|         ':'.join([_SEGMENTATION_MASK_TAG, _SEGMENTATION_MASK_STREAM_NAME]), | ||||
|         ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]), | ||||
|         ':'.join([ | ||||
|           _POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user