Explicitly state the modes in the tests for ImageSegmenterOptions and InteractiveSegmenterOptions

This commit is contained in:
kinaryml 2023-04-13 11:55:37 -07:00
parent 3f68f90238
commit a03fa448dc
2 changed files with 36 additions and 14 deletions

View File

@ -157,7 +157,9 @@ class ImageSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_category_mask=True)
base_options=base_options, output_category_mask=True,
output_confidence_masks=False
)
segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input.
@ -188,8 +190,9 @@ class ImageSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions(
base_options=base_options,
activation=_Activation.SOFTMAX)
base_options=base_options, output_category_mask=False,
output_confidence_masks=True, activation=_Activation.SOFTMAX
)
with _ImageSegmenter.create_from_options(options) as segmenter:
segmentation_result = segmenter.segment(test_image)
@ -279,7 +282,9 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_category_mask=True,
running_mode=_RUNNING_MODE.VIDEO)
output_confidence_masks=False,
running_mode=_RUNNING_MODE.VIDEO
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(
@ -297,8 +302,10 @@ class ImageSegmenterTest(parameterized.TestCase):
os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)))
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, output_category_mask=False,
output_confidence_masks=True
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(
@ -370,8 +377,10 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
output_category_mask=True,
output_confidence_masks=False,
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
result_callback=check_result
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp)
@ -405,9 +414,12 @@ class ImageSegmenterTest(parameterized.TestCase):
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_category_mask=False,
output_confidence_masks=True,
result_callback=check_result
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(test_image, timestamp)

View File

@ -200,7 +200,8 @@ class InteractiveSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _InteractiveSegmenterOptions(
base_options=base_options, output_category_mask=True
base_options=base_options, output_category_mask=True,
output_confidence_masks=False
)
segmenter = _InteractiveSegmenter.create_from_options(options)
@ -252,7 +253,10 @@ class InteractiveSegmenterTest(parameterized.TestCase):
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(base_options=base_options)
options = _InteractiveSegmenterOptions(
base_options=base_options, output_category_mask=False,
output_confidence_masks=True
)
with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation
@ -284,7 +288,10 @@ class InteractiveSegmenterTest(parameterized.TestCase):
)
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(base_options=base_options)
options = _InteractiveSegmenterOptions(
base_options=base_options, output_category_mask=False,
output_confidence_masks=True
)
with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation
@ -310,7 +317,10 @@ class InteractiveSegmenterTest(parameterized.TestCase):
)
# Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions(base_options=base_options)
options = _InteractiveSegmenterOptions(
base_options=base_options, output_category_mask=False,
output_confidence_masks=True
)
with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest."