Deprecated output_type for the ImageSegmenter and InteractiveSegmenter APIs
This commit is contained in:
		
							parent
							
								
									c7aecb42ff
								
							
						
					
					
						commit
						3f68f90238
					
				|  | @ -93,6 +93,25 @@ py_test( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | py_test( | ||||||
|  |     name = "interactive_segmenter_test", | ||||||
|  |     srcs = ["interactive_segmenter_test.py"], | ||||||
|  |     data = [ | ||||||
|  |         "//mediapipe/tasks/testdata/vision:test_images", | ||||||
|  |         "//mediapipe/tasks/testdata/vision:test_models", | ||||||
|  |     ], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/python:_framework_bindings", | ||||||
|  |         "//mediapipe/tasks/python/components/containers:keypoint", | ||||||
|  |         "//mediapipe/tasks/python/components/containers:rect", | ||||||
|  |         "//mediapipe/tasks/python/core:base_options", | ||||||
|  |         "//mediapipe/tasks/python/test:test_utils", | ||||||
|  |         "//mediapipe/tasks/python/vision:interactive_segmenter", | ||||||
|  |         "//mediapipe/tasks/python/vision/core:image_processing_options", | ||||||
|  |         "//mediapipe/tasks/python/vision/core:vision_task_running_mode", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| py_test( | py_test( | ||||||
|     name = "face_detector_test", |     name = "face_detector_test", | ||||||
|     srcs = ["face_detector_test.py"], |     srcs = ["face_detector_test.py"], | ||||||
|  |  | ||||||
|  | @ -30,10 +30,10 @@ from mediapipe.tasks.python.test import test_utils | ||||||
| from mediapipe.tasks.python.vision import image_segmenter | from mediapipe.tasks.python.vision import image_segmenter | ||||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode | from mediapipe.tasks.python.vision.core import vision_task_running_mode | ||||||
| 
 | 
 | ||||||
|  | ImageSegmenterResult = image_segmenter.ImageSegmenterResult | ||||||
| _BaseOptions = base_options_module.BaseOptions | _BaseOptions = base_options_module.BaseOptions | ||||||
| _Image = image_module.Image | _Image = image_module.Image | ||||||
| _ImageFormat = image_frame.ImageFormat | _ImageFormat = image_frame.ImageFormat | ||||||
| _OutputType = image_segmenter.ImageSegmenterOptions.OutputType |  | ||||||
| _Activation = image_segmenter.ImageSegmenterOptions.Activation | _Activation = image_segmenter.ImageSegmenterOptions.Activation | ||||||
| _ImageSegmenter = image_segmenter.ImageSegmenter | _ImageSegmenter = image_segmenter.ImageSegmenter | ||||||
| _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions | _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions | ||||||
|  | @ -42,11 +42,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode | ||||||
| _MODEL_FILE = 'deeplabv3.tflite' | _MODEL_FILE = 'deeplabv3.tflite' | ||||||
| _IMAGE_FILE = 'segmentation_input_rotation0.jpg' | _IMAGE_FILE = 'segmentation_input_rotation0.jpg' | ||||||
| _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' | _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' | ||||||
|  | _CAT_IMAGE = 'cat.jpg' | ||||||
|  | _CAT_MASK = 'cat_mask.jpg' | ||||||
| _MASK_MAGNIFICATION_FACTOR = 10 | _MASK_MAGNIFICATION_FACTOR = 10 | ||||||
| _MASK_SIMILARITY_THRESHOLD = 0.98 | _MASK_SIMILARITY_THRESHOLD = 0.98 | ||||||
| _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' | _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _calculate_soft_iou(m1, m2): | ||||||
|  |   intersection_sum = np.sum(m1 * m2) | ||||||
|  |   union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum | ||||||
|  | 
 | ||||||
|  |   if union_sum > 0: | ||||||
|  |     return intersection_sum / union_sum | ||||||
|  |   else: | ||||||
|  |     return 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold): | ||||||
|  |   actual_mask = actual_mask.numpy_view() | ||||||
|  |   expected_mask = expected_mask.numpy_view() / 255.0 | ||||||
|  | 
 | ||||||
|  |   return ( | ||||||
|  |       actual_mask.shape == expected_mask.shape | ||||||
|  |       and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold | ||||||
|  |   ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def _similar_to_uint8_mask(actual_mask, expected_mask): | def _similar_to_uint8_mask(actual_mask, expected_mask): | ||||||
|   actual_mask_pixels = actual_mask.numpy_view().flatten() |   actual_mask_pixels = actual_mask.numpy_view().flatten() | ||||||
|   expected_mask_pixels = expected_mask.numpy_view().flatten() |   expected_mask_pixels = expected_mask.numpy_view().flatten() | ||||||
|  | @ -84,6 +106,14 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
|     self.model_path = test_utils.get_test_data_path( |     self.model_path = test_utils.get_test_data_path( | ||||||
|         os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) |         os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) | ||||||
| 
 | 
 | ||||||
|  |   def _load_segmentation_mask(self, file_path: str): | ||||||
|  |     # Loads ground truth segmentation file. | ||||||
|  |     gt_segmentation_data = cv2.imread( | ||||||
|  |       test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)), | ||||||
|  |       cv2.IMREAD_GRAYSCALE, | ||||||
|  |     ) | ||||||
|  |     return _Image(_ImageFormat.GRAY8, gt_segmentation_data) | ||||||
|  | 
 | ||||||
|   def test_create_from_file_succeeds_with_valid_model_path(self): |   def test_create_from_file_succeeds_with_valid_model_path(self): | ||||||
|     # Creates with default option and valid model file successfully. |     # Creates with default option and valid model file successfully. | ||||||
|     with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter: |     with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter: | ||||||
|  | @ -127,20 +157,19 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
|       raise ValueError('model_file_type is invalid.') |       raise ValueError('model_file_type is invalid.') | ||||||
| 
 | 
 | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) |         base_options=base_options, output_category_mask=True) | ||||||
