Removed an unnecessary check and updated tests to check if the masks are generated or not

This commit is contained in:
kinaryml 2023-04-24 11:23:27 -07:00
parent 9032bce577
commit b511822815
2 changed files with 55 additions and 19 deletions

View File

@ -104,12 +104,19 @@ class PoseLandmarkerTest(parameterized.TestCase):
self, self,
actual_result: PoseLandmarkerResult, actual_result: PoseLandmarkerResult,
expected_result: PoseLandmarkerResult, expected_result: PoseLandmarkerResult,
output_segmentation_masks: bool,
margin: float margin: float
): ):
self._expect_pose_landmarks_correct( self._expect_pose_landmarks_correct(
actual_result.pose_landmarks, expected_result.pose_landmarks, actual_result.pose_landmarks, expected_result.pose_landmarks,
margin margin
) )
if output_segmentation_masks:
self.assertIsInstance(actual_result.segmentation_masks, List)
for i, mask in enumerate(actual_result.segmentation_masks):
self.assertIsInstance(mask, _Image)
else:
self.assertIsNone(actual_result.segmentation_masks)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -141,12 +148,17 @@ class PoseLandmarkerTest(parameterized.TestCase):
self.assertIsInstance(landmarker, _PoseLandmarker) self.assertIsInstance(landmarker, _PoseLandmarker)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, (ModelFileType.FILE_NAME, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)), _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, (ModelFileType.FILE_CONTENT, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_NAME, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)) _get_expected_pose_landmarker_result(_POSE_LANDMARKS))
) )
def test_detect(self, model_file_type, expected_detection_result): def test_detect(self, model_file_type, output_segmentation_masks,
expected_detection_result):
# Creates pose landmarker. # Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -158,7 +170,10 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _PoseLandmarkerOptions(base_options=base_options) options = _PoseLandmarkerOptions(
base_options=base_options,
output_segmentation_masks=output_segmentation_masks
)
landmarker = _PoseLandmarker.create_from_options(options) landmarker = _PoseLandmarker.create_from_options(options)
# Performs pose landmarks detection on the input. # Performs pose landmarks detection on the input.
@ -166,19 +181,27 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
self._expect_pose_landmarker_results_correct( self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_MARGIN detection_result, expected_detection_result, output_segmentation_masks,
_LANDMARKS_MARGIN
) )
# Closes the pose landmarker explicitly when the pose landmarker is not used # Closes the pose landmarker explicitly when the pose landmarker is not used
# in a context. # in a context.
landmarker.close() landmarker.close()
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, (ModelFileType.FILE_NAME, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)), _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, (ModelFileType.FILE_CONTENT, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_NAME, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)) _get_expected_pose_landmarker_result(_POSE_LANDMARKS))
) )
def test_detect_in_context(self, model_file_type, expected_detection_result): def test_detect_in_context(
self, model_file_type, output_segmentation_masks,
expected_detection_result
):
# Creates pose landmarker. # Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -190,14 +213,18 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _PoseLandmarkerOptions(base_options=base_options) options = _PoseLandmarkerOptions(
base_options=base_options,
output_segmentation_masks=output_segmentation_masks
)
with _PoseLandmarker.create_from_options(options) as landmarker: with _PoseLandmarker.create_from_options(options) as landmarker:
# Performs pose landmarks detection on the input. # Performs pose landmarks detection on the input.
detection_result = landmarker.detect(self.test_image) detection_result = landmarker.detect(self.test_image)
# Comparing results. # Comparing results.
self._expect_pose_landmarker_results_correct( self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_MARGIN detection_result, expected_detection_result,
output_segmentation_masks, _LANDMARKS_MARGIN
) )
def test_detect_fails_with_region_of_interest(self): def test_detect_fails_with_region_of_interest(self):
@ -295,12 +322,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
landmarker.detect_for_video(self.test_image, 0) landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters( @parameterized.parameters(
(_POSE_IMAGE, 0, (_POSE_IMAGE, 0, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)), _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0, (_POSE_IMAGE, 0, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0, False,
PoseLandmarkerResult([], [], [])) PoseLandmarkerResult([], [], []))
) )
def test_detect_for_video(self, image_path, rotation, expected_result): def test_detect_for_video(self, image_path, rotation,
output_segmentation_masks, expected_result):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)) test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions. # Set rotation parameters using ImageProcessingOptions.
@ -308,6 +338,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
rotation_degrees=rotation) rotation_degrees=rotation)
options = _PoseLandmarkerOptions( options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_segmentation_masks=output_segmentation_masks,
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker: with _PoseLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
@ -315,7 +346,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
image_processing_options) image_processing_options)
if result.pose_landmarks: if result.pose_landmarks:
self._expect_pose_landmarker_results_correct( self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_MARGIN result, expected_result, output_segmentation_masks,
_LANDMARKS_MARGIN
) )
else: else:
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
@ -352,12 +384,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
landmarker.detect_async(self.test_image, 0) landmarker.detect_async(self.test_image, 0)
@parameterized.parameters( @parameterized.parameters(
(_POSE_IMAGE, 0, (_POSE_IMAGE, 0, False,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)), _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0, (_POSE_IMAGE, 0, True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0, False,
PoseLandmarkerResult([], [], [])) PoseLandmarkerResult([], [], []))
) )
def test_detect_async_calls(self, image_path, rotation, expected_result): def test_detect_async_calls(self, image_path, rotation,
output_segmentation_masks, expected_result):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)) test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions. # Set rotation parameters using ImageProcessingOptions.
@ -369,7 +404,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
timestamp_ms: int): timestamp_ms: int):
if result.pose_landmarks: if result.pose_landmarks:
self._expect_pose_landmarker_results_correct( self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_MARGIN result, expected_result, output_segmentation_masks,
_LANDMARKS_MARGIN
) )
else: else:
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
@ -380,6 +416,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
options = _PoseLandmarkerOptions( options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_segmentation_masks=output_segmentation_masks,
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) result_callback=check_result)
with _PoseLandmarker.create_from_options(options) as landmarker: with _PoseLandmarker.create_from_options(options) as landmarker:

View File

@ -263,7 +263,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
) )
output_streams = [ output_streams = [
':'.join([_SEGMENTATION_MASK_TAG, _SEGMENTATION_MASK_STREAM_NAME]),
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]), ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
':'.join([ ':'.join([
_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME _POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME