From 3f68f90238a36521994c9ab5ee8500473f7f657d Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 12 Apr 2023 14:37:16 -0700 Subject: [PATCH 1/4] Deprecated output_type for the ImageSegmenter and InteractiveSegmenter APIs --- mediapipe/tasks/python/test/vision/BUILD | 19 ++ .../test/vision/image_segmenter_test.py | 194 ++++++++++++------ .../test/vision/interactive_segmenter_test.py | 29 ++- .../tasks/python/vision/image_segmenter.py | 96 ++++++--- .../python/vision/interactive_segmenter.py | 64 ++++-- mediapipe/tasks/testdata/vision/BUILD | 2 + 6 files changed, 272 insertions(+), 132 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 704e1af5c..46838c264 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -93,6 +93,25 @@ py_test( ], ) +py_test( + name = "interactive_segmenter_test", + srcs = ["interactive_segmenter_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:keypoint", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:interactive_segmenter", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) + py_test( name = "face_detector_test", srcs = ["face_detector_test.py"], diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 7f0b47eb7..327925191 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -30,10 +30,10 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageSegmenterResult = image_segmenter.ImageSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat -_OutputType = image_segmenter.ImageSegmenterOptions.OutputType _Activation = image_segmenter.ImageSegmenterOptions.Activation _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions @@ -42,11 +42,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _MODEL_FILE = 'deeplabv3.tflite' _IMAGE_FILE = 'segmentation_input_rotation0.jpg' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' +_CAT_IMAGE = 'cat.jpg' +_CAT_MASK = 'cat_mask.jpg' _MASK_MAGNIFICATION_FACTOR = 10 _MASK_SIMILARITY_THRESHOLD = 0.98 _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' +def _calculate_soft_iou(m1, m2): + intersection_sum = np.sum(m1 * m2) + union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum + + if union_sum > 0: + return intersection_sum / union_sum + else: + return 0 + + +def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold): + actual_mask = actual_mask.numpy_view() + expected_mask = expected_mask.numpy_view() / 255.0 + + return ( + actual_mask.shape == expected_mask.shape + and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold + ) + + def _similar_to_uint8_mask(actual_mask, expected_mask): actual_mask_pixels = actual_mask.numpy_view().flatten() expected_mask_pixels = expected_mask.numpy_view().flatten() @@ -84,6 +106,14 @@ class ImageSegmenterTest(parameterized.TestCase): self.model_path = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) + def _load_segmentation_mask(self, file_path: str): + # Loads ground truth segmentation file. + gt_segmentation_data = cv2.imread( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)), + cv2.IMREAD_GRAYSCALE, + ) + return _Image(_ImageFormat.GRAY8, gt_segmentation_data) + def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter: @@ -127,20 +157,19 @@ class ImageSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) + base_options=base_options, output_category_mask=True) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. self.assertEqual(result_pixels.dtype, np.uint8) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), + _similar_to_uint8_mask(category_mask, self.test_seg_image), f'Number of pixels in the candidate mask differing from that of the ' f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') @@ -152,67 +181,33 @@ class ImageSegmenterTest(parameterized.TestCase): # Creates segmenter. base_options = _BaseOptions(model_asset_path=self.model_path) - # Run segmentation on the model in CATEGORY_MASK mode. - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) - segmenter = _ImageSegmenter.create_from_options(options) - category_masks = segmenter.segment(self.test_image) - category_mask = category_masks[0].numpy_view() + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( base_options=base_options, - output_type=_OutputType.CONFIDENCE_MASK, activation=_Activation.SOFTMAX) - segmenter = _ImageSegmenter.create_from_options(options) - confidence_masks = segmenter.segment(self.test_image) - # Check if confidence mask shape is correct. - self.assertLen( - confidence_masks, 21, - 'Number of confidence masks must match with number of categories.') - - # Gather the confidence masks in a single array `confidence_mask_array`. - confidence_mask_array = np.array( - [confidence_mask.numpy_view() for confidence_mask in confidence_masks]) - - # Check if data type of `confidence_masks` are correct. - self.assertEqual(confidence_mask_array.dtype, np.float32) - - # Compute the category mask from the created confidence mask. - calculated_category_mask = np.argmax(confidence_mask_array, axis=0) - self.assertListEqual( - calculated_category_mask.tolist(), category_mask.tolist(), - 'Confidence mask does not match with the category mask.') - - # Closes the segmenter explicitly when the segmenter is not used in - # a context. - segmenter.close() - - @parameterized.parameters((ModelFileType.FILE_NAME), - (ModelFileType.FILE_CONTENT)) - def test_segment_in_context(self, model_file_type): - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(model_asset_path=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_contents = f.read() - base_options = _BaseOptions(model_asset_buffer=model_contents) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) with _ImageSegmenter.create_from_options(options) as segmenter: - # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment(test_image) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, 21, + 'Number of confidence masks must match with number of categories.') + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) def test_missing_result_callback(self): options = _ImageSegmenterOptions( @@ -280,20 +275,49 @@ class ImageSegmenterTest(parameterized.TestCase): ValueError, r'Input timestamp must be monotonically increasing'): segmenter.segment_for_video(self.test_image, 0) - def test_segment_for_video(self): + def test_segment_for_video_in_category_mask_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, + output_category_mask=True, running_mode=_RUNNING_MODE.VIDEO) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): - category_masks = segmenter.segment_for_video(self.test_image, timestamp) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment_for_video( + self.test_image, timestamp) + category_mask = segmentation_result.category_mask self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), + _similar_to_uint8_mask(category_mask, self.test_seg_image), f'Number of pixels in the candidate mask differing from that of the ' f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + def test_segment_for_video_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmentation_result = segmenter.segment_for_video( + test_image, timestamp) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, 21, + 'Number of confidence masks must match with number of categories.') + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) + def test_calling_segment_in_live_stream_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), @@ -325,13 +349,13 @@ class ImageSegmenterTest(parameterized.TestCase): ValueError, r'Input timestamp must be monotonically increasing'): segmenter.segment_async(self.test_image, 0) - def test_segment_async_calls(self): + def test_segment_async_calls_in_category_mask_mode(self): observed_timestamp_ms = -1 - def check_result(result: List[image_module.Image], output_image: _Image, + def check_result(result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int): # Get the output category mask. - category_mask = result[0] + category_mask = result.category_mask self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.width, self.test_seg_image.width) @@ -345,13 +369,49 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, + output_category_mask=True, running_mode=_RUNNING_MODE.LIVE_STREAM, result_callback=check_result) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(self.test_image, timestamp) + def test_segment_async_calls_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + observed_timestamp_ms = -1 + + def check_result(result: ImageSegmenterResult, output_image: _Image, + timestamp_ms: int): + # Get the output category mask. + confidence_masks = result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, 21, + 'Number of confidence masks must match with number of categories.') + + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=check_result) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmenter.segment_async(test_image, timestamp) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index e8c52ae3e..6af15aa09 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat _NormalizedKeypoint = keypoint_module.NormalizedKeypoint _Rect = rect.Rect -_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _RegionOfInterest = interactive_segmenter.RegionOfInterest @@ -200,15 +200,14 @@ class InteractiveSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK + base_options=base_options, output_category_mask=True ) segmenter = _InteractiveSegmenter.create_from_options(options) # Performs image segmentation on the input. roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) - category_masks = segmenter.segment(self.test_image, roi) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image, roi) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. @@ -219,7 +218,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): self.assertTrue( _similar_to_uint8_mask( - category_masks[0], test_seg_image, similarity_threshold + category_mask, test_seg_image, similarity_threshold ), ( 'Number of pixels in the candidate mask differing from that of the' @@ -253,13 +252,12 @@ class InteractiveSegmenterTest(parameterized.TestCase): roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK - ) + options = _InteractiveSegmenterOptions(base_options=base_options) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation - confidence_masks = segmenter.segment(self.test_image, roi) + segmentation_result = segmenter.segment(self.test_image, roi) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -286,16 +284,15 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK - ) + options = _InteractiveSegmenterOptions(base_options=base_options) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) - confidence_masks = segmenter.segment( + segmentation_result = segmenter.segment( self.test_image, roi, image_processing_options ) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -313,9 +310,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK - ) + options = _InteractiveSegmenterOptions(base_options=base_options) with self.assertRaisesRegex( ValueError, "This task doesn't support region-of-interest." diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index e50ffbf79..102773173 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -31,7 +31,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode -ImageSegmenterResult = List[image_module.Image] _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions @@ -42,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' @@ -53,6 +54,12 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +@dataclasses.dataclass +class ImageSegmenterResult: + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class ImageSegmenterOptions: """Options for the image segmenter task. @@ -64,19 +71,13 @@ class ImageSegmenterOptions: objects on single image inputs. 2) The video mode for segmenting objects on the decoded frames of a video. 3) The live stream mode for segmenting objects on a live stream of input data, such as from camera. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. - activation: Activation function to apply to input tensor. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - class Activation(enum.Enum): NONE = 0 SIGMOID = 1 @@ -84,7 +85,8 @@ class ImageSegmenterOptions: base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK + output_confidence_masks: bool = True + output_category_mask: bool = False activation: Optional[Activation] = Activation.NONE result_callback: Optional[ Callable[[ImageSegmenterResult, image_module.Image, int], None] @@ -98,7 +100,7 @@ class ImageSegmenterOptions: False if self.running_mode == _RunningMode.IMAGE else True ) segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value, activation=self.activation.value + activation=self.activation.value ) return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, @@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): def packets_callback(output_packets: Mapping[str, packet.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + + segmentation_result = ImageSegmenterResult() + + if options.output_confidence_masks: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if options.output_category_mask: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) - timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback( segmentation_result, image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, ) + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result def segment_for_video( @@ -285,9 +317,19 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + + return segmentation_result def segment_async( diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index 12a30b6ef..6f423d76c 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _ROI_STREAM_NAME = 'roi_in' @@ -55,32 +57,32 @@ _TASK_GRAPH_NAME = ( ) +@dataclasses.dataclass +class InteractiveSegmenterResult: + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class InteractiveSegmenterOptions: """Options for the interactive segmenter task. Attributes: base_options: Base options for the interactive segmenter task. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - base_options: _BaseOptions - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK + output_confidence_masks: bool = True + output_category_mask: bool = False @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: """Generates an InteractiveSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False - segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, @@ -192,6 +194,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If other types of error occurred. """ + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ @@ -199,10 +215,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -216,7 +229,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): image: image_module.Image, roi: RegionOfInterest, image_processing_options: Optional[_ImageProcessingOptions] = None, - ) -> List[image_module.Image]: + ) -> InteractiveSegmenterResult: """Performs the actual segmentation task on the provided MediaPipe Image. The image can be of any size with format RGB. @@ -248,7 +261,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = InteractiveSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 0de0c255c..23dc3ed5f 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -77,6 +77,7 @@ mediapipe_files(srcs = [ "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", "pose.jpg", "pose_detection.tflite", + "ptm_512_hdt_ptm_woid.tflite", "pose_landmark_lite.tflite", "pose_landmarker.task", "right_hands.jpg", @@ -187,6 +188,7 @@ filegroup( "mobilenet_v3_small_100_224_embedder.tflite", "palm_detection_full.tflite", "pose_detection.tflite", + "ptm_512_hdt_ptm_woid.tflite", "pose_landmark_lite.tflite", "pose_landmarker.task", "selfie_segm_128_128_3.tflite", From a03fa448dc09d75da7f5eecae16a1a2909a46df6 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 13 Apr 2023 11:55:37 -0700 Subject: [PATCH 2/4] Explicitly state the modes in the tests for ImageSegmenterOptions and InteractiveSegmenterOptions --- .../test/vision/image_segmenter_test.py | 32 +++++++++++++------ .../test/vision/interactive_segmenter_test.py | 18 ++++++++--- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 327925191..aa557281f 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -157,7 +157,9 @@ class ImageSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageSegmenterOptions( - base_options=base_options, output_category_mask=True) + base_options=base_options, output_category_mask=True, + output_confidence_masks=False + ) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. @@ -188,8 +190,9 @@ class ImageSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( - base_options=base_options, - activation=_Activation.SOFTMAX) + base_options=base_options, output_category_mask=False, + output_confidence_masks=True, activation=_Activation.SOFTMAX + ) with _ImageSegmenter.create_from_options(options) as segmenter: segmentation_result = segmenter.segment(test_image) @@ -279,7 +282,9 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), output_category_mask=True, - running_mode=_RUNNING_MODE.VIDEO) + output_confidence_masks=False, + running_mode=_RUNNING_MODE.VIDEO + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmentation_result = segmenter.segment_for_video( @@ -297,8 +302,10 @@ class ImageSegmenterTest(parameterized.TestCase): os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) options = _ImageSegmenterOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, output_category_mask=False, + output_confidence_masks=True + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmentation_result = segmenter.segment_for_video( @@ -370,8 +377,10 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), output_category_mask=True, + output_confidence_masks=False, running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + result_callback=check_result + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(self.test_image, timestamp) @@ -405,9 +414,12 @@ class ImageSegmenterTest(parameterized.TestCase): self.observed_timestamp_ms = timestamp_ms options = _ImageSegmenterOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_category_mask=False, + output_confidence_masks=True, + result_callback=check_result + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(test_image, timestamp) diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index 6af15aa09..aea4f8a1d 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -200,7 +200,8 @@ class InteractiveSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _InteractiveSegmenterOptions( - base_options=base_options, output_category_mask=True + base_options=base_options, output_category_mask=True, + output_confidence_masks=False ) segmenter = _InteractiveSegmenter.create_from_options(options) @@ -252,7 +253,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation @@ -284,7 +288,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation @@ -310,7 +317,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with self.assertRaisesRegex( ValueError, "This task doesn't support region-of-interest." From a036bf70cc014ce5ba435ab53e3aa25b38a7ad0e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 13 Apr 2023 21:09:01 -0700 Subject: [PATCH 3/4] Removed Activation from ImageSegmenterOptions --- .../tasks/python/test/vision/image_segmenter_test.py | 4 +--- mediapipe/tasks/python/vision/image_segmenter.py | 10 +--------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index aa557281f..de0a263f3 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -15,7 +15,6 @@ import enum import os -from typing import List from unittest import mock from absl.testing import absltest @@ -34,7 +33,6 @@ ImageSegmenterResult = image_segmenter.ImageSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat -_Activation = image_segmenter.ImageSegmenterOptions.Activation _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode @@ -191,7 +189,7 @@ class ImageSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( base_options=base_options, output_category_mask=False, - output_confidence_masks=True, activation=_Activation.SOFTMAX + output_confidence_masks=True ) with _ImageSegmenter.create_from_options(options) as segmenter: diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 102773173..c5204db47 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -78,16 +78,10 @@ class ImageSegmenterOptions: is set to the live stream mode. """ - class Activation(enum.Enum): - NONE = 0 - SIGMOID = 1 - SOFTMAX = 2 - base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE output_confidence_masks: bool = True output_category_mask: bool = False - activation: Optional[Activation] = Activation.NONE result_callback: Optional[ Callable[[ImageSegmenterResult, image_module.Image, int], None] ] = None @@ -99,9 +93,7 @@ class ImageSegmenterOptions: base_options_proto.use_stream_mode = ( False if self.running_mode == _RunningMode.IMAGE else True ) - segmenter_options_proto = _SegmenterOptionsProto( - activation=self.activation.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, From a745b71f97aa49a5533ddb282bfee2bd30654aa4 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Thu, 13 Apr 2023 21:11:20 -0700 Subject: [PATCH 4/4] Removed unused import --- mediapipe/tasks/python/vision/image_segmenter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index c5204db47..8edabe321 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -14,7 +14,6 @@ """MediaPipe image segmenter task.""" import dataclasses -import enum from typing import Callable, List, Mapping, Optional from mediapipe.python import packet_creator