Code cleanup

This commit is contained in:
kinaryml 2023-04-18 21:43:38 -07:00
parent 1cb404bea1
commit 67b72e4fe9
2 changed files with 12 additions and 10 deletions

View File

@ -247,12 +247,13 @@ class ImageSegmenterTest(parameterized.TestCase):
) )
) )
def test_labels_succeeds(self): @parameterized.parameters((True, False), (False, True))
def test_labels_succeeds(self, output_category_mask, output_confidence_masks):
expected_labels = _EXPECTED_LABELS expected_labels = _EXPECTED_LABELS
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, output_category_mask=True, base_options=base_options, output_category_mask=output_category_mask,
output_confidence_masks=False output_confidence_masks=output_confidence_masks
) )
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input. # Performs image segmentation on the input.

View File

@ -129,27 +129,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Output tensors: Output tensors:
(kTfLiteUInt8/kTfLiteFloat32) (kTfLiteUInt8/kTfLiteFloat32)
- list of segmented masks. - list of segmented masks.
- if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - if `output_category_mask` is True, uint8 Image, Image vector of size 1.
- if `output_type` is CONFIDENCE_MASK, float32 Image list of size - if `output_confidence_masks` is True, float32 Image list of size
`channels`. `channels`.
- batch is always 1 - batch is always 1
An example of such model can be found at: An example of such model can be found at:
https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
""" """
def __init__(self, graph_config, running_mode, packet_callback): def __init__(self, graph_config, running_mode, packet_callback) -> None:
"""Initializes the `ImageSegmenter` object."""
super(ImageSegmenter, self).__init__( super(ImageSegmenter, self).__init__(
graph_config, running_mode, packet_callback graph_config, running_mode, packet_callback
) )
self._populate_labels() self._populate_labels()
def _populate_labels(self): def _populate_labels(self) -> None:
""" """
Populate the labelmap in TensorsToSegmentationCalculator to labels field. Populate the labelmap in TensorsToSegmentationCalculator to labels field.
Returns: Raises:
Exception if there is an error during finding TensorsToSegmentationCalculator. Exception if there is an error during finding
:return: TensorsToSegmentationCalculator.
""" """
self._labels = [] self._labels = []
graph_config = self._runner.get_graph_config() graph_config = self._runner.get_graph_config()