|     segmenter = _ImageSegmenter.create_from_options(options) |     segmenter = _ImageSegmenter.create_from_options(options) | ||||||
| 
 | 
 | ||||||
|     # Performs image segmentation on the input. |     # Performs image segmentation on the input. | ||||||
|     category_masks = segmenter.segment(self.test_image) |     segmentation_result = segmenter.segment(self.test_image) | ||||||
|     self.assertLen(category_masks, 1) |     category_mask = segmentation_result.category_mask | ||||||
|     category_mask = category_masks[0] |  | ||||||
|     result_pixels = category_mask.numpy_view().flatten() |     result_pixels = category_mask.numpy_view().flatten() | ||||||
| 
 | 
 | ||||||
|     # Check if data type of `category_mask` is correct. |     # Check if data type of `category_mask` is correct. | ||||||
|     self.assertEqual(result_pixels.dtype, np.uint8) |     self.assertEqual(result_pixels.dtype, np.uint8) | ||||||
| 
 | 
 | ||||||
|     self.assertTrue( |     self.assertTrue( | ||||||
|         _similar_to_uint8_mask(category_masks[0], self.test_seg_image), |         _similar_to_uint8_mask(category_mask, self.test_seg_image), | ||||||
|         f'Number of pixels in the candidate mask differing from that of the ' |         f'Number of pixels in the candidate mask differing from that of the ' | ||||||
|         f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') |         f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') | ||||||
| 
 | 
 | ||||||
|  | @ -152,67 +181,33 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
|     # Creates segmenter. |     # Creates segmenter. | ||||||
|     base_options = _BaseOptions(model_asset_path=self.model_path) |     base_options = _BaseOptions(model_asset_path=self.model_path) | ||||||
| 
 | 
 | ||||||
|     # Run segmentation on the model in CATEGORY_MASK mode. |     # Load the cat image. | ||||||
|     options = _ImageSegmenterOptions( |     test_image = _Image.create_from_file( | ||||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) |         test_utils.get_test_data_path( | ||||||
|     segmenter = _ImageSegmenter.create_from_options(options) |             os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||||
|     category_masks = segmenter.segment(self.test_image) |  | ||||||
|     category_mask = category_masks[0].numpy_view() |  | ||||||
| 
 | 
 | ||||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. |     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|         base_options=base_options, |         base_options=base_options, | ||||||
|         output_type=_OutputType.CONFIDENCE_MASK, |  | ||||||
|         activation=_Activation.SOFTMAX) |         activation=_Activation.SOFTMAX) | ||||||
|     segmenter = _ImageSegmenter.create_from_options(options) | 
 | ||||||
|     confidence_masks = segmenter.segment(self.test_image) |     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||||
|  |       segmentation_result = segmenter.segment(test_image) | ||||||
|  |       confidence_masks = segmentation_result.confidence_masks | ||||||
| 
 | 
 | ||||||
|       # Check if confidence mask shape is correct. |       # Check if confidence mask shape is correct. | ||||||
|       self.assertLen( |       self.assertLen( | ||||||
|           confidence_masks, 21, |           confidence_masks, 21, | ||||||
|           'Number of confidence masks must match with number of categories.') |           'Number of confidence masks must match with number of categories.') | ||||||
| 
 | 
 | ||||||
|     # Gather the confidence masks in a single array `confidence_mask_array`. |       # Loads ground truth segmentation file. | ||||||
|     confidence_mask_array = np.array( |       expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||||
|         [confidence_mask.numpy_view() for confidence_mask in confidence_masks]) |  | ||||||
| 
 |  | ||||||
|     # Check if data type of `confidence_masks` are correct. |  | ||||||
|     self.assertEqual(confidence_mask_array.dtype, np.float32) |  | ||||||
| 
 |  | ||||||
|     # Compute the category mask from the created confidence mask. |  | ||||||
|     calculated_category_mask = np.argmax(confidence_mask_array, axis=0) |  | ||||||
|     self.assertListEqual( |  | ||||||
|         calculated_category_mask.tolist(), category_mask.tolist(), |  | ||||||
|         'Confidence mask does not match with the category mask.') |  | ||||||
| 
 |  | ||||||
|     # Closes the segmenter explicitly when the segmenter is not used in |  | ||||||
|     # a context. |  | ||||||
|     segmenter.close() |  | ||||||
| 
 |  | ||||||
|   @parameterized.parameters((ModelFileType.FILE_NAME), |  | ||||||
|                             (ModelFileType.FILE_CONTENT)) |  | ||||||
|   def test_segment_in_context(self, model_file_type): |  | ||||||
|     if model_file_type is ModelFileType.FILE_NAME: |  | ||||||
|       base_options = _BaseOptions(model_asset_path=self.model_path) |  | ||||||
|     elif model_file_type is ModelFileType.FILE_CONTENT: |  | ||||||
|       with open(self.model_path, 'rb') as f: |  | ||||||
|         model_contents = f.read() |  | ||||||
|       base_options = _BaseOptions(model_asset_buffer=model_contents) |  | ||||||
|     else: |  | ||||||
|       # Should never happen |  | ||||||
|       raise ValueError('model_file_type is invalid.') |  | ||||||
| 
 |  | ||||||
|     options = _ImageSegmenterOptions( |  | ||||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK) |  | ||||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: |  | ||||||
|       # Performs image segmentation on the input. |  | ||||||
|       category_masks = segmenter.segment(self.test_image) |  | ||||||
|       self.assertLen(category_masks, 1) |  | ||||||
| 
 | 
 | ||||||
|       self.assertTrue( |       self.assertTrue( | ||||||
|           _similar_to_uint8_mask(category_masks[0], self.test_seg_image), |           _similar_to_float_mask( | ||||||
|           f'Number of pixels in the candidate mask differing from that of the ' |               confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||||
|           f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') |           ) | ||||||
|  |       ) | ||||||
| 
 | 
 | ||||||
|   def test_missing_result_callback(self): |   def test_missing_result_callback(self): | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|  | @ -280,20 +275,49 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
|           ValueError, r'Input timestamp must be monotonically increasing'): |           ValueError, r'Input timestamp must be monotonically increasing'): | ||||||
|         segmenter.segment_for_video(self.test_image, 0) |         segmenter.segment_for_video(self.test_image, 0) | ||||||
| 
 | 
 | ||||||
|   def test_segment_for_video(self): |   def test_segment_for_video_in_category_mask_mode(self): | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), |         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|         output_type=_OutputType.CATEGORY_MASK, |         output_category_mask=True, | ||||||
|         running_mode=_RUNNING_MODE.VIDEO) |         running_mode=_RUNNING_MODE.VIDEO) | ||||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: |     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||||
|       for timestamp in range(0, 300, 30): |       for timestamp in range(0, 300, 30): | ||||||
|         category_masks = segmenter.segment_for_video(self.test_image, timestamp) |         segmentation_result = segmenter.segment_for_video( | ||||||
|         self.assertLen(category_masks, 1) |             self.test_image, timestamp) | ||||||
|  |         category_mask = segmentation_result.category_mask | ||||||
|         self.assertTrue( |         self.assertTrue( | ||||||
|             _similar_to_uint8_mask(category_masks[0], self.test_seg_image), |             _similar_to_uint8_mask(category_mask, self.test_seg_image), | ||||||
|             f'Number of pixels in the candidate mask differing from that of the ' |             f'Number of pixels in the candidate mask differing from that of the ' | ||||||
|             f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') |             f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') | ||||||
| 
 | 
 | ||||||
|  |   def test_segment_for_video_in_confidence_mask_mode(self): | ||||||
|  |     # Load the cat image. | ||||||
|  |     test_image = _Image.create_from_file( | ||||||
|  |       test_utils.get_test_data_path( | ||||||
|  |         os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||||
|  | 
 | ||||||
|  |     options = _ImageSegmenterOptions( | ||||||
|  |       base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|  |       running_mode=_RUNNING_MODE.VIDEO) | ||||||
|  |     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||||
|  |       for timestamp in range(0, 300, 30): | ||||||
|  |         segmentation_result = segmenter.segment_for_video( | ||||||
|  |           test_image, timestamp) | ||||||
|  |         confidence_masks = segmentation_result.confidence_masks | ||||||
|  | 
 | ||||||
|  |         # Check if confidence mask shape is correct. | ||||||
|  |         self.assertLen( | ||||||
|  |           confidence_masks, 21, | ||||||
|  |           'Number of confidence masks must match with number of categories.') | ||||||
|  | 
 | ||||||
|  |         # Loads ground truth segmentation file. | ||||||
|  |         expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||||
|  |         self.assertTrue( | ||||||
|  |           _similar_to_float_mask( | ||||||
|  |             confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||||
|  |           ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|   def test_calling_segment_in_live_stream_mode(self): |   def test_calling_segment_in_live_stream_mode(self): | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), |         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|  | @ -325,13 +349,13 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
|           ValueError, r'Input timestamp must be monotonically increasing'): |           ValueError, r'Input timestamp must be monotonically increasing'): | ||||||
|         segmenter.segment_async(self.test_image, 0) |         segmenter.segment_async(self.test_image, 0) | ||||||
| 
 | 
 | ||||||
|   def test_segment_async_calls(self): |   def test_segment_async_calls_in_category_mask_mode(self): | ||||||
|     observed_timestamp_ms = -1 |     observed_timestamp_ms = -1 | ||||||
| 
 | 
 | ||||||
|     def check_result(result: List[image_module.Image], output_image: _Image, |     def check_result(result: ImageSegmenterResult, output_image: _Image, | ||||||
|                      timestamp_ms: int): |                      timestamp_ms: int): | ||||||
|       # Get the output category mask. |       # Get the output category mask. | ||||||
|       category_mask = result[0] |       category_mask = result.category_mask | ||||||
|       self.assertEqual(output_image.width, self.test_image.width) |       self.assertEqual(output_image.width, self.test_image.width) | ||||||
|       self.assertEqual(output_image.height, self.test_image.height) |       self.assertEqual(output_image.height, self.test_image.height) | ||||||
|       self.assertEqual(output_image.width, self.test_seg_image.width) |       self.assertEqual(output_image.width, self.test_seg_image.width) | ||||||
|  | @ -345,13 +369,49 @@ class ImageSegmenterTest(parameterized.TestCase): | ||||||
| 
 | 
 | ||||||
|     options = _ImageSegmenterOptions( |     options = _ImageSegmenterOptions( | ||||||
|         base_options=_BaseOptions(model_asset_path=self.model_path), |         base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|         output_type=_OutputType.CATEGORY_MASK, |         output_category_mask=True, | ||||||
|         running_mode=_RUNNING_MODE.LIVE_STREAM, |         running_mode=_RUNNING_MODE.LIVE_STREAM, | ||||||
|         result_callback=check_result) |         result_callback=check_result) | ||||||
|     with _ImageSegmenter.create_from_options(options) as segmenter: |     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||||
|       for timestamp in range(0, 300, 30): |       for timestamp in range(0, 300, 30): | ||||||
|         segmenter.segment_async(self.test_image, timestamp) |         segmenter.segment_async(self.test_image, timestamp) | ||||||
| 
 | 
 | ||||||
|  |   def test_segment_async_calls_in_confidence_mask_mode(self): | ||||||
|  |     # Load the cat image. | ||||||
|  |     test_image = _Image.create_from_file( | ||||||
|  |       test_utils.get_test_data_path( | ||||||
|  |         os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) | ||||||
|  | 
 | ||||||
|  |     # Loads ground truth segmentation file. | ||||||
|  |     expected_mask = self._load_segmentation_mask(_CAT_MASK) | ||||||
|  |     observed_timestamp_ms = -1 | ||||||
|  | 
 | ||||||
|  |     def check_result(result: ImageSegmenterResult, output_image: _Image, | ||||||
|  |                      timestamp_ms: int): | ||||||
|  |       # Get the output category mask. | ||||||
|  |       confidence_masks = result.confidence_masks | ||||||
|  | 
 | ||||||
|  |       # Check if confidence mask shape is correct. | ||||||
|  |       self.assertLen( | ||||||
|  |         confidence_masks, 21, | ||||||
|  |         'Number of confidence masks must match with number of categories.') | ||||||
|  | 
 | ||||||
|  |       self.assertTrue( | ||||||
|  |         _similar_to_float_mask( | ||||||
|  |           confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD | ||||||
|  |         ) | ||||||
|  |       ) | ||||||
|  |       self.assertLess(observed_timestamp_ms, timestamp_ms) | ||||||
|  |       self.observed_timestamp_ms = timestamp_ms | ||||||
|  | 
 | ||||||
|  |     options = _ImageSegmenterOptions( | ||||||
|  |       base_options=_BaseOptions(model_asset_path=self.model_path), | ||||||
|  |       running_mode=_RUNNING_MODE.LIVE_STREAM, | ||||||
|  |       result_callback=check_result) | ||||||
|  |     with _ImageSegmenter.create_from_options(options) as segmenter: | ||||||
|  |       for timestamp in range(0, 300, 30): | ||||||
|  |         segmenter.segment_async(test_image, timestamp) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   absltest.main() |   absltest.main() | ||||||
|  |  | ||||||
|  | @ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils | ||||||
| from mediapipe.tasks.python.vision import interactive_segmenter | from mediapipe.tasks.python.vision import interactive_segmenter | ||||||
| from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | ||||||
| 
 | 
 | ||||||
|  | InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult | ||||||
| _BaseOptions = base_options_module.BaseOptions | _BaseOptions = base_options_module.BaseOptions | ||||||
| _Image = image_module.Image | _Image = image_module.Image | ||||||
| _ImageFormat = image_frame.ImageFormat | _ImageFormat = image_frame.ImageFormat | ||||||
| _NormalizedKeypoint = keypoint_module.NormalizedKeypoint | _NormalizedKeypoint = keypoint_module.NormalizedKeypoint | ||||||
| _Rect = rect.Rect | _Rect = rect.Rect | ||||||
| _OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType |  | ||||||
| _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter | _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter | ||||||
| _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions | _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions | ||||||
| _RegionOfInterest = interactive_segmenter.RegionOfInterest | _RegionOfInterest = interactive_segmenter.RegionOfInterest | ||||||
|  | @ -200,15 +200,14 @@ class InteractiveSegmenterTest(parameterized.TestCase): | ||||||
|       raise ValueError('model_file_type is invalid.') |       raise ValueError('model_file_type is invalid.') | ||||||
| 
 | 
 | ||||||
|     options = _InteractiveSegmenterOptions( |     options = _InteractiveSegmenterOptions( | ||||||
|         base_options=base_options, output_type=_OutputType.CATEGORY_MASK |         base_options=base_options, output_category_mask=True | ||||||
|     ) |     ) | ||||||
|     segmenter = _InteractiveSegmenter.create_from_options(options) |     segmenter = _InteractiveSegmenter.create_from_options(options) | ||||||
| 
 | 
 | ||||||
|     # Performs image segmentation on the input. |     # Performs image segmentation on the input. | ||||||
|     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) |     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) | ||||||
|     category_masks = segmenter.segment(self.test_image, roi) |     segmentation_result = segmenter.segment(self.test_image, roi) | ||||||
|     self.assertLen(category_masks, 1) |     category_mask = segmentation_result.category_mask | ||||||
|     category_mask = category_masks[0] |  | ||||||
|     result_pixels = category_mask.numpy_view().flatten() |     result_pixels = category_mask.numpy_view().flatten() | ||||||
| 
 | 
 | ||||||
|     # Check if data type of `category_mask` is correct. |     # Check if data type of `category_mask` is correct. | ||||||
|  | @ -219,7 +218,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): | ||||||
| 
 | 
 | ||||||
