From 67b72e4fe9b6765c3d134d88a6ba77ac50a35a05 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 18 Apr 2023 21:43:38 -0700 Subject: [PATCH] Code cleanup --- .../python/test/vision/image_segmenter_test.py | 7 ++++--- mediapipe/tasks/python/vision/image_segmenter.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 3458bb504..009dc685a 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -247,12 +247,13 @@ class ImageSegmenterTest(parameterized.TestCase): ) ) - def test_labels_succeeds(self): + @parameterized.parameters((True, False), (False, True)) + def test_labels_succeeds(self, output_category_mask, output_confidence_masks): expected_labels = _EXPECTED_LABELS base_options = _BaseOptions(model_asset_path=self.model_path) options = _ImageSegmenterOptions( - base_options=base_options, output_category_mask=True, - output_confidence_masks=False + base_options=base_options, output_category_mask=output_category_mask, + output_confidence_masks=output_confidence_masks ) with _ImageSegmenter.create_from_options(options) as segmenter: # Performs image segmentation on the input. diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index a6c9501c2..220d7818f 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -129,27 +129,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): Output tensors: (kTfLiteUInt8/kTfLiteFloat32) - list of segmented masks. - - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + - if `output_category_mask` is True, uint8 Image, Image vector of size 1. + - if `output_confidence_masks` is True, float32 Image list of size `channels`. - batch is always 1 An example of such model can be found at: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 """ - def __init__(self, graph_config, running_mode, packet_callback): + def __init__(self, graph_config, running_mode, packet_callback) -> None: + """Initializes the `ImageSegmenter` object.""" super(ImageSegmenter, self).__init__( graph_config, running_mode, packet_callback ) self._populate_labels() - def _populate_labels(self): + def _populate_labels(self) -> None: """ Populate the labelmap in TensorsToSegmentationCalculator to labels field. - Returns: - Exception if there is an error during finding TensorsToSegmentationCalculator. - :return: + Raises: + Exception if there is an error during finding + TensorsToSegmentationCalculator. """ self._labels = [] graph_config = self._runner.get_graph_config()