Removed an unnecessary check and updated tests to check if the masks are generated or not
This commit is contained in:
parent
9032bce577
commit
b511822815
|
@ -104,12 +104,19 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
self,
|
self,
|
||||||
actual_result: PoseLandmarkerResult,
|
actual_result: PoseLandmarkerResult,
|
||||||
expected_result: PoseLandmarkerResult,
|
expected_result: PoseLandmarkerResult,
|
||||||
|
output_segmentation_masks: bool,
|
||||||
margin: float
|
margin: float
|
||||||
):
|
):
|
||||||
self._expect_pose_landmarks_correct(
|
self._expect_pose_landmarks_correct(
|
||||||
actual_result.pose_landmarks, expected_result.pose_landmarks,
|
actual_result.pose_landmarks, expected_result.pose_landmarks,
|
||||||
margin
|
margin
|
||||||
)
|
)
|
||||||
|
if output_segmentation_masks:
|
||||||
|
self.assertIsInstance(actual_result.segmentation_masks, List)
|
||||||
|
for i, mask in enumerate(actual_result.segmentation_masks):
|
||||||
|
self.assertIsInstance(mask, _Image)
|
||||||
|
else:
|
||||||
|
self.assertIsNone(actual_result.segmentation_masks)
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -141,12 +148,17 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
self.assertIsInstance(landmarker, _PoseLandmarker)
|
self.assertIsInstance(landmarker, _PoseLandmarker)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME,
|
(ModelFileType.FILE_NAME, False,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
(ModelFileType.FILE_CONTENT,
|
(ModelFileType.FILE_CONTENT, False,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_NAME, True,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_CONTENT, True,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
||||||
)
|
)
|
||||||
def test_detect(self, model_file_type, expected_detection_result):
|
def test_detect(self, model_file_type, output_segmentation_masks,
|
||||||
|
expected_detection_result):
|
||||||
# Creates pose landmarker.
|
# Creates pose landmarker.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
@ -158,7 +170,10 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=base_options,
|
||||||
|
output_segmentation_masks=output_segmentation_masks
|
||||||
|
)
|
||||||
landmarker = _PoseLandmarker.create_from_options(options)
|
landmarker = _PoseLandmarker.create_from_options(options)
|
||||||
|
|
||||||
# Performs pose landmarks detection on the input.
|
# Performs pose landmarks detection on the input.
|
||||||
|
@ -166,19 +181,27 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self._expect_pose_landmarker_results_correct(
|
self._expect_pose_landmarker_results_correct(
|
||||||
detection_result, expected_detection_result, _LANDMARKS_MARGIN
|
detection_result, expected_detection_result, output_segmentation_masks,
|
||||||
|
_LANDMARKS_MARGIN
|
||||||
)
|
)
|
||||||
# Closes the pose landmarker explicitly when the pose landmarker is not used
|
# Closes the pose landmarker explicitly when the pose landmarker is not used
|
||||||
# in a context.
|
# in a context.
|
||||||
landmarker.close()
|
landmarker.close()
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_NAME,
|
(ModelFileType.FILE_NAME, False,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
(ModelFileType.FILE_CONTENT,
|
(ModelFileType.FILE_CONTENT, False,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_NAME, True,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_CONTENT, True,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
||||||
)
|
)
|
||||||
def test_detect_in_context(self, model_file_type, expected_detection_result):
|
def test_detect_in_context(
|
||||||
|
self, model_file_type, output_segmentation_masks,
|
||||||
|
expected_detection_result
|
||||||
|
):
|
||||||
# Creates pose landmarker.
|
# Creates pose landmarker.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
@ -190,14 +213,18 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
# Should never happen
|
# Should never happen
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=base_options,
|
||||||
|
output_segmentation_masks=output_segmentation_masks
|
||||||
|
)
|
||||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
# Performs pose landmarks detection on the input.
|
# Performs pose landmarks detection on the input.
|
||||||
detection_result = landmarker.detect(self.test_image)
|
detection_result = landmarker.detect(self.test_image)
|
||||||
|
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self._expect_pose_landmarker_results_correct(
|
self._expect_pose_landmarker_results_correct(
|
||||||
detection_result, expected_detection_result, _LANDMARKS_MARGIN
|
detection_result, expected_detection_result,
|
||||||
|
output_segmentation_masks, _LANDMARKS_MARGIN
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_detect_fails_with_region_of_interest(self):
|
def test_detect_fails_with_region_of_interest(self):
|
||||||
|
@ -295,12 +322,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
landmarker.detect_for_video(self.test_image, 0)
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(_POSE_IMAGE, 0,
|
(_POSE_IMAGE, 0, False,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
(_BURGER_IMAGE, 0,
|
(_POSE_IMAGE, 0, True,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(_BURGER_IMAGE, 0, False,
|
||||||
PoseLandmarkerResult([], [], []))
|
PoseLandmarkerResult([], [], []))
|
||||||
)
|
)
|
||||||
def test_detect_for_video(self, image_path, rotation, expected_result):
|
def test_detect_for_video(self, image_path, rotation,
|
||||||
|
output_segmentation_masks, expected_result):
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(image_path))
|
test_utils.get_test_data_path(image_path))
|
||||||
# Set rotation parameters using ImageProcessingOptions.
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
@ -308,6 +338,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
rotation_degrees=rotation)
|
rotation_degrees=rotation)
|
||||||
options = _PoseLandmarkerOptions(
|
options = _PoseLandmarkerOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
output_segmentation_masks=output_segmentation_masks,
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
|
@ -315,7 +346,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
image_processing_options)
|
image_processing_options)
|
||||||
if result.pose_landmarks:
|
if result.pose_landmarks:
|
||||||
self._expect_pose_landmarker_results_correct(
|
self._expect_pose_landmarker_results_correct(
|
||||||
result, expected_result, _LANDMARKS_MARGIN
|
result, expected_result, output_segmentation_masks,
|
||||||
|
_LANDMARKS_MARGIN
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(result, expected_result)
|
self.assertEqual(result, expected_result)
|
||||||
|
@ -352,12 +384,15 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
landmarker.detect_async(self.test_image, 0)
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
(_POSE_IMAGE, 0,
|
(_POSE_IMAGE, 0, False,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
(_BURGER_IMAGE, 0,
|
(_POSE_IMAGE, 0, True,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(_BURGER_IMAGE, 0, False,
|
||||||
PoseLandmarkerResult([], [], []))
|
PoseLandmarkerResult([], [], []))
|
||||||
)
|
)
|
||||||
def test_detect_async_calls(self, image_path, rotation, expected_result):
|
def test_detect_async_calls(self, image_path, rotation,
|
||||||
|
output_segmentation_masks, expected_result):
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(image_path))
|
test_utils.get_test_data_path(image_path))
|
||||||
# Set rotation parameters using ImageProcessingOptions.
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
@ -369,7 +404,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
timestamp_ms: int):
|
timestamp_ms: int):
|
||||||
if result.pose_landmarks:
|
if result.pose_landmarks:
|
||||||
self._expect_pose_landmarker_results_correct(
|
self._expect_pose_landmarker_results_correct(
|
||||||
result, expected_result, _LANDMARKS_MARGIN
|
result, expected_result, output_segmentation_masks,
|
||||||
|
_LANDMARKS_MARGIN
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(result, expected_result)
|
self.assertEqual(result, expected_result)
|
||||||
|
@ -380,6 +416,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
|
|
||||||
options = _PoseLandmarkerOptions(
|
options = _PoseLandmarkerOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
output_segmentation_masks=output_segmentation_masks,
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=check_result)
|
result_callback=check_result)
|
||||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
|
|
@ -263,7 +263,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||||
)
|
)
|
||||||
|
|
||||||
output_streams = [
|
output_streams = [
|
||||||
':'.join([_SEGMENTATION_MASK_TAG, _SEGMENTATION_MASK_STREAM_NAME]),
|
|
||||||
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
|
':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
|
||||||
':'.join([
|
':'.join([
|
||||||
_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME
|
_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME
|
||||||
|
|
Loading…
Reference in New Issue
Block a user