|     self.assertTrue( |     self.assertTrue( | ||||||
|         _similar_to_uint8_mask( |         _similar_to_uint8_mask( | ||||||
|             category_masks[0], test_seg_image, similarity_threshold |             category_mask, test_seg_image, similarity_threshold | ||||||
|         ), |         ), | ||||||
|         ( |         ( | ||||||
|             'Number of pixels in the candidate mask differing from that of the' |             'Number of pixels in the candidate mask differing from that of the' | ||||||
|  | @ -253,13 +252,12 @@ class InteractiveSegmenterTest(parameterized.TestCase): | ||||||
|     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) |     roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) | ||||||
| 
 | 
 | ||||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. |     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||||
|     options = _InteractiveSegmenterOptions( |     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     with _InteractiveSegmenter.create_from_options(options) as segmenter: |     with _InteractiveSegmenter.create_from_options(options) as segmenter: | ||||||
|       # Perform segmentation |       # Perform segmentation | ||||||
|       confidence_masks = segmenter.segment(self.test_image, roi) |       segmentation_result = segmenter.segment(self.test_image, roi) | ||||||
|  |       confidence_masks = segmentation_result.confidence_masks | ||||||
| 
 | 
 | ||||||
|       # Check if confidence mask shape is correct. |       # Check if confidence mask shape is correct. | ||||||
|       self.assertLen( |       self.assertLen( | ||||||
|  | @ -286,16 +284,15 @@ class InteractiveSegmenterTest(parameterized.TestCase): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. |     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||||
|     options = _InteractiveSegmenterOptions( |     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     with _InteractiveSegmenter.create_from_options(options) as segmenter: |     with _InteractiveSegmenter.create_from_options(options) as segmenter: | ||||||
|       # Perform segmentation |       # Perform segmentation | ||||||
|       image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) |       image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) | ||||||
|       confidence_masks = segmenter.segment( |       segmentation_result = segmenter.segment( | ||||||
|           self.test_image, roi, image_processing_options |           self.test_image, roi, image_processing_options | ||||||
|       ) |       ) | ||||||
|  |       confidence_masks = segmentation_result.confidence_masks | ||||||
| 
 | 
 | ||||||
|       # Check if confidence mask shape is correct. |       # Check if confidence mask shape is correct. | ||||||
|       self.assertLen( |       self.assertLen( | ||||||
|  | @ -313,9 +310,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Run segmentation on the model in CONFIDENCE_MASK mode. |     # Run segmentation on the model in CONFIDENCE_MASK mode. | ||||||
|     options = _InteractiveSegmenterOptions( |     options = _InteractiveSegmenterOptions(base_options=base_options) | ||||||
|         base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     with self.assertRaisesRegex( |     with self.assertRaisesRegex( | ||||||
|         ValueError, "This task doesn't support region-of-interest." |         ValueError, "This task doesn't support region-of-interest." | ||||||
|  |  | ||||||
|  | @ -31,7 +31,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api | ||||||
| from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module | ||||||
| from mediapipe.tasks.python.vision.core import vision_task_running_mode | from mediapipe.tasks.python.vision.core import vision_task_running_mode | ||||||
| 
 | 
 | ||||||
| ImageSegmenterResult = List[image_module.Image] |  | ||||||
| _NormalizedRect = rect.NormalizedRect | _NormalizedRect = rect.NormalizedRect | ||||||
| _BaseOptions = base_options_module.BaseOptions | _BaseOptions = base_options_module.BaseOptions | ||||||
| _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions | _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions | ||||||
|  | @ -42,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode | ||||||
| _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | ||||||
| _TaskInfo = task_info_module.TaskInfo | _TaskInfo = task_info_module.TaskInfo | ||||||
| 
 | 
 | ||||||
| _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' | _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' | ||||||
| _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' | _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' | ||||||
|  | _CATEGORY_MASK_STREAM_NAME = 'category_mask' | ||||||
|  | _CATEGORY_MASK_TAG = 'CATEGORY_MASK' | ||||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | _IMAGE_IN_STREAM_NAME = 'image_in' | ||||||
| _IMAGE_OUT_STREAM_NAME = 'image_out' | _IMAGE_OUT_STREAM_NAME = 'image_out' | ||||||
| _IMAGE_TAG = 'IMAGE' | _IMAGE_TAG = 'IMAGE' | ||||||
|  | @ -53,6 +54,12 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' | ||||||
| _MICRO_SECONDS_PER_MILLISECOND = 1000 | _MICRO_SECONDS_PER_MILLISECOND = 1000 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclasses.dataclass | ||||||
|  | class ImageSegmenterResult: | ||||||
|  |   confidence_masks: Optional[List[image_module.Image]] = None | ||||||
|  |   category_mask: Optional[image_module.Image] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclasses.dataclass | @dataclasses.dataclass | ||||||
| class ImageSegmenterOptions: | class ImageSegmenterOptions: | ||||||
|   """Options for the image segmenter task. |   """Options for the image segmenter task. | ||||||
|  | @ -64,19 +71,13 @@ class ImageSegmenterOptions: | ||||||
|       objects on single image inputs. 2) The video mode for segmenting objects |       objects on single image inputs. 2) The video mode for segmenting objects | ||||||
|       on the decoded frames of a video. 3) The live stream mode for segmenting |       on the decoded frames of a video. 3) The live stream mode for segmenting | ||||||
|       objects on a live stream of input data, such as from camera. |       objects on a live stream of input data, such as from camera. | ||||||
|     output_type: The output mask type allows specifying the type of |     output_confidence_masks: Whether to output confidence masks. | ||||||
|       post-processing to perform on the raw model results. |     output_category_mask: Whether to output category mask. | ||||||
|     activation: Activation function to apply to input tensor. |  | ||||||
|     result_callback: The user-defined result callback for processing live stream |     result_callback: The user-defined result callback for processing live stream | ||||||
|       data. The result callback should only be specified when the running mode |       data. The result callback should only be specified when the running mode | ||||||
|       is set to the live stream mode. |       is set to the live stream mode. | ||||||
|   """ |   """ | ||||||
| 
 | 
 | ||||||
|   class OutputType(enum.Enum): |  | ||||||
|     UNSPECIFIED = 0 |  | ||||||
|     CATEGORY_MASK = 1 |  | ||||||
|     CONFIDENCE_MASK = 2 |  | ||||||
| 
 |  | ||||||
|   class Activation(enum.Enum): |   class Activation(enum.Enum): | ||||||
|     NONE = 0 |     NONE = 0 | ||||||
|     SIGMOID = 1 |     SIGMOID = 1 | ||||||
|  | @ -84,7 +85,8 @@ class ImageSegmenterOptions: | ||||||
| 
 | 
 | ||||||
|   base_options: _BaseOptions |   base_options: _BaseOptions | ||||||
|   running_mode: _RunningMode = _RunningMode.IMAGE |   running_mode: _RunningMode = _RunningMode.IMAGE | ||||||
|   output_type: Optional[OutputType] = OutputType.CATEGORY_MASK |   output_confidence_masks: bool = True | ||||||
|  |   output_category_mask: bool = False | ||||||
|   activation: Optional[Activation] = Activation.NONE |   activation: Optional[Activation] = Activation.NONE | ||||||
|   result_callback: Optional[ |   result_callback: Optional[ | ||||||
|       Callable[[ImageSegmenterResult, image_module.Image, int], None] |       Callable[[ImageSegmenterResult, image_module.Image, int], None] | ||||||
|  | @ -98,7 +100,7 @@ class ImageSegmenterOptions: | ||||||
|         False if self.running_mode == _RunningMode.IMAGE else True |         False if self.running_mode == _RunningMode.IMAGE else True | ||||||
|     ) |     ) | ||||||
|     segmenter_options_proto = _SegmenterOptionsProto( |     segmenter_options_proto = _SegmenterOptionsProto( | ||||||
|         output_type=self.output_type.value, activation=self.activation.value |         activation=self.activation.value | ||||||
|     ) |     ) | ||||||
|     return _ImageSegmenterGraphOptionsProto( |     return _ImageSegmenterGraphOptionsProto( | ||||||
|         base_options=base_options_proto, |         base_options=base_options_proto, | ||||||
|  | @ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|     def packets_callback(output_packets: Mapping[str, packet.Packet]): |     def packets_callback(output_packets: Mapping[str, packet.Packet]): | ||||||
|       if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): |       if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): | ||||||
|         return |         return | ||||||
|       segmentation_result = packet_getter.get_image_list( | 
 | ||||||
|           output_packets[_SEGMENTATION_OUT_STREAM_NAME] |       segmentation_result = ImageSegmenterResult() | ||||||
|  | 
 | ||||||
|  |       if options.output_confidence_masks: | ||||||
|  |         segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||||
|  |             output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|  |       if options.output_category_mask: | ||||||
|  |         segmentation_result.category_mask = packet_getter.get_image( | ||||||
|  |             output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|       image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) |       image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) | ||||||
|       timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp |       timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp | ||||||
|       options.result_callback( |       options.result_callback( | ||||||
|           segmentation_result, |           segmentation_result, | ||||||
|           image, |           image, | ||||||
|           timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, |           timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, | ||||||
|       ) |       ) | ||||||
| 
 | 
 | ||||||
|  |     output_streams = [ | ||||||
|  |         ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     if options.output_confidence_masks: | ||||||
|  |       output_streams.append( | ||||||
|  |         ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|  |     if options.output_category_mask: | ||||||
|  |       output_streams.append( | ||||||
|  |         ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|     task_info = _TaskInfo( |     task_info = _TaskInfo( | ||||||
|         task_graph=_TASK_GRAPH_NAME, |         task_graph=_TASK_GRAPH_NAME, | ||||||
|         input_streams=[ |         input_streams=[ | ||||||
|             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), |             ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), | ||||||
|             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), |             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), | ||||||
|         ], |         ], | ||||||
|         output_streams=[ |         output_streams=output_streams, | ||||||
|             ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), |  | ||||||
|             ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), |  | ||||||
|         ], |  | ||||||
|         task_options=options, |         task_options=options, | ||||||
|     ) |     ) | ||||||
|     return cls( |     return cls( | ||||||
|  | @ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|             normalized_rect.to_pb2() |             normalized_rect.to_pb2() | ||||||
|         ), |         ), | ||||||
|     }) |     }) | ||||||
|     segmentation_result = packet_getter.get_image_list( |     segmentation_result = ImageSegmenterResult() | ||||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | 
 | ||||||
|  |     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||||
|  |           output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||||
|       ) |       ) | ||||||
|  | 
 | ||||||
