diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 327925191..aa557281f 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -157,7 +157,9 @@ class ImageSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageSegmenterOptions( - base_options=base_options, output_category_mask=True) + base_options=base_options, output_category_mask=True, + output_confidence_masks=False + ) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. @@ -188,8 +190,9 @@ class ImageSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( - base_options=base_options, - activation=_Activation.SOFTMAX) + base_options=base_options, output_category_mask=False, + output_confidence_masks=True, activation=_Activation.SOFTMAX + ) with _ImageSegmenter.create_from_options(options) as segmenter: segmentation_result = segmenter.segment(test_image) @@ -279,7 +282,9 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), output_category_mask=True, - running_mode=_RUNNING_MODE.VIDEO) + output_confidence_masks=False, + running_mode=_RUNNING_MODE.VIDEO + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmentation_result = segmenter.segment_for_video( @@ -297,8 +302,10 @@ class ImageSegmenterTest(parameterized.TestCase): os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))) options = _ImageSegmenterOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, output_category_mask=False, + output_confidence_masks=True + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmentation_result = segmenter.segment_for_video( @@ -370,8 +377,10 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), output_category_mask=True, + output_confidence_masks=False, running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + result_callback=check_result + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(self.test_image, timestamp) @@ -405,9 +414,12 @@ class ImageSegmenterTest(parameterized.TestCase): self.observed_timestamp_ms = timestamp_ms options = _ImageSegmenterOptions( - base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_category_mask=False, + output_confidence_masks=True, + result_callback=check_result + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(test_image, timestamp) diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index 6af15aa09..aea4f8a1d 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -200,7 +200,8 @@ class InteractiveSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _InteractiveSegmenterOptions( - base_options=base_options, output_category_mask=True + base_options=base_options, output_category_mask=True, + output_confidence_masks=False ) segmenter = _InteractiveSegmenter.create_from_options(options) @@ -252,7 +253,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation @@ -284,7 +288,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation @@ -310,7 +317,10 @@ class InteractiveSegmenterTest(parameterized.TestCase): ) # Run segmentation on the model in CONFIDENCE_MASK mode. - options = _InteractiveSegmenterOptions(base_options=base_options) + options = _InteractiveSegmenterOptions( + base_options=base_options, output_category_mask=False, + output_confidence_masks=True + ) with self.assertRaisesRegex( ValueError, "This task doesn't support region-of-interest."