Added more pose landmarker tests and updated face landmarker tests to cover all the results

This commit is contained in:
kinaryml 2023-04-18 22:45:46 -07:00
parent 39742b6641
commit 1688d0fa79
3 changed files with 281 additions and 57 deletions

View File

@ -58,17 +58,20 @@ _FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02
def _get_expected_face_landmarks(file_path: str): def _get_expected_face_landmarks(file_path: str):
proto_file_path = test_utils.get_test_data_path(file_path) proto_file_path = test_utils.get_test_data_path(file_path)
face_landmarks_results = []
with open(proto_file_path, 'rb') as f: with open(proto_file_path, 'rb') as f:
proto = landmark_pb2.NormalizedLandmarkList() proto = landmark_pb2.NormalizedLandmarkList()
text_format.Parse(f.read(), proto) text_format.Parse(f.read(), proto)
face_landmarks = [] face_landmarks = []
for landmark in proto.landmark: for landmark in proto.landmark:
face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark)) face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
return face_landmarks face_landmarks_results.append(face_landmarks)
return face_landmarks_results
def _get_expected_face_blendshapes(file_path: str): def _get_expected_face_blendshapes(file_path: str):
proto_file_path = test_utils.get_test_data_path(file_path) proto_file_path = test_utils.get_test_data_path(file_path)
face_blendshapes_results = []
with open(proto_file_path, 'rb') as f: with open(proto_file_path, 'rb') as f:
proto = classification_pb2.ClassificationList() proto = classification_pb2.ClassificationList()
text_format.Parse(f.read(), proto) text_format.Parse(f.read(), proto)
@ -84,7 +87,8 @@ def _get_expected_face_blendshapes(file_path: str):
category_name=face_blendshapes.label, category_name=face_blendshapes.label,
) )
) )
return face_blendshapes_categories face_blendshapes_results.append(face_blendshapes_categories)
return face_blendshapes_results
def _get_expected_facial_transformation_matrixes(): def _get_expected_facial_transformation_matrixes():
@ -119,13 +123,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Expects to have the same number of faces detected. # Expects to have the same number of faces detected.
self.assertLen(actual_landmarks, len(expected_landmarks)) self.assertLen(actual_landmarks, len(expected_landmarks))
for i, elem in enumerate(actual_landmarks): for i, _ in enumerate(actual_landmarks):
self.assertAlmostEqual( for j, elem in enumerate(actual_landmarks[i]):
elem.x, expected_landmarks[i].x, delta=_LANDMARKS_DIFF_MARGIN self.assertAlmostEqual(
) elem.x, expected_landmarks[i][j].x, delta=_LANDMARKS_DIFF_MARGIN
self.assertAlmostEqual( )
elem.y, expected_landmarks[i].y, delta=_LANDMARKS_DIFF_MARGIN self.assertAlmostEqual(
) elem.y, expected_landmarks[i][j].y, delta=_LANDMARKS_DIFF_MARGIN
)
def _expect_blendshapes_correct( def _expect_blendshapes_correct(
self, actual_blendshapes, expected_blendshapes self, actual_blendshapes, expected_blendshapes
@ -133,13 +138,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Expects to have the same number of blendshapes. # Expects to have the same number of blendshapes.
self.assertLen(actual_blendshapes, len(expected_blendshapes)) self.assertLen(actual_blendshapes, len(expected_blendshapes))
for i, elem in enumerate(actual_blendshapes): for i, _ in enumerate(actual_blendshapes):
self.assertEqual(elem.index, expected_blendshapes[i].index) for j, elem in enumerate(actual_blendshapes[i]):
self.assertAlmostEqual( self.assertEqual(elem.index, expected_blendshapes[i][j].index)
elem.score, self.assertAlmostEqual(
expected_blendshapes[i].score, elem.score,
delta=_BLENDSHAPES_DIFF_MARGIN, expected_blendshapes[i][j].score,
) delta=_BLENDSHAPES_DIFF_MARGIN,
)
def _expect_facial_transformation_matrixes_correct( def _expect_facial_transformation_matrixes_correct(
self, actual_matrix_list, expected_matrix_list self, actual_matrix_list, expected_matrix_list
@ -236,11 +242,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
if expected_face_landmarks is not None: if expected_face_landmarks is not None:
self._expect_landmarks_correct( self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks detection_result.face_landmarks, expected_face_landmarks
) )
if expected_face_blendshapes is not None: if expected_face_blendshapes is not None:
self._expect_blendshapes_correct( self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes detection_result.face_blendshapes, expected_face_blendshapes
) )
if expected_facial_transformation_matrixes is not None: if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct( self._expect_facial_transformation_matrixes_correct(
@ -302,11 +308,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
if expected_face_landmarks is not None: if expected_face_landmarks is not None:
self._expect_landmarks_correct( self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks detection_result.face_landmarks, expected_face_landmarks
) )
if expected_face_blendshapes is not None: if expected_face_blendshapes is not None:
self._expect_blendshapes_correct( self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes detection_result.face_blendshapes, expected_face_blendshapes
) )
if expected_facial_transformation_matrixes is not None: if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct( self._expect_facial_transformation_matrixes_correct(
@ -446,11 +452,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
if expected_face_landmarks is not None: if expected_face_landmarks is not None:
self._expect_landmarks_correct( self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks detection_result.face_landmarks, expected_face_landmarks
) )
if expected_face_blendshapes is not None: if expected_face_blendshapes is not None:
self._expect_blendshapes_correct( self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes detection_result.face_blendshapes, expected_face_blendshapes
) )
if expected_facial_transformation_matrixes is not None: if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct( self._expect_facial_transformation_matrixes_correct(
@ -523,11 +529,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
if expected_face_landmarks is not None: if expected_face_landmarks is not None:
self._expect_landmarks_correct( self._expect_landmarks_correct(
result.face_landmarks[0], expected_face_landmarks result.face_landmarks, expected_face_landmarks
) )
if expected_face_blendshapes is not None: if expected_face_blendshapes is not None:
self._expect_blendshapes_correct( self._expect_blendshapes_correct(
result.face_blendshapes[0], expected_face_blendshapes result.face_blendshapes, expected_face_blendshapes
) )
if expected_facial_transformation_matrixes is not None: if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct( self._expect_facial_transformation_matrixes_correct(

View File

@ -14,6 +14,7 @@
"""Tests for pose landmarker.""" """Tests for pose landmarker."""
import enum import enum
from typing import List
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -49,8 +50,8 @@ _POSE_LANDMARKER_BUNDLE_ASSET_FILE = 'pose_landmarker.task'
_BURGER_IMAGE = 'burger.jpg' _BURGER_IMAGE = 'burger.jpg'
_POSE_IMAGE = 'pose.jpg' _POSE_IMAGE = 'pose.jpg'
_POSE_LANDMARKS = 'pose_landmarks.pbtxt' _POSE_LANDMARKS = 'pose_landmarks.pbtxt'
_LANDMARKS_ERROR_TOLERANCE = 0.03 _LANDMARKS_DIFF_MARGIN = 0.03
_LANDMARKS_ON_VIDEO_ERROR_TOLERANCE = 0.03 _LANDMARKS_ON_VIDEO_DIFF_MARGIN = 0.03
def _get_expected_pose_landmarker_result( def _get_expected_pose_landmarker_result(
@ -85,33 +86,34 @@ class PoseLandmarkerTest(parameterized.TestCase):
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
_POSE_LANDMARKER_BUNDLE_ASSET_FILE) _POSE_LANDMARKER_BUNDLE_ASSET_FILE)
def _expect_pose_landmarker_results_correct( def _expect_pose_landmarks_correct(
self, self,
actual_result: PoseLandmarkerResult, actual_landmarks: List[List[landmark_module.NormalizedLandmark]],
expected_result: PoseLandmarkerResult, expected_landmarks: List[List[landmark_module.NormalizedLandmark]],
error_tolerance: float diff_margin: float
): ):
# Expects to have the same number of poses detected. # Expects to have the same number of poses detected.
self.assertLen(actual_result.pose_landmarks, self.assertLen(actual_landmarks, len(expected_landmarks))
len(expected_result.pose_landmarks))
self.assertLen(actual_result.pose_world_landmarks, for i, _ in enumerate(actual_landmarks):
len(expected_result.pose_world_landmarks)) for j, elem in enumerate(actual_landmarks[i]):
self.assertLen(actual_result.pose_auxiliary_landmarks, self.assertAlmostEqual(
len(expected_result.pose_auxiliary_landmarks)) elem.x, expected_landmarks[i][j].x, delta=diff_margin
# Actual landmarks match expected landmarks. )
actual_landmarks = actual_result.pose_landmarks[0] self.assertAlmostEqual(
expected_landmarks = expected_result.pose_landmarks[0] elem.y, expected_landmarks[i][j].y, delta=diff_margin
for i, pose_landmark in enumerate(actual_landmarks): )
self.assertAlmostEqual(
pose_landmark.x, def _expect_pose_landmarker_results_correct(
expected_landmarks[i].x, self,
delta=error_tolerance actual_result: PoseLandmarkerResult,
) expected_result: PoseLandmarkerResult,
self.assertAlmostEqual( diff_margin: float
pose_landmark.y, ):
expected_landmarks[i].y, self._expect_pose_landmarks_correct(
delta=error_tolerance actual_result.pose_landmarks, expected_result.pose_landmarks,
) diff_margin
)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -146,7 +148,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
(ModelFileType.FILE_NAME, (ModelFileType.FILE_NAME,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)), _get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT, (ModelFileType.FILE_CONTENT,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))) _get_expected_pose_landmarker_result(_POSE_LANDMARKS))
)
def test_detect(self, model_file_type, expected_detection_result): def test_detect(self, model_file_type, expected_detection_result):
# Creates pose landmarker. # Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
@ -164,14 +167,229 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Performs pose landmarks detection on the input. # Performs pose landmarks detection on the input.
detection_result = landmarker.detect(self.test_image) detection_result = landmarker.detect(self.test_image)
# Comparing results. # Comparing results.
self._expect_pose_landmarker_results_correct( self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_ERROR_TOLERANCE detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
) )
# Closes the pose landmarker explicitly when the pose landmarker is not used # Closes the pose landmarker explicitly when the pose landmarker is not used
# in a context. # in a context.
landmarker.close() landmarker.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
)
def test_detect_in_context(self, model_file_type, expected_detection_result):
# Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _PoseLandmarkerOptions(base_options=base_options)
with _PoseLandmarker.create_from_options(options) as landmarker:
# Performs pose landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
)
def test_detect_fails_with_region_of_interest(self):
# Creates pose landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _PoseLandmarkerOptions(base_options=base_options)
with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest."):
with _PoseLandmarker.create_from_options(options) as landmarker:
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
image_processing_options = _ImageProcessingOptions(
region_of_interest=_Rect(0, 0, 1, 1))
# Attempt to perform pose landmarks detection on the cropped input.
landmarker.detect(self.test_image, image_processing_options)
def test_empty_detection_outputs(self):
# Creates pose landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _PoseLandmarkerOptions(base_options=base_options)
with _PoseLandmarker.create_from_options(options) as landmarker:
# Load an image with no poses.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_BURGER_IMAGE))
# Performs pose landmarks detection on the input.
detection_result = landmarker.detect(test_image)
# Comparing results.
self.assertEmpty(detection_result.pose_landmarks)
self.assertEmpty(detection_result.pose_world_landmarks)
self.assertEmpty(detection_result.pose_auxiliary_landmarks)
def test_missing_result_callback(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(_POSE_IMAGE, 0,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0,
PoseLandmarkerResult([], [], []))
)
def test_detect_for_video(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
result = landmarker.detect_for_video(test_image, timestamp,
image_processing_options)
if result.pose_landmarks:
self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_ON_VIDEO_DIFF_MARGIN
)
else:
self.assertEqual(result, expected_result)
def test_calling_detect_in_live_stream_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(_POSE_IMAGE, 0,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0,
PoseLandmarkerResult([], [], []))
)
def test_detect_async_calls(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
observed_timestamp_ms = -1
def check_result(result: PoseLandmarkerResult, output_image: _Image,
timestamp_ms: int):
if result.pose_landmarks:
self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_DIFF_MARGIN
)
else:
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _PoseLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp, image_processing_options)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -64,7 +64,7 @@ class PoseLandmarkerResult:
pose_world_landmarks: Detected pose landmarks in world coordinates. pose_world_landmarks: Detected pose landmarks in world coordinates.
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
ROI for next frame. ROI for next frame.
segmentation_masks: Segmentation masks for pose. segmentation_masks: Optional segmentation masks for pose.
""" """
pose_landmarks: List[List[landmark_module.NormalizedLandmark]] pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
@ -77,7 +77,7 @@ def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet] output_packets: Mapping[str, packet_module.Packet]
) -> PoseLandmarkerResult: ) -> PoseLandmarkerResult:
"""Constructs a `PoseLandmarkerResult` from output packets.""" """Constructs a `PoseLandmarkerResult` from output packets."""
pose_landmarker_result = PoseLandmarkerResult([], [], [], []) pose_landmarker_result = PoseLandmarkerResult([], [], [])
if _SEGMENTATION_MASK_STREAM_NAME in output_packets: if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list( pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
@ -356,7 +356,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
image_processing_options: Options for image processing. image_processing_options: Options for image processing.
Returns: Returns:
The pose landmarks detection results. The pose landmarker detection results.
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
@ -402,7 +402,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
per input image. per input image.
The `result_callback` provides: The `result_callback` provides:
- The pose landmarks detection results. - The pose landmarker detection results.
- The input image that the pose landmarker runs on. - The input image that the pose landmarker runs on.
- The input timestamp in milliseconds. - The input timestamp in milliseconds.