From 1cb404bea16a4f36df8abeb583a72b9819776583 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Tue, 18 Apr 2023 21:31:14 -0700 Subject: [PATCH] Changed labels to be a property --- .../test/vision/image_segmenter_test.py | 8 +++-- .../tasks/python/vision/image_segmenter.py | 31 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index b54b53994..3458bb504 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -247,14 +247,16 @@ class ImageSegmenterTest(parameterized.TestCase): ) ) - def test_get_labels_succeeds(self): + def test_labels_succeeds(self): expected_labels = _EXPECTED_LABELS base_options = _BaseOptions(model_asset_path=self.model_path) options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) + base_options=base_options, output_category_mask=True, + output_confidence_masks=False + ) with _ImageSegmenter.create_from_options(options) as segmenter: # Performs image segmentation on the input. - actual_labels = segmenter.get_labels() + actual_labels = segmenter.labels self.assertListEqual(actual_labels, expected_labels) def test_missing_result_callback(self): diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 4119f2632..a6c9501c2 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -151,7 +151,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): Exception if there is an error during finding TensorsToSegmentationCalculator. :return: """ - self.labels = [] + self._labels = [] graph_config = self._runner.get_graph_config() found_tensors_to_segmentation = False @@ -170,7 +170,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): for i in range(len(options.label_items)): if i not in options.label_items: raise Exception(f"The labelmap has no expected key: {i}.") - self.labels.append(options.label_items[i].name) + self._labels.append(options.label_items[i].name) @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': @@ -271,19 +271,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): packets_callback if options.result_callback else None, ) - def get_labels(self): - """ Get the category label list of the ImageSegmenter can recognize. - - For CATEGORY_MASK type, the index in the category mask corresponds to the - category in the label list. - For CONFIDENCE_MASK type, the output mask list at index corresponds to the - category in the label list. - - If there is no label map provided in the model file, empty label list is - returned. - """ - return self.labels - def segment( self, image: image_module.Image, @@ -427,3 +414,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) + + @property + def labels(self) -> List[str]: + """ Get the category label list of the ImageSegmenter can recognize. + + For CATEGORY_MASK type, the index in the category mask corresponds to the + category in the label list. + For CONFIDENCE_MASK type, the output mask list at index corresponds to the + category in the label list. + + If there is no label map provided in the model file, empty label list is + returned. + """ + return self._labels