Deprecated output_type for the ImageSegmenter and InteractiveSegmenter APIs
This commit is contained in:
		
							parent
							
								
									c7aecb42ff
								
							
						
					
					
						commit
						3f68f90238
					
				|  | @ -93,6 +93,25 @@ py_test( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "interactive_segmenter_test", | ||||
|     srcs = ["interactive_segmenter_test.py"], | ||||
|     data = [ | ||||
|         "//mediapipe/tasks/testdata/vision:test_images", | ||||
|         "//mediapipe/tasks/testdata/vision:test_models", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/python:_framework_bindings", | ||||
|         "//mediapipe/tasks/python/components/containers:keypoint", | ||||
|         "//mediapipe/tasks/python/components/containers:rect", | ||||
|         "//mediapipe/tasks/python/core:base_options", | ||||
|         "//mediapipe/tasks/python/test:test_utils", | ||||
|         "//mediapipe/tasks/python/vision:interactive_segmenter", | ||||
|         "//mediapipe/tasks/python/vision/core:image_processing_options", | ||||
|         "//mediapipe/tasks/python/vision/core:vision_task_running_mode", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "face_detector_test", | ||||
|     srcs = ["face_detector_test.py"], | ||||
|  |  | |||
|  | @ -30,10 +30,10 @@ from mediapipe.tasks.python.test import test_utils | |||
| from mediapipe.tasks.python.vision import image_segmenter | ||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode | ||||
| 
 | ||||
| ImageSegmenterResult = image_segmenter.ImageSegmenterResult | ||||
| _BaseOptions = base_options_module.BaseOptions | ||||
| _Image = image_module.Image | ||||
| _ImageFormat = image_frame.ImageFormat | ||||
| _OutputType = image_segmenter.ImageSegmenterOptions.OutputType | ||||
| _Activation = image_segmenter.ImageSegmenterOptions.Activation | ||||
| _ImageSegmenter = image_segmenter.ImageSegmenter | ||||
| _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions | ||||
|  | @ -42,11 +42,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode | |||
| _MODEL_FILE = 'deeplabv3.tflite' | ||||
| _IMAGE_FILE = 'segmentation_input_rotation0.jpg' | ||||
| _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' | ||||
| _CAT_IMAGE = 'cat.jpg' | ||||
| _CAT_MASK = 'cat_mask.jpg' | ||||
| _MASK_MAGNIFICATION_FACTOR = 10 | ||||
| _MASK_SIMILARITY_THRESHOLD = 0.98 | ||||
| _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' | ||||
| 
 | ||||
| 
 | ||||
| def _calculate_soft_iou(m1, m2): | ||||
|   intersection_sum = np.sum(m1 * m2) | ||||
|   union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum | ||||
| 
 | ||||
|   if union_sum > 0: | ||||
|     return intersection_sum / union_sum | ||||
|   else: | ||||
|     return 0 | ||||
| 
 | ||||
| 
 | ||||
| def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold): | ||||
|   actual_mask = actual_mask.numpy_view() | ||||
|   expected_mask = expected_mask.numpy_view() / 255.0 | ||||
| 
 | ||||
|   return ( | ||||
|       actual_mask.shape == expected_mask.shape | ||||
|       and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold | ||||
|   ) | ||||
| 
 | ||||
| 
 | ||||
| def _similar_to_uint8_mask(actual_mask, expected_mask): | ||||
|   actual_mask_pixels = actual_mask.numpy_view().flatten() | ||||
|   expected_mask_pixels = expected_mask.numpy_view().flatten() | ||||
|  | @ -84,6 +106,14 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|     self.model_path = test_utils.get_test_data_path( | ||||
|         os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) | ||||
| 
 | ||||
|   def _load_segmentation_mask(self, file_path: str): | ||||
|     # Loads ground truth segmentation file. | ||||
|     gt_segmentation_data = cv2.imread( | ||||
|       test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)), | ||||
|       cv2.IMREAD_GRAYSCALE, | ||||
|     ) | ||||
|     return _Image(_ImageFormat.GRAY8, gt_segmentation_data) | ||||
| 
 | ||||
|   def test_create_from_file_succeeds_with_valid_model_path(self): | ||||
|     # Creates with default option and valid model file successfully. | ||||
|     with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter: | ||||
|  | @ -127,20 +157,19 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|       raise ValueError('model_file_type is invalid.') | ||||
| 
 | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) | ||||
