Changed labels to be a property

This commit is contained in:
kinaryml 2023-04-18 21:31:14 -07:00
parent d621df8046
commit 1cb404bea1
2 changed files with 21 additions and 18 deletions

View File

@ -247,14 +247,16 @@ class ImageSegmenterTest(parameterized.TestCase):
)
)
def test_get_labels_succeeds(self):
def test_labels_succeeds(self):
expected_labels = _EXPECTED_LABELS
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
base_options=base_options, output_category_mask=True,
output_confidence_masks=False
)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.
actual_labels = segmenter.get_labels()
actual_labels = segmenter.labels
self.assertListEqual(actual_labels, expected_labels)
def test_missing_result_callback(self):

View File

@ -151,7 +151,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Exception if there is an error during finding TensorsToSegmentationCalculator.
:return:
"""
self.labels = []
self._labels = []
graph_config = self._runner.get_graph_config()
found_tensors_to_segmentation = False
@ -170,7 +170,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
for i in range(len(options.label_items)):
if i not in options.label_items:
raise Exception(f"The labelmap has no expected key: {i}.")
self.labels.append(options.label_items[i].name)
self._labels.append(options.label_items[i].name)
@classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
@ -271,19 +271,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
packets_callback if options.result_callback else None,
)
def get_labels(self):
""" Get the category label list of the ImageSegmenter can recognize.
For CATEGORY_MASK type, the index in the category mask corresponds to the
category in the label list.
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
category in the label list.
If there is no label map provided in the model file, empty label list is
returned.
"""
return self.labels
def segment(
self,
image: image_module.Image,
@ -427,3 +414,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
@property
def labels(self) -> List[str]:
""" Get the category label list of the ImageSegmenter can recognize.
For CATEGORY_MASK type, the index in the category mask corresponds to the
category in the label list.
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
category in the label list.
If there is no label map provided in the model file, empty label list is
returned.
"""
return self._labels