Code cleanup
This commit is contained in:
parent
1cb404bea1
commit
67b72e4fe9
|
@ -247,12 +247,13 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def test_labels_succeeds(self):
|
||||
@parameterized.parameters((True, False), (False, True))
|
||||
def test_labels_succeeds(self, output_category_mask, output_confidence_masks):
|
||||
expected_labels = _EXPECTED_LABELS
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options, output_category_mask=True,
|
||||
output_confidence_masks=False
|
||||
base_options=base_options, output_category_mask=output_category_mask,
|
||||
output_confidence_masks=output_confidence_masks
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
# Performs image segmentation on the input.
|
||||
|
|
|
@ -129,27 +129,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
Output tensors:
|
||||
(kTfLiteUInt8/kTfLiteFloat32)
|
||||
- list of segmented masks.
|
||||
- if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
- if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
- if `output_category_mask` is True, uint8 Image, Image vector of size 1.
|
||||
- if `output_confidence_masks` is True, float32 Image list of size
|
||||
`channels`.
|
||||
- batch is always 1
|
||||
|
||||
An example of such model can be found at:
|
||||
https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
||||
"""
|
||||
def __init__(self, graph_config, running_mode, packet_callback):
|
||||
def __init__(self, graph_config, running_mode, packet_callback) -> None:
|
||||
"""Initializes the `ImageSegmenter` object."""
|
||||
super(ImageSegmenter, self).__init__(
|
||||
graph_config, running_mode, packet_callback
|
||||
)
|
||||
self._populate_labels()
|
||||
|
||||
def _populate_labels(self):
|
||||
def _populate_labels(self) -> None:
|
||||
"""
|
||||
Populate the labelmap in TensorsToSegmentationCalculator to labels field.
|
||||
|
||||
Returns:
|
||||
Exception if there is an error during finding TensorsToSegmentationCalculator.
|
||||
:return:
|
||||
Raises:
|
||||
Exception if there is an error during finding
|
||||
TensorsToSegmentationCalculator.
|
||||
"""
|
||||
self._labels = []
|
||||
graph_config = self._runner.get_graph_config()
|
||||
|
|
Loading…
Reference in New Issue
Block a user