Added more pose landmarker tests and updated face landmarker tests to cover all the results

This commit is contained in:
kinaryml 2023-04-18 22:45:46 -07:00
parent 39742b6641
commit 1688d0fa79
3 changed files with 281 additions and 57 deletions

View File

@ -58,17 +58,20 @@ _FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02
def _get_expected_face_landmarks(file_path: str):
proto_file_path = test_utils.get_test_data_path(file_path)
face_landmarks_results = []
with open(proto_file_path, 'rb') as f:
proto = landmark_pb2.NormalizedLandmarkList()
text_format.Parse(f.read(), proto)
face_landmarks = []
for landmark in proto.landmark:
face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
return face_landmarks
face_landmarks_results.append(face_landmarks)
return face_landmarks_results
def _get_expected_face_blendshapes(file_path: str):
proto_file_path = test_utils.get_test_data_path(file_path)
face_blendshapes_results = []
with open(proto_file_path, 'rb') as f:
proto = classification_pb2.ClassificationList()
text_format.Parse(f.read(), proto)
@ -84,7 +87,8 @@ def _get_expected_face_blendshapes(file_path: str):
category_name=face_blendshapes.label,
)
)
return face_blendshapes_categories
face_blendshapes_results.append(face_blendshapes_categories)
return face_blendshapes_results
def _get_expected_facial_transformation_matrixes():
@ -119,13 +123,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Expects to have the same number of faces detected.
self.assertLen(actual_landmarks, len(expected_landmarks))
for i, elem in enumerate(actual_landmarks):
self.assertAlmostEqual(
elem.x, expected_landmarks[i].x, delta=_LANDMARKS_DIFF_MARGIN
)
self.assertAlmostEqual(
elem.y, expected_landmarks[i].y, delta=_LANDMARKS_DIFF_MARGIN
)
for i, _ in enumerate(actual_landmarks):
for j, elem in enumerate(actual_landmarks[i]):
self.assertAlmostEqual(
elem.x, expected_landmarks[i][j].x, delta=_LANDMARKS_DIFF_MARGIN
)
self.assertAlmostEqual(
elem.y, expected_landmarks[i][j].y, delta=_LANDMARKS_DIFF_MARGIN
)
def _expect_blendshapes_correct(
self, actual_blendshapes, expected_blendshapes
@ -133,13 +138,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Expects to have the same number of blendshapes.
self.assertLen(actual_blendshapes, len(expected_blendshapes))
for i, elem in enumerate(actual_blendshapes):
self.assertEqual(elem.index, expected_blendshapes[i].index)
self.assertAlmostEqual(
elem.score,
expected_blendshapes[i].score,
delta=_BLENDSHAPES_DIFF_MARGIN,
)
for i, _ in enumerate(actual_blendshapes):
for j, elem in enumerate(actual_blendshapes[i]):
self.assertEqual(elem.index, expected_blendshapes[i][j].index)
self.assertAlmostEqual(
elem.score,
expected_blendshapes[i][j].score,
delta=_BLENDSHAPES_DIFF_MARGIN,
)
def _expect_facial_transformation_matrixes_correct(
self, actual_matrix_list, expected_matrix_list
@ -236,11 +242,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks
detection_result.face_landmarks, expected_face_landmarks
)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes
detection_result.face_blendshapes, expected_face_blendshapes
)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct(
@ -302,11 +308,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks
detection_result.face_landmarks, expected_face_landmarks
)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes
detection_result.face_blendshapes, expected_face_blendshapes
)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct(
@ -446,11 +452,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(
detection_result.face_landmarks[0], expected_face_landmarks
detection_result.face_landmarks, expected_face_landmarks
)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(
detection_result.face_blendshapes[0], expected_face_blendshapes
detection_result.face_blendshapes, expected_face_blendshapes
)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct(
@ -523,11 +529,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
# Comparing results.
if expected_face_landmarks is not None:
self._expect_landmarks_correct(
result.face_landmarks[0], expected_face_landmarks
result.face_landmarks, expected_face_landmarks
)
if expected_face_blendshapes is not None:
self._expect_blendshapes_correct(
result.face_blendshapes[0], expected_face_blendshapes
result.face_blendshapes, expected_face_blendshapes
)
if expected_facial_transformation_matrixes is not None:
self._expect_facial_transformation_matrixes_correct(

View File

@ -14,6 +14,7 @@
"""Tests for pose landmarker."""
import enum
from typing import List
from unittest import mock
from absl.testing import absltest
@ -49,8 +50,8 @@ _POSE_LANDMARKER_BUNDLE_ASSET_FILE = 'pose_landmarker.task'
_BURGER_IMAGE = 'burger.jpg'
_POSE_IMAGE = 'pose.jpg'
_POSE_LANDMARKS = 'pose_landmarks.pbtxt'
_LANDMARKS_ERROR_TOLERANCE = 0.03
_LANDMARKS_ON_VIDEO_ERROR_TOLERANCE = 0.03
_LANDMARKS_DIFF_MARGIN = 0.03
_LANDMARKS_ON_VIDEO_DIFF_MARGIN = 0.03
def _get_expected_pose_landmarker_result(
@ -85,33 +86,34 @@ class PoseLandmarkerTest(parameterized.TestCase):
self.model_path = test_utils.get_test_data_path(
_POSE_LANDMARKER_BUNDLE_ASSET_FILE)
def _expect_pose_landmarks_correct(
self,
actual_landmarks: List[List[landmark_module.NormalizedLandmark]],
expected_landmarks: List[List[landmark_module.NormalizedLandmark]],
diff_margin: float
):
# Expects to have the same number of poses detected.
self.assertLen(actual_landmarks, len(expected_landmarks))
for i, _ in enumerate(actual_landmarks):
for j, elem in enumerate(actual_landmarks[i]):
self.assertAlmostEqual(
elem.x, expected_landmarks[i][j].x, delta=diff_margin
)
self.assertAlmostEqual(
elem.y, expected_landmarks[i][j].y, delta=diff_margin
)
def _expect_pose_landmarker_results_correct(
self,
actual_result: PoseLandmarkerResult,
expected_result: PoseLandmarkerResult,
error_tolerance: float
diff_margin: float
):
# Expects to have the same number of poses detected.
self.assertLen(actual_result.pose_landmarks,
len(expected_result.pose_landmarks))
self.assertLen(actual_result.pose_world_landmarks,
len(expected_result.pose_world_landmarks))
self.assertLen(actual_result.pose_auxiliary_landmarks,
len(expected_result.pose_auxiliary_landmarks))
# Actual landmarks match expected landmarks.
actual_landmarks = actual_result.pose_landmarks[0]
expected_landmarks = expected_result.pose_landmarks[0]
for i, pose_landmark in enumerate(actual_landmarks):
self.assertAlmostEqual(
pose_landmark.x,
expected_landmarks[i].x,
delta=error_tolerance
)
self.assertAlmostEqual(
pose_landmark.y,
expected_landmarks[i].y,
delta=error_tolerance
)
self._expect_pose_landmarks_correct(
actual_result.pose_landmarks, expected_result.pose_landmarks,
diff_margin
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
@ -146,7 +148,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
(ModelFileType.FILE_NAME,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)))
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
)
def test_detect(self, model_file_type, expected_detection_result):
# Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME:
@ -164,14 +167,229 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Performs pose landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_ERROR_TOLERANCE
detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
)
# Closes the pose landmarker explicitly when the pose landmarker is not used
# in a context.
landmarker.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(ModelFileType.FILE_CONTENT,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
)
def test_detect_in_context(self, model_file_type, expected_detection_result):
# Creates pose landmarker.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _PoseLandmarkerOptions(base_options=base_options)
with _PoseLandmarker.create_from_options(options) as landmarker:
# Performs pose landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
self._expect_pose_landmarker_results_correct(
detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
)
def test_detect_fails_with_region_of_interest(self):
# Creates pose landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _PoseLandmarkerOptions(base_options=base_options)
with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest."):
with _PoseLandmarker.create_from_options(options) as landmarker:
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
image_processing_options = _ImageProcessingOptions(
region_of_interest=_Rect(0, 0, 1, 1))
# Attempt to perform pose landmarks detection on the cropped input.
landmarker.detect(self.test_image, image_processing_options)
def test_empty_detection_outputs(self):
# Creates pose landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _PoseLandmarkerOptions(base_options=base_options)
with _PoseLandmarker.create_from_options(options) as landmarker:
# Load an image with no poses.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_BURGER_IMAGE))
# Performs pose landmarks detection on the input.
detection_result = landmarker.detect(test_image)
# Comparing results.
self.assertEmpty(detection_result.pose_landmarks)
self.assertEmpty(detection_result.pose_world_landmarks)
self.assertEmpty(detection_result.pose_auxiliary_landmarks)
def test_missing_result_callback(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(_POSE_IMAGE, 0,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0,
PoseLandmarkerResult([], [], []))
)
def test_detect_for_video(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _PoseLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
result = landmarker.detect_for_video(test_image, timestamp,
image_processing_options)
if result.pose_landmarks:
self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_ON_VIDEO_DIFF_MARGIN
)
else:
self.assertEqual(result, expected_result)
def test_calling_detect_in_live_stream_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _PoseLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(_POSE_IMAGE, 0,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
(_BURGER_IMAGE, 0,
PoseLandmarkerResult([], [], []))
)
def test_detect_async_calls(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
observed_timestamp_ms = -1
def check_result(result: PoseLandmarkerResult, output_image: _Image,
timestamp_ms: int):
if result.pose_landmarks:
self._expect_pose_landmarker_results_correct(
result, expected_result, _LANDMARKS_DIFF_MARGIN
)
else:
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _PoseLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _PoseLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp, image_processing_options)
if __name__ == '__main__':
absltest.main()

View File

@ -64,7 +64,7 @@ class PoseLandmarkerResult:
pose_world_landmarks: Detected pose landmarks in world coordinates.
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
ROI for next frame.
segmentation_masks: Segmentation masks for pose.
segmentation_masks: Optional segmentation masks for pose.
"""
pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
@ -77,7 +77,7 @@ def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]
) -> PoseLandmarkerResult:
"""Constructs a `PoseLandmarkerResult` from output packets."""
pose_landmarker_result = PoseLandmarkerResult([], [], [], [])
pose_landmarker_result = PoseLandmarkerResult([], [], [])
if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
@ -356,7 +356,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
image_processing_options: Options for image processing.
Returns:
The pose landmarks detection results.
The pose landmarker detection results.
Raises:
ValueError: If any of the input arguments is invalid.
@ -402,7 +402,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
per input image.
The `result_callback` provides:
- The pose landmarks detection results.
- The pose landmarker detection results.
- The input image that the pose landmarker runs on.
- The input timestamp in milliseconds.