|  |     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.category_mask = packet_getter.get_image( | ||||||
|  |           output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|     return segmentation_result |     return segmentation_result | ||||||
| 
 | 
 | ||||||
|   def segment_for_video( |   def segment_for_video( | ||||||
|  | @ -285,9 +317,19 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|             normalized_rect.to_pb2() |             normalized_rect.to_pb2() | ||||||
|         ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), |         ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), | ||||||
|     }) |     }) | ||||||
|     segmentation_result = packet_getter.get_image_list( |     segmentation_result = ImageSegmenterResult() | ||||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | 
 | ||||||
|  |     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||||
|  |           output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||||
|       ) |       ) | ||||||
|  | 
 | ||||||
|  |     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.category_mask = packet_getter.get_image( | ||||||
|  |           output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     return segmentation_result |     return segmentation_result | ||||||
| 
 | 
 | ||||||
|   def segment_async( |   def segment_async( | ||||||
|  |  | ||||||
|  | @ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode | ||||||
| _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions | ||||||
| _TaskInfo = task_info_module.TaskInfo | _TaskInfo = task_info_module.TaskInfo | ||||||
| 
 | 
 | ||||||
| _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' | _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' | ||||||
| _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' | _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' | ||||||
|  | _CATEGORY_MASK_STREAM_NAME = 'category_mask' | ||||||
|  | _CATEGORY_MASK_TAG = 'CATEGORY_MASK' | ||||||
| _IMAGE_IN_STREAM_NAME = 'image_in' | _IMAGE_IN_STREAM_NAME = 'image_in' | ||||||
| _IMAGE_OUT_STREAM_NAME = 'image_out' | _IMAGE_OUT_STREAM_NAME = 'image_out' | ||||||
| _ROI_STREAM_NAME = 'roi_in' | _ROI_STREAM_NAME = 'roi_in' | ||||||
|  | @ -55,32 +57,32 @@ _TASK_GRAPH_NAME = ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclasses.dataclass | ||||||
|  | class InteractiveSegmenterResult: | ||||||
|  |   confidence_masks: Optional[List[image_module.Image]] = None | ||||||
|  |   category_mask: Optional[image_module.Image] = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclasses.dataclass | @dataclasses.dataclass | ||||||
| class InteractiveSegmenterOptions: | class InteractiveSegmenterOptions: | ||||||
|   """Options for the interactive segmenter task. |   """Options for the interactive segmenter task. | ||||||
| 
 | 
 | ||||||
|   Attributes: |   Attributes: | ||||||
|     base_options: Base options for the interactive segmenter task. |     base_options: Base options for the interactive segmenter task. | ||||||
|     output_type: The output mask type allows specifying the type of |     output_confidence_masks: Whether to output confidence masks. | ||||||
|       post-processing to perform on the raw model results. |     output_category_mask: Whether to output category mask. | ||||||
|   """ |   """ | ||||||
| 
 | 
 | ||||||
|   class OutputType(enum.Enum): |  | ||||||
|     UNSPECIFIED = 0 |  | ||||||
|     CATEGORY_MASK = 1 |  | ||||||
|     CONFIDENCE_MASK = 2 |  | ||||||
| 
 |  | ||||||
|   base_options: _BaseOptions |   base_options: _BaseOptions | ||||||
|   output_type: Optional[OutputType] = OutputType.CATEGORY_MASK |   output_confidence_masks: bool = True | ||||||
|  |   output_category_mask: bool = False | ||||||
| 
 | 
 | ||||||
|   @doc_controls.do_not_generate_docs |   @doc_controls.do_not_generate_docs | ||||||
|   def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: |   def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: | ||||||
|     """Generates an InteractiveSegmenterOptions protobuf object.""" |     """Generates an InteractiveSegmenterOptions protobuf object.""" | ||||||
|     base_options_proto = self.base_options.to_pb2() |     base_options_proto = self.base_options.to_pb2() | ||||||
|     base_options_proto.use_stream_mode = False |     base_options_proto.use_stream_mode = False | ||||||
|     segmenter_options_proto = _SegmenterOptionsProto( |     segmenter_options_proto = _SegmenterOptionsProto() | ||||||
|         output_type=self.output_type.value |  | ||||||
|     ) |  | ||||||
|     return _ImageSegmenterGraphOptionsProto( |     return _ImageSegmenterGraphOptionsProto( | ||||||
|         base_options=base_options_proto, |         base_options=base_options_proto, | ||||||
|         segmenter_options=segmenter_options_proto, |         segmenter_options=segmenter_options_proto, | ||||||
|  | @ -192,6 +194,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|       RuntimeError: If other types of error occurred. |       RuntimeError: If other types of error occurred. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|  |     output_streams = [ | ||||||
|  |       ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     if options.output_confidence_masks: | ||||||
|  |       output_streams.append( | ||||||
|  |         ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|  |     if options.output_category_mask: | ||||||
|  |       output_streams.append( | ||||||
|  |         ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|     task_info = _TaskInfo( |     task_info = _TaskInfo( | ||||||
|         task_graph=_TASK_GRAPH_NAME, |         task_graph=_TASK_GRAPH_NAME, | ||||||
|         input_streams=[ |         input_streams=[ | ||||||
|  | @ -199,10 +215,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|             ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), |             ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), | ||||||
|             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), |             ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), | ||||||
|         ], |         ], | ||||||
|         output_streams=[ |         output_streams=output_streams, | ||||||
|             ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), |  | ||||||
|             ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), |  | ||||||
|         ], |  | ||||||
|         task_options=options, |         task_options=options, | ||||||
|     ) |     ) | ||||||
|     return cls( |     return cls( | ||||||
|  | @ -216,7 +229,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|       image: image_module.Image, |       image: image_module.Image, | ||||||
|       roi: RegionOfInterest, |       roi: RegionOfInterest, | ||||||
|       image_processing_options: Optional[_ImageProcessingOptions] = None, |       image_processing_options: Optional[_ImageProcessingOptions] = None, | ||||||
|   ) -> List[image_module.Image]: |   ) -> InteractiveSegmenterResult: | ||||||
|     """Performs the actual segmentation task on the provided MediaPipe Image. |     """Performs the actual segmentation task on the provided MediaPipe Image. | ||||||
| 
 | 
 | ||||||
|     The image can be of any size with format RGB. |     The image can be of any size with format RGB. | ||||||
|  | @ -248,7 +261,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): | ||||||
|             normalized_rect.to_pb2() |             normalized_rect.to_pb2() | ||||||
|         ), |         ), | ||||||
|     }) |     }) | ||||||
|     segmentation_result = packet_getter.get_image_list( |     segmentation_result = InteractiveSegmenterResult() | ||||||
|         output_packets[_SEGMENTATION_OUT_STREAM_NAME] | 
 | ||||||
|  |     if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.confidence_masks = packet_getter.get_image_list( | ||||||
|  |         output_packets[_CONFIDENCE_MASKS_STREAM_NAME] | ||||||
|       ) |       ) | ||||||
|  | 
 | ||||||