|         base_options=base_options, output_category_mask=True) | ||||
|     segmenter = _ImageSegmenter.create_from_options(options) | ||||
| 
 | ||||
|     # Performs image segmentation on the input. | ||||
|     category_masks = segmenter.segment(self.test_image) | ||||
|     self.assertLen(category_masks, 1) | ||||
|     category_mask = category_masks[0] | ||||
|     segmentation_result = segmenter.segment(self.test_image) | ||||
|     category_mask = segmentation_result.category_mask | ||||
|     result_pixels = category_mask.numpy_view().flatten() | ||||
| 
 | ||||
|     # Check if data type of `category_mask` is correct. | ||||
|     self.assertEqual(result_pixels.dtype, np.uint8) | ||||
| 
 | ||||
|     self.assertTrue( | ||||
|         _similar_to_uint8_mask(category_masks[0], self.test_seg_image), | ||||
|         _similar_to_uint8_mask(category_mask, self.test_seg_image), | ||||
|         f'Number of pixels in the candidate mask differing from that of the ' | ||||
|         f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') | ||||
| 
 | ||||
|  | @ -152,67 +181,33 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|     # Creates segmenter. | ||||
|     base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
| 
 | ||||
|     # Run segmentation on the model in CATEGORY_MASK mode. | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) | ||||
|     segmenter = _ImageSegmenter.create_from_options(options) | ||||
|     category_masks = segmenter.segment(self.test_image) | ||||
|     category_mask = category_masks[0].numpy_view() | ||||
|     # Load the cat image. | ||||
|     test_image = _Image.create_from_file( | ||||
|         test_utils.get_test_data_path( | ||||
|             os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||
| 
 | ||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=base_options, | ||||
|         output_type=_OutputType.CONFIDENCE_MASK, | ||||
|         activation=_Activation.SOFTMAX) | ||||
|     segmenter = _ImageSegmenter.create_from_options(options) | ||||
|     confidence_masks = segmenter.segment(self.test_image) | ||||
| 
 | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       segmentation_result = segmenter.segment(test_image) | ||||
|       confidence_masks = segmentation_result.confidence_masks | ||||
| 
 | ||||
|       # Check if confidence mask shape is correct. | ||||
|       self.assertLen( | ||||
|           confidence_masks, 21, | ||||
|           'Number of confidence masks must match with number of categories.') | ||||
| 
 | ||||
|     # Gather the confidence masks in a single array `confidence_mask_array`. | ||||
|     confidence_mask_array = np.array( | ||||
|         [confidence_mask.numpy_view() for confidence_mask in confidence_masks]) | ||||
| 
 | ||||
|     # Check if data type of `confidence_masks` are correct. | ||||
|     self.assertEqual(confidence_mask_array.dtype, np.float32) | ||||
| 
 | ||||
|     # Compute the category mask from the created confidence mask. | ||||
|     calculated_category_mask = np.argmax(confidence_mask_array, axis=0) | ||||
|     self.assertListEqual( | ||||
|         calculated_category_mask.tolist(), category_mask.tolist(), | ||||
|         'Confidence mask does not match with the category mask.') | ||||
| 
 | ||||
|     # Closes the segmenter explicitly when the segmenter is not used in | ||||
|     # a context. | ||||
|     segmenter.close() | ||||
| 
 | ||||
|   @parameterized.parameters((ModelFileType.FILE_NAME), | ||||
|                             (ModelFileType.FILE_CONTENT)) | ||||
|   def test_segment_in_context(self, model_file_type): | ||||
|     if model_file_type is ModelFileType.FILE_NAME: | ||||
|       base_options = _BaseOptions(model_asset_path=self.model_path) | ||||
|     elif model_file_type is ModelFileType.FILE_CONTENT: | ||||
|       with open(self.model_path, 'rb') as f: | ||||
|         model_contents = f.read() | ||||
|       base_options = _BaseOptions(model_asset_buffer=model_contents) | ||||
|     else: | ||||
|       # Should never happen | ||||
|       raise ValueError('model_file_type is invalid.') | ||||
| 
 | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       # Performs image segmentation on the input. | ||||
|       category_masks = segmenter.segment(self.test_image) | ||||
|       self.assertLen(category_masks, 1) | ||||
|       # Loads ground truth segmentation file. | ||||
|       expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||
| 
 | ||||
|       self.assertTrue( | ||||
|           _similar_to_uint8_mask(category_masks[0], self.test_seg_image), | ||||
|           f'Number of pixels in the candidate mask differing from that of the ' | ||||
|           f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') | ||||
|           _similar_to_float_mask( | ||||
|               confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||
|           ) | ||||
|       ) | ||||
| 
 | ||||
|   def test_missing_result_callback(self): | ||||
|     options = _ImageSegmenterOptions( | ||||
|  | @ -280,20 +275,49 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|           ValueError, r'Input timestamp must be monotonically increasing'): | ||||
|         segmenter.segment_for_video(self.test_image, 0) | ||||
| 
 | ||||
|   def test_segment_for_video(self): | ||||
|   def test_segment_for_video_in_category_mask_mode(self): | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|         output_type=_OutputType.CATEGORY_MASK, | ||||
|         output_category_mask=True, | ||||
|         running_mode=_RUNNING_MODE.VIDEO) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       for timestamp in range(0, 300, 30): | ||||
|         category_masks = segmenter.segment_for_video(self.test_image, timestamp) | ||||
|         self.assertLen(category_masks, 1) | ||||
|         segmentation_result = segmenter.segment_for_video( | ||||
|             self.test_image, timestamp) | ||||
|         category_mask = segmentation_result.category_mask | ||||
|         self.assertTrue( | ||||
|             _similar_to_uint8_mask(category_masks[0], self.test_seg_image), | ||||
|             _similar_to_uint8_mask(category_mask, self.test_seg_image), | ||||
|             f'Number of pixels in the candidate mask differing from that of the ' | ||||
|             f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') | ||||
| 
 | ||||
|   def test_segment_for_video_in_confidence_mask_mode(self): | ||||
|     # Load the cat image. | ||||
|     test_image = _Image.create_from_file( | ||||
|       test_utils.get_test_data_path( | ||||
|         os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||
| 
 | ||||
|     options = _ImageSegmenterOptions( | ||||
|       base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|       running_mode=_RUNNING_MODE.VIDEO) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       for timestamp in range(0, 300, 30): | ||||
|         segmentation_result = segmenter.segment_for_video( | ||||
|           test_image, timestamp) | ||||
|         confidence_masks = segmentation_result.confidence_masks | ||||
| 
 | ||||
|         # Check if confidence mask shape is correct. | ||||
|         self.assertLen( | ||||
|           confidence_masks, 21, | ||||
|           'Number of confidence masks must match with number of categories.') | ||||
| 
 | ||||
|         # Loads ground truth segmentation file. | ||||
|         expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||
|         self.assertTrue( | ||||
|           _similar_to_float_mask( | ||||
|             confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||
|           ) | ||||
|         ) | ||||
| 
 | ||||
|   def test_calling_segment_in_live_stream_mode(self): | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|  | @ -325,13 +349,13 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
|           ValueError, r'Input timestamp must be monotonically increasing'): | ||||
|         segmenter.segment_async(self.test_image, 0) | ||||
| 
 | ||||
|   def test_segment_async_calls(self): | ||||
|   def test_segment_async_calls_in_category_mask_mode(self): | ||||
|     observed_timestamp_ms = -1 | ||||
| 
 | ||||
|     def check_result(result: List[image_module.Image], output_image: _Image, | ||||
|     def check_result(result: ImageSegmenterResult, output_image: _Image, | ||||
|                      timestamp_ms: int): | ||||
|       # Get the output category mask. | ||||
|       category_mask = result[0] | ||||
|       category_mask = result.category_mask | ||||
|       self.assertEqual(output_image.width, self.test_image.width) | ||||
|       self.assertEqual(output_image.height, self.test_image.height) | ||||
|       self.assertEqual(output_image.width, self.test_seg_image.width) | ||||
|  | @ -345,13 +369,49 @@ class ImageSegmenterTest(parameterized.TestCase): | |||
| 
 | ||||
|     options = _ImageSegmenterOptions( | ||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|         output_type=_OutputType.CATEGORY_MASK, | ||||
|         output_category_mask=True, | ||||
|         running_mode=_RUNNING_MODE.LIVE_STREAM, | ||||
|         result_callback=check_result) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       for timestamp in range(0, 300, 30): | ||||
|         segmenter.segment_async(self.test_image, timestamp) | ||||
| 
 | ||||
|   def test_segment_async_calls_in_confidence_mask_mode(self): | ||||
|     # Load the cat image. | ||||
|     test_image = _Image.create_from_file( | ||||
|       test_utils.get_test_data_path( | ||||
|         os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||
| 
 | ||||
|     # Loads ground truth segmentation file. | ||||
|     expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||
|     observed_timestamp_ms = -1 | ||||
| 
 | ||||
|     def check_result(result: ImageSegmenterResult, output_image: _Image, | ||||
|                      timestamp_ms: int): | ||||
|       # Get the output category mask. | ||||
|       confidence_masks = result.confidence_masks | ||||
| 
 | ||||
|       # Check if confidence mask shape is correct. | ||||
|       self.assertLen( | ||||
|         confidence_masks, 21, | ||||
|         'Number of confidence masks must match with number of categories.') | ||||
| 
 | ||||
|       self.assertTrue( | ||||
|         _similar_to_float_mask( | ||||
|           confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||
|         ) | ||||
|       ) | ||||
|       self.assertLess(observed_timestamp_ms, timestamp_ms) | ||||
|       self.observed_timestamp_ms = timestamp_ms | ||||
| 
 | ||||
|     options = _ImageSegmenterOptions( | ||||
|       base_options=_BaseOptions(model_asset_path=self.model_path), | ||||
|       running_mode=_RUNNING_MODE.LIVE_STREAM, | ||||
|       result_callback=check_result) | ||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||
|       for timestamp in range(0, 300, 30): | ||||
|         segmenter.segment_async(test_image, timestamp) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   absltest.main() | ||||
|  |  | |||
|  | @ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils | |||
| from mediapipe.tasks.python.vision import interactive_segmenter | ||||
| from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | ||||
| 
 | ||||
| InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult | ||||
| _BaseOptions = base_options_module.BaseOptions | ||||
| _Image = image_module.Image | ||||
| _ImageFormat = image_frame.ImageFormat | ||||
| _NormalizedKeypoint = keypoint_module.NormalizedKeypoint | ||||
| _Rect = rect.Rect | ||||
| _OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType | ||||
| _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter | ||||
| _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions | ||||
| _RegionOfInterest = interactive_segmenter.RegionOfInterest | ||||
|  | @ -200,15 +200,14 @@ class InteractiveSegmenterTest(parameterized.TestCase): | |||
|       raise ValueError('model_file_type is invalid.') | ||||
| 
 | ||||
|     options = _InteractiveSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK | ||||
|         base_options=base_options, output_category_mask=True | ||||
|     ) | ||||
|     segmenter = _InteractiveSegmenter.create_from_options(options) | ||||
| 
 | ||||
|     # Performs image segmentation on the input. | ||||
|     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) | ||||
|     category_masks = segmenter.segment(self.test_image, roi) | ||||
|     self.assertLen(category_masks, 1) | ||||
|     category_mask = category_masks[0] | ||||
|     segmentation_result = segmenter.segment(self.test_image, roi) | ||||
|     category_mask = segmentation_result.category_mask | ||||
|     result_pixels = category_mask.numpy_view().flatten() | ||||
| 
 | ||||
|     # Check if data type of `category_mask` is correct. | ||||
|  | @ -219,7 +218,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): | |||
| 
 | ||||
|     self.assertTrue( | ||||
|         _similar_to_uint8_mask( | ||||
|             category_masks[0], test_seg_image, similarity_threshold | ||||
|             category_mask, test_seg_image, similarity_threshold | ||||
|         ), | ||||
|         ( | ||||
|             'Number of pixels in the candidate mask differing from that of the' | ||||
|  | @ -253,13 +252,12 @@ class InteractiveSegmenterTest(parameterized.TestCase): | |||
|     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) | ||||
| 
 | ||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||
|     options = _InteractiveSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK | ||||
|     ) | ||||
|     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||
| 
 | ||||
|     with _InteractiveSegmenter.create_from_options(options) as segmenter: | ||||
|       # Perform segmentation | ||||
|       confidence_masks = segmenter.segment(self.test_image, roi) | ||||
|       segmentation_result = segmenter.segment(self.test_image, roi) | ||||
|       confidence_masks = segmentation_result.confidence_masks | ||||
| 
 | ||||
|       # Check if confidence mask shape is correct. | ||||
|       self.assertLen( | ||||
|  | @ -286,16 +284,15 @@ class InteractiveSegmenterTest(parameterized.TestCase): | |||
|     ) | ||||
| 
 | ||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||
|     options = _InteractiveSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK | ||||
|     ) | ||||
|     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||
| 
 | ||||
|     with _InteractiveSegmenter.create_from_options(options) as segmenter: | ||||
|       # Perform segmentation | ||||
|       image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) | ||||
|       confidence_masks = segmenter.segment( | ||||
|       segmentation_result = segmenter.segment( | ||||
|           self.test_image, roi, image_processing_options | ||||
|       ) | ||||
|       confidence_masks = segmentation_result.confidence_masks | ||||
| 
 | ||||
|       # Check if confidence mask shape is correct. | ||||
|       self.assertLen( | ||||
|  | @ -313,9 +310,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): | |||
|     ) | ||||
| 
 | ||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||
|     options = _InteractiveSegmenterOptions( | ||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK | ||||
|     ) | ||||
|     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||
| 
 | ||||
|     with self.assertRaisesRegex( | ||||
|         ValueError, "This task doesn't support region-of-interest." | ||||
|  |  | |||
|  | @ -31,7 +31,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api | |||
| from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | ||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode | ||||
| 
 | ||||
| ImageSegmenterResult = List[image_module.Image] | ||||
| _NormalizedRect = rect.NormalizedRect | ||||
| _BaseOptions = base_options_module.BaseOptions | ||||
| _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions | ||||
|  | @ -42,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode | |||
| _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | ||||
| _TaskInfo = task_info_module.TaskInfo | ||||
| 
 | ||||
| _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' | ||||
| _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' | ||||
| _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' | ||||
| _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' | ||||
| _CATEGORY_MASK_STREAM_NAME = 'category_mask' | ||||
| _CATEGORY_MASK_TAG = 'CATEGORY_MASK' | ||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | ||||
| _IMAGE_OUT_STREAM_NAME = 'image_out' | ||||
| _IMAGE_TAG = 'IMAGE' | ||||
|  | @ -53,6 +54,12 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' | |||
| _MICRO_SECONDS_PER_MILLISECOND = 1000 | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
| class ImageSegmenterResult: | ||||
|   confidence_masks: Optional[List[image_module.Image]] = None | ||||
|   category_mask: Optional[image_module.Image] = None | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
| class ImageSegmenterOptions: | ||||
|   """Options for the image segmenter task. | ||||
|  | @ -64,19 +71,13 @@ class ImageSegmenterOptions: | |||
|       objects on single image inputs. 2) The video mode for segmenting objects | ||||
|       on the decoded frames of a video. 3) The live stream mode for segmenting | ||||
|       objects on a live stream of input data, such as from camera. | ||||
|     output_type: The output mask type allows specifying the type of | ||||
|       post-processing to perform on the raw model results. | ||||
|     activation: Activation function to apply to input tensor. | ||||
|     output_confidence_masks: Whether to output confidence masks. | ||||
|     output_category_mask: Whether to output category mask. | ||||
|     result_callback: The user-defined result callback for processing live stream | ||||
|       data. The result callback should only be specified when the running mode | ||||
|       is set to the live stream mode. | ||||
|   """ | ||||
| 
 | ||||
|   class OutputType(enum.Enum): | ||||
|     UNSPECIFIED = 0 | ||||
|     CATEGORY_MASK = 1 | ||||
|     CONFIDENCE_MASK = 2 | ||||
| 
 | ||||
|   class Activation(enum.Enum): | ||||
|     NONE = 0 | ||||
|     SIGMOID = 1 | ||||
|  | @ -84,7 +85,8 @@ class ImageSegmenterOptions: | |||
| 
 | ||||
|   base_options: _BaseOptions | ||||
|   running_mode: _RunningMode = _RunningMode.IMAGE | ||||
|   output_type: Optional[OutputType] = OutputType.CATEGORY_MASK | ||||
|   output_confidence_masks: bool = True | ||||
|   output_category_mask: bool = False | ||||
|   activation: Optional[Activation] = Activation.NONE | ||||
|   result_callback: Optional[ | ||||
|       Callable[[ImageSegmenterResult, image_module.Image, int], None] | ||||
|  | @ -98,7 +100,7 @@ class ImageSegmenterOptions: | |||
|         False if self.running_mode == _RunningMode.IMAGE else True | ||||
|     ) | ||||
|     segmenter_options_proto = _SegmenterOptionsProto( | ||||
|         output_type=self.output_type.value, activation=self.activation.value | ||||
|         activation=self.activation.value | ||||
|     ) | ||||
|     return _ImageSegmenterGraphOptionsProto( | ||||
|         base_options=base_options_proto, | ||||
|  | @ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|     def packets_callback(output_packets: Mapping[str, packet.Packet]): | ||||
|       if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): | ||||
|         return | ||||
|       segmentation_result = packet_getter.get_image_list( | ||||
|           output_packets[_SEGMENTATION_OUT_STREAM_NAME] | ||||
| 
 | ||||
|       segmentation_result = ImageSegmenterResult() | ||||
| 
 | ||||
|       if options.output_confidence_masks: | ||||
|         segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||
|             output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||
|         ) | ||||
| 
 | ||||
|       if options.output_category_mask: | ||||
|         segmentation_result.category_mask = packet_getter.get_image( | ||||
|             output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||
|         ) | ||||
| 
 | ||||
|       image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) | ||||
|       timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp | ||||
|       timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp | ||||
|       options.result_callback( | ||||
|           segmentation_result, | ||||
|           image, | ||||
|           timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, | ||||
|       ) | ||||
| 
 | ||||
|     output_streams = [ | ||||
|         ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||
|     ] | ||||
| 
 | ||||
|     if options.output_confidence_masks: | ||||
|       output_streams.append( | ||||
|         ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) | ||||
|       ) | ||||
| 
 | ||||
