From e364aeb35908cb03b08bd528aa252b8a36c5f0b2 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sat, 25 Mar 2023 00:44:30 -0700 Subject: [PATCH] Revised Interactive Segmenter API and added more tests --- mediapipe/tasks/python/test/vision/BUILD | 2 + .../test/vision/interactive_segmenter_test.py | 82 ++++++++++++++----- .../python/vision/interactive_segmenter.py | 3 +- 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index f42ba7104..b5579ea29 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -102,9 +102,11 @@ py_test( deps = [ "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:keypoint", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:interactive_segmenter", + "//mediapipe/tasks/python/vision/core:image_processing_options", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index 6a069c929..866cd19c1 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -24,20 +24,24 @@ import numpy as np from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image_frame from mediapipe.tasks.python.components.containers import keypoint as keypoint_module +from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import interactive_segmenter +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat _NormalizedKeypoint = keypoint_module.NormalizedKeypoint +_Rect = rect.Rect _OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _RegionOfInterest = interactive_segmenter.RegionOfInterest _Format = interactive_segmenter.RegionOfInterest.Format +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _MODEL_FILE = 'ptm_512_hdt_ptm_woid.tflite' @@ -50,11 +54,11 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' def _calculate_soft_iou(m1, m2): - intersection = np.sum(m1 * m2) - union = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection + intersection_sum = np.sum(m1 * m2) + union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum - if union > 0: - return intersection / union + if union_sum > 0: + return intersection_sum / union_sum else: return 0 @@ -189,9 +193,9 @@ class InteractiveSegmenterTest(parameterized.TestCase): @parameterized.parameters( (_RegionOfInterest.Format.KEYPOINT, _NormalizedKeypoint(0.44, 0.7), - _CATS_AND_DOGS_MASK_DOG_1, 0.58), + _CATS_AND_DOGS_MASK_DOG_1, 0.84), (_RegionOfInterest.Format.KEYPOINT, _NormalizedKeypoint(0.66, 0.66), - _CATS_AND_DOGS_MASK_DOG_2, 0.60) + _CATS_AND_DOGS_MASK_DOG_2, _MASK_SIMILARITY_THRESHOLD) ) def test_segment_succeeds_with_confidence_mask( self, format, keypoint, output_mask, similarity_threshold): @@ -203,26 +207,66 @@ class InteractiveSegmenterTest(parameterized.TestCase): options = _InteractiveSegmenterOptions( base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK) - segmenter = _InteractiveSegmenter.create_from_options(options) - # Perform segmentation - confidence_masks = segmenter.segment(self.test_image, roi) + with _InteractiveSegmenter.create_from_options(options) as segmenter: + # Perform segmentation + confidence_masks = segmenter.segment(self.test_image, roi) - # Check if confidence mask shape is correct. - self.assertLen( + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, 2, + 'Number of confidence masks must match with number of categories.') + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(output_mask) + + self.assertTrue( + _similar_to_float_mask(confidence_masks[1], expected_mask, + similarity_threshold)) + + def test_segment_succeeds_with_rotation(self): + # Creates segmenter. + base_options = _BaseOptions(model_asset_path=self.model_path) + roi = _RegionOfInterest( + format=_RegionOfInterest.Format.KEYPOINT, + keypoint=_NormalizedKeypoint(0.66, 0.66) + ) + + # Run segmentation on the model in CONFIDENCE_MASK mode. + options = _InteractiveSegmenterOptions( + base_options=base_options, + output_type=_OutputType.CONFIDENCE_MASK) + + with _InteractiveSegmenter.create_from_options(options) as segmenter: + # Perform segmentation + image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) + confidence_masks = segmenter.segment(self.test_image, roi) + + # Check if confidence mask shape is correct. + self.assertLen( confidence_masks, 2, 'Number of confidence masks must match with number of categories.') - # Loads ground truth segmentation file. - expected_mask = self._load_segmentation_mask(output_mask) + def test_segment_fails_with_roi_in_image_processing_options(self): + # Creates segmenter. + base_options = _BaseOptions(model_asset_path=self.model_path) + roi = _RegionOfInterest( + format=_RegionOfInterest.Format.KEYPOINT, + keypoint=_NormalizedKeypoint(0.66, 0.66) + ) - self.assertTrue( - _similar_to_float_mask(confidence_masks[1], expected_mask, - similarity_threshold)) + # Run segmentation on the model in CONFIDENCE_MASK mode. + options = _InteractiveSegmenterOptions( + base_options=base_options, + output_type=_OutputType.CONFIDENCE_MASK) - # Closes the segmenter explicitly when the segmenter is not used in - # a context. - segmenter.close() + with self.assertRaisesRegex( + ValueError, "This task doesn't support region-of-interest."): + with _InteractiveSegmenter.create_from_options(options) as segmenter: + # Perform segmentation + image_processing_options = _ImageProcessingOptions( + _Rect(left=0.1, top=0, right=0.9, bottom=1)) + segmenter.segment(self.test_image, roi, image_processing_options) if __name__ == '__main__': diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index ad524dfa4..94665cbbf 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -254,7 +254,8 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image segmentation failed to run. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect(image_processing_options, + roi_allowed=False) render_data_proto = _convert_roi_to_render_data(roi) output_packets = self._process_image_data( {