Changed labels to be a property
This commit is contained in:
parent
d621df8046
commit
1cb404bea1
|
@ -247,14 +247,16 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_labels_succeeds(self):
|
def test_labels_succeeds(self):
|
||||||
expected_labels = _EXPECTED_LABELS
|
expected_labels = _EXPECTED_LABELS
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
base_options=base_options, output_category_mask=True,
|
||||||
|
output_confidence_masks=False
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
actual_labels = segmenter.get_labels()
|
actual_labels = segmenter.labels
|
||||||
self.assertListEqual(actual_labels, expected_labels)
|
self.assertListEqual(actual_labels, expected_labels)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
|
|
|
@ -151,7 +151,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Exception if there is an error during finding TensorsToSegmentationCalculator.
|
Exception if there is an error during finding TensorsToSegmentationCalculator.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.labels = []
|
self._labels = []
|
||||||
graph_config = self._runner.get_graph_config()
|
graph_config = self._runner.get_graph_config()
|
||||||
found_tensors_to_segmentation = False
|
found_tensors_to_segmentation = False
|
||||||
|
|
||||||
|
@ -170,7 +170,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
for i in range(len(options.label_items)):
|
for i in range(len(options.label_items)):
|
||||||
if i not in options.label_items:
|
if i not in options.label_items:
|
||||||
raise Exception(f"The labelmap has no expected key: {i}.")
|
raise Exception(f"The labelmap has no expected key: {i}.")
|
||||||
self.labels.append(options.label_items[i].name)
|
self._labels.append(options.label_items[i].name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
|
def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter':
|
||||||
|
@ -271,19 +271,6 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
packets_callback if options.result_callback else None,
|
packets_callback if options.result_callback else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_labels(self):
|
|
||||||
""" Get the category label list of the ImageSegmenter can recognize.
|
|
||||||
|
|
||||||
For CATEGORY_MASK type, the index in the category mask corresponds to the
|
|
||||||
category in the label list.
|
|
||||||
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
|
|
||||||
category in the label list.
|
|
||||||
|
|
||||||
If there is no label map provided in the model file, empty label list is
|
|
||||||
returned.
|
|
||||||
"""
|
|
||||||
return self.labels
|
|
||||||
|
|
||||||
def segment(
|
def segment(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
|
@ -427,3 +414,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self) -> List[str]:
|
||||||
|
""" Get the category label list of the ImageSegmenter can recognize.
|
||||||
|
|
||||||
|
For CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||||
|
category in the label list.
|
||||||
|
For CONFIDENCE_MASK type, the output mask list at index corresponds to the
|
||||||
|
category in the label list.
|
||||||
|
|
||||||
|
If there is no label map provided in the model file, empty label list is
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
return self._labels
|
||||||
|
|
Loading…
Reference in New Issue
Block a user