Code cleanup

This commit is contained in:
kinaryml 2023-04-18 21:43:38 -07:00
parent 1cb404bea1
commit 67b72e4fe9
2 changed files with 12 additions and 10 deletions

View File

@ -247,12 +247,13 @@ class ImageSegmenterTest(parameterized.TestCase):
)
)
def test_labels_succeeds(self):
@parameterized.parameters((True, False), (False, True))
def test_labels_succeeds(self, output_category_mask, output_confidence_masks):
expected_labels = _EXPECTED_LABELS
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(
base_options=base_options, output_category_mask=True,
output_confidence_masks=False
base_options=base_options, output_category_mask=output_category_mask,
output_confidence_masks=output_confidence_masks
)
with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input.

View File

@ -129,27 +129,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Output tensors:
(kTfLiteUInt8/kTfLiteFloat32)
- list of segmented masks.
- if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
- if `output_type` is CONFIDENCE_MASK, float32 Image list of size
- if `output_category_mask` is True, uint8 Image, Image vector of size 1.
- if `output_confidence_masks` is True, float32 Image list of size
`channels`.
- batch is always 1
An example of such model can be found at:
https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
"""
def __init__(self, graph_config, running_mode, packet_callback):
def __init__(self, graph_config, running_mode, packet_callback) -> None:
"""Initializes the `ImageSegmenter` object."""
super(ImageSegmenter, self).__init__(
graph_config, running_mode, packet_callback
)
self._populate_labels()
def _populate_labels(self):
def _populate_labels(self) -> None:
"""
Populate the labelmap in TensorsToSegmentationCalculator to labels field.
Returns:
Exception if there is an error during finding TensorsToSegmentationCalculator.
:return:
Raises:
Exception if there is an error during finding
TensorsToSegmentationCalculator.
"""
self._labels = []
graph_config = self._runner.get_graph_config()