Changed labels to be a property
This commit is contained in:
parent
d621df8046
commit
1cb404bea1
|
@ -247,14 +247,16 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def test_get_labels_succeeds(self):
|
||||
def test_labels_succeeds(self):
|
||||
expected_labels = _EXPECTED_LABELS
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
||||
base_options=base_options, output_category_mask=True,
|
||||
output_confidence_masks=False
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
# Performs image segmentation on the input.
|
||||
actual_labels = segmenter.get_labels()
|
||||
actual_labels = segmenter.labels
|
||||
self.assertListEqual(actual_labels, expected_labels)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
|
|
|
@ -151,7 +151,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
Exception if there is an error during finding TensorsToSegmentationCalculator.
|
||||
:return:
|
||||
"""
|
||||
self.labels = []
|
||||
self._labels = []
|
||||
graph_config = self._runner.get_graph_config()
|
||||
found_tensors_to_segmentation = False
|
||||
|
||||
|
@ -170,7 +170,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
for i in range(len(options.label_items)):
|
||||
if i not in options.label_items:
|
||||
raise Exception(f"The labelmap has no expected key: {i}.")
|
||||
self.labels.append(options.label_items[i].name)
|
||||
self._labels.append(options.label_items[i].name)
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
|
||||
|
@ -271,19 +271,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def get_labels(self):
|
||||
""" Get the category label list of the ImageSegmenter can recognize.
|
||||
|
||||
For CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||
category in the label list.
|
||||
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
|
||||
category in the label list.
|
||||
|
||||
If there is no label map provided in the model file, empty label list is
|
||||
returned.
|
||||
"""
|
||||
return self.labels
|
||||
|
||||
def segment(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
|
@ -427,3 +414,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
||||
@property
|
||||
def labels(self) -> List[str]:
|
||||
""" Get the category label list of the ImageSegmenter can recognize.
|
||||
|
||||
For CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||
category in the label list.
|
||||
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
|
||||
category in the label list.
|
||||
|
||||
If there is no label map provided in the model file, empty label list is
|
||||
returned.
|
||||
"""
|
||||
return self._labels
|
||||
|
|
Loading…
Reference in New Issue
Block a user