|     if options.output_category_mask: | ||||
|       output_streams.append( | ||||
|         ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) | ||||
|       ) | ||||
| 
 | ||||
|     task_info = _TaskInfo( | ||||
|         task_graph=_TASK_GRAPH_NAME, | ||||
|         input_streams=[ | ||||
|             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), | ||||
|             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), | ||||
|         ], | ||||
|         output_streams=[ | ||||
|             ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), | ||||
|             ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||
|         ], | ||||
|         output_streams=output_streams, | ||||
|         task_options=options, | ||||
|     ) | ||||
|     return cls( | ||||
|  | @ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|             normalized_rect.to_pb2() | ||||
|         ), | ||||
|     }) | ||||
|     segmentation_result = packet_getter.get_image_list( | ||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | ||||
|     segmentation_result = ImageSegmenterResult() | ||||
| 
 | ||||
|     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||
|       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||
|           output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
|     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||
|       segmentation_result.category_mask = packet_getter.get_image( | ||||
|           output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
|     return segmentation_result | ||||
| 
 | ||||
|   def segment_for_video( | ||||
|  | @ -285,9 +317,19 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|             normalized_rect.to_pb2() | ||||
|         ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), | ||||
|     }) | ||||
|     segmentation_result = packet_getter.get_image_list( | ||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | ||||
|     segmentation_result = ImageSegmenterResult() | ||||
| 
 | ||||
|     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||
|       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||
|           output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
|     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||
|       segmentation_result.category_mask = packet_getter.get_image( | ||||
|           output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
| 
 | ||||
|     return segmentation_result | ||||
| 
 | ||||
|   def segment_async( | ||||
|  |  | |||
|  | @ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode | |||
| _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | ||||
| _TaskInfo = task_info_module.TaskInfo | ||||
| 
 | ||||
| _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' | ||||
| _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' | ||||
| _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' | ||||
| _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' | ||||
| _CATEGORY_MASK_STREAM_NAME = 'category_mask' | ||||
| _CATEGORY_MASK_TAG = 'CATEGORY_MASK' | ||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | ||||
| _IMAGE_OUT_STREAM_NAME = 'image_out' | ||||
| _ROI_STREAM_NAME = 'roi_in' | ||||
|  | @ -55,32 +57,32 @@ _TASK_GRAPH_NAME = ( | |||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
| class InteractiveSegmenterResult: | ||||
|   confidence_masks: Optional[List[image_module.Image]] = None | ||||
|   category_mask: Optional[image_module.Image] = None | ||||
| 
 | ||||
| 
 | ||||
| @dataclasses.dataclass | ||||
| class InteractiveSegmenterOptions: | ||||
|   """Options for the interactive segmenter task. | ||||
| 
 | ||||
|   Attributes: | ||||
|     base_options: Base options for the interactive segmenter task. | ||||
|     output_type: The output mask type allows specifying the type of | ||||
|       post-processing to perform on the raw model results. | ||||
|     output_confidence_masks: Whether to output confidence masks. | ||||
|     output_category_mask: Whether to output category mask. | ||||
|   """ | ||||
| 
 | ||||
|   class OutputType(enum.Enum): | ||||
|     UNSPECIFIED = 0 | ||||
|     CATEGORY_MASK = 1 | ||||
|     CONFIDENCE_MASK = 2 | ||||
| 
 | ||||
|   base_options: _BaseOptions | ||||
|   output_type: Optional[OutputType] = OutputType.CATEGORY_MASK | ||||
|   output_confidence_masks: bool = True | ||||
|   output_category_mask: bool = False | ||||
| 
 | ||||
|   @doc_controls.do_not_generate_docs | ||||
|   def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: | ||||
|     """Generates an InteractiveSegmenterOptions protobuf object.""" | ||||
|     base_options_proto = self.base_options.to_pb2() | ||||
|     base_options_proto.use_stream_mode = False | ||||
|     segmenter_options_proto = _SegmenterOptionsProto( | ||||
|         output_type=self.output_type.value | ||||
|     ) | ||||
|     segmenter_options_proto = _SegmenterOptionsProto() | ||||
|     return _ImageSegmenterGraphOptionsProto( | ||||
|         base_options=base_options_proto, | ||||
|         segmenter_options=segmenter_options_proto, | ||||
|  | @ -192,6 +194,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|       RuntimeError: If other types of error occurred. | ||||
|     """ | ||||
| 
 | ||||
|     output_streams = [ | ||||
|       ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||
|     ] | ||||
| 
 | ||||
|     if options.output_confidence_masks: | ||||
|       output_streams.append( | ||||
|         ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) | ||||
|       ) | ||||
| 
 | ||||
|     if options.output_category_mask: | ||||
|       output_streams.append( | ||||
|         ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) | ||||
|       ) | ||||
| 
 | ||||
|     task_info = _TaskInfo( | ||||
|         task_graph=_TASK_GRAPH_NAME, | ||||
|         input_streams=[ | ||||
|  | @ -199,10 +215,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|             ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), | ||||
|             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), | ||||
|         ], | ||||
|         output_streams=[ | ||||
|             ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), | ||||
|             ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||
|         ], | ||||
|         output_streams=output_streams, | ||||
|         task_options=options, | ||||
|     ) | ||||
|     return cls( | ||||
|  | @ -216,7 +229,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|       image: image_module.Image, | ||||
|       roi: RegionOfInterest, | ||||
|       image_processing_options: Optional[_ImageProcessingOptions] = None, | ||||
|   ) -> List[image_module.Image]: | ||||
|   ) -> InteractiveSegmenterResult: | ||||
|     """Performs the actual segmentation task on the provided MediaPipe Image. | ||||
| 
 | ||||
|     The image can be of any size with format RGB. | ||||
|  | @ -248,7 +261,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | |||
|             normalized_rect.to_pb2() | ||||
|         ), | ||||
|     }) | ||||
|     segmentation_result = packet_getter.get_image_list( | ||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | ||||
|     segmentation_result = InteractiveSegmenterResult() | ||||
| 
 | ||||
|     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||
|       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||
|         output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
|     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||
|       segmentation_result.category_mask = packet_getter.get_image( | ||||
|         output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||
|       ) | ||||
| 
 | ||||
|     return segmentation_result | ||||
|  |  | |||
							
								
								
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							|  | @ -77,6 +77,7 @@ mediapipe_files(srcs = [ | |||
|     "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", | ||||
|     "pose.jpg", | ||||
|     "pose_detection.tflite", | ||||
|     "ptm_512_hdt_ptm_woid.tflite", | ||||
|     "pose_landmark_lite.tflite", | ||||
|     "pose_landmarker.task", | ||||
|     "right_hands.jpg", | ||||
|  | @ -187,6 +188,7 @@ filegroup( | |||
|         "mobilenet_v3_small_100_224_embedder.tflite", | ||||
|         "palm_detection_full.tflite", | ||||
|         "pose_detection.tflite", | ||||
|         "ptm_512_hdt_ptm_woid.tflite", | ||||
|         "pose_landmark_lite.tflite", | ||||
|         "pose_landmarker.task", | ||||
|         "selfie_segm_128_128_3.tflite", | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user