|  |     if _CATEGORY_MASK_STREAM_NAME in output_packets: | ||||||
|  |       segmentation_result.category_mask = packet_getter.get_image( | ||||||
|  |         output_packets[_CATEGORY_MASK_STREAM_NAME] | ||||||
|  |       ) | ||||||
|  | 
 | ||||||
|     return segmentation_result |     return segmentation_result | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								mediapipe/tasks/testdata/vision/BUILD
									
									
									
									
										vendored
									
									
								
							|  | @ -77,6 +77,7 @@ mediapipe_files(srcs = [ | ||||||
|     "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", |     "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", | ||||||
|     "pose.jpg", |     "pose.jpg", | ||||||
|     "pose_detection.tflite", |     "pose_detection.tflite", | ||||||
|  |     "ptm_512_hdt_ptm_woid.tflite", | ||||||
|     "pose_landmark_lite.tflite", |     "pose_landmark_lite.tflite", | ||||||
|     "pose_landmarker.task", |     "pose_landmarker.task", | ||||||
|     "right_hands.jpg", |     "right_hands.jpg", | ||||||
|  | @ -187,6 +188,7 @@ filegroup( | ||||||
|         "mobilenet_v3_small_100_224_embedder.tflite", |         "mobilenet_v3_small_100_224_embedder.tflite", | ||||||
|         "palm_detection_full.tflite", |         "palm_detection_full.tflite", | ||||||
|         "pose_detection.tflite", |         "pose_detection.tflite", | ||||||
|  |         "ptm_512_hdt_ptm_woid.tflite", | ||||||
|         "pose_landmark_lite.tflite", |         "pose_landmark_lite.tflite", | ||||||
|         "pose_landmarker.task", |         "pose_landmarker.task", | ||||||
|         "selfie_segm_128_128_3.tflite", |         "selfie_segm_128_128_3.tflite", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user