Added more pose landmarker tests and updated face landmarker tests to cover all the results
This commit is contained in:
parent
39742b6641
commit
1688d0fa79
|
@ -58,17 +58,20 @@ _FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02
|
||||||
|
|
||||||
def _get_expected_face_landmarks(file_path: str):
|
def _get_expected_face_landmarks(file_path: str):
|
||||||
proto_file_path = test_utils.get_test_data_path(file_path)
|
proto_file_path = test_utils.get_test_data_path(file_path)
|
||||||
|
face_landmarks_results = []
|
||||||
with open(proto_file_path, 'rb') as f:
|
with open(proto_file_path, 'rb') as f:
|
||||||
proto = landmark_pb2.NormalizedLandmarkList()
|
proto = landmark_pb2.NormalizedLandmarkList()
|
||||||
text_format.Parse(f.read(), proto)
|
text_format.Parse(f.read(), proto)
|
||||||
face_landmarks = []
|
face_landmarks = []
|
||||||
for landmark in proto.landmark:
|
for landmark in proto.landmark:
|
||||||
face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
|
face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
|
||||||
return face_landmarks
|
face_landmarks_results.append(face_landmarks)
|
||||||
|
return face_landmarks_results
|
||||||
|
|
||||||
|
|
||||||
def _get_expected_face_blendshapes(file_path: str):
|
def _get_expected_face_blendshapes(file_path: str):
|
||||||
proto_file_path = test_utils.get_test_data_path(file_path)
|
proto_file_path = test_utils.get_test_data_path(file_path)
|
||||||
|
face_blendshapes_results = []
|
||||||
with open(proto_file_path, 'rb') as f:
|
with open(proto_file_path, 'rb') as f:
|
||||||
proto = classification_pb2.ClassificationList()
|
proto = classification_pb2.ClassificationList()
|
||||||
text_format.Parse(f.read(), proto)
|
text_format.Parse(f.read(), proto)
|
||||||
|
@ -84,7 +87,8 @@ def _get_expected_face_blendshapes(file_path: str):
|
||||||
category_name=face_blendshapes.label,
|
category_name=face_blendshapes.label,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return face_blendshapes_categories
|
face_blendshapes_results.append(face_blendshapes_categories)
|
||||||
|
return face_blendshapes_results
|
||||||
|
|
||||||
|
|
||||||
def _get_expected_facial_transformation_matrixes():
|
def _get_expected_facial_transformation_matrixes():
|
||||||
|
@ -119,13 +123,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Expects to have the same number of faces detected.
|
# Expects to have the same number of faces detected.
|
||||||
self.assertLen(actual_landmarks, len(expected_landmarks))
|
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||||
|
|
||||||
for i, elem in enumerate(actual_landmarks):
|
for i, _ in enumerate(actual_landmarks):
|
||||||
self.assertAlmostEqual(
|
for j, elem in enumerate(actual_landmarks[i]):
|
||||||
elem.x, expected_landmarks[i].x, delta=_LANDMARKS_DIFF_MARGIN
|
self.assertAlmostEqual(
|
||||||
)
|
elem.x, expected_landmarks[i][j].x, delta=_LANDMARKS_DIFF_MARGIN
|
||||||
self.assertAlmostEqual(
|
)
|
||||||
elem.y, expected_landmarks[i].y, delta=_LANDMARKS_DIFF_MARGIN
|
self.assertAlmostEqual(
|
||||||
)
|
elem.y, expected_landmarks[i][j].y, delta=_LANDMARKS_DIFF_MARGIN
|
||||||
|
)
|
||||||
|
|
||||||
def _expect_blendshapes_correct(
|
def _expect_blendshapes_correct(
|
||||||
self, actual_blendshapes, expected_blendshapes
|
self, actual_blendshapes, expected_blendshapes
|
||||||
|
@ -133,13 +138,14 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Expects to have the same number of blendshapes.
|
# Expects to have the same number of blendshapes.
|
||||||
self.assertLen(actual_blendshapes, len(expected_blendshapes))
|
self.assertLen(actual_blendshapes, len(expected_blendshapes))
|
||||||
|
|
||||||
for i, elem in enumerate(actual_blendshapes):
|
for i, _ in enumerate(actual_blendshapes):
|
||||||
self.assertEqual(elem.index, expected_blendshapes[i].index)
|
for j, elem in enumerate(actual_blendshapes[i]):
|
||||||
self.assertAlmostEqual(
|
self.assertEqual(elem.index, expected_blendshapes[i][j].index)
|
||||||
elem.score,
|
self.assertAlmostEqual(
|
||||||
expected_blendshapes[i].score,
|
elem.score,
|
||||||
delta=_BLENDSHAPES_DIFF_MARGIN,
|
expected_blendshapes[i][j].score,
|
||||||
)
|
delta=_BLENDSHAPES_DIFF_MARGIN,
|
||||||
|
)
|
||||||
|
|
||||||
def _expect_facial_transformation_matrixes_correct(
|
def _expect_facial_transformation_matrixes_correct(
|
||||||
self, actual_matrix_list, expected_matrix_list
|
self, actual_matrix_list, expected_matrix_list
|
||||||
|
@ -236,11 +242,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
if expected_face_landmarks is not None:
|
if expected_face_landmarks is not None:
|
||||||
self._expect_landmarks_correct(
|
self._expect_landmarks_correct(
|
||||||
detection_result.face_landmarks[0], expected_face_landmarks
|
detection_result.face_landmarks, expected_face_landmarks
|
||||||
)
|
)
|
||||||
if expected_face_blendshapes is not None:
|
if expected_face_blendshapes is not None:
|
||||||
self._expect_blendshapes_correct(
|
self._expect_blendshapes_correct(
|
||||||
detection_result.face_blendshapes[0], expected_face_blendshapes
|
detection_result.face_blendshapes, expected_face_blendshapes
|
||||||
)
|
)
|
||||||
if expected_facial_transformation_matrixes is not None:
|
if expected_facial_transformation_matrixes is not None:
|
||||||
self._expect_facial_transformation_matrixes_correct(
|
self._expect_facial_transformation_matrixes_correct(
|
||||||
|
@ -302,11 +308,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
if expected_face_landmarks is not None:
|
if expected_face_landmarks is not None:
|
||||||
self._expect_landmarks_correct(
|
self._expect_landmarks_correct(
|
||||||
detection_result.face_landmarks[0], expected_face_landmarks
|
detection_result.face_landmarks, expected_face_landmarks
|
||||||
)
|
)
|
||||||
if expected_face_blendshapes is not None:
|
if expected_face_blendshapes is not None:
|
||||||
self._expect_blendshapes_correct(
|
self._expect_blendshapes_correct(
|
||||||
detection_result.face_blendshapes[0], expected_face_blendshapes
|
detection_result.face_blendshapes, expected_face_blendshapes
|
||||||
)
|
)
|
||||||
if expected_facial_transformation_matrixes is not None:
|
if expected_facial_transformation_matrixes is not None:
|
||||||
self._expect_facial_transformation_matrixes_correct(
|
self._expect_facial_transformation_matrixes_correct(
|
||||||
|
@ -446,11 +452,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
if expected_face_landmarks is not None:
|
if expected_face_landmarks is not None:
|
||||||
self._expect_landmarks_correct(
|
self._expect_landmarks_correct(
|
||||||
detection_result.face_landmarks[0], expected_face_landmarks
|
detection_result.face_landmarks, expected_face_landmarks
|
||||||
)
|
)
|
||||||
if expected_face_blendshapes is not None:
|
if expected_face_blendshapes is not None:
|
||||||
self._expect_blendshapes_correct(
|
self._expect_blendshapes_correct(
|
||||||
detection_result.face_blendshapes[0], expected_face_blendshapes
|
detection_result.face_blendshapes, expected_face_blendshapes
|
||||||
)
|
)
|
||||||
if expected_facial_transformation_matrixes is not None:
|
if expected_facial_transformation_matrixes is not None:
|
||||||
self._expect_facial_transformation_matrixes_correct(
|
self._expect_facial_transformation_matrixes_correct(
|
||||||
|
@ -523,11 +529,11 @@ class FaceLandmarkerTest(parameterized.TestCase):
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
if expected_face_landmarks is not None:
|
if expected_face_landmarks is not None:
|
||||||
self._expect_landmarks_correct(
|
self._expect_landmarks_correct(
|
||||||
result.face_landmarks[0], expected_face_landmarks
|
result.face_landmarks, expected_face_landmarks
|
||||||
)
|
)
|
||||||
if expected_face_blendshapes is not None:
|
if expected_face_blendshapes is not None:
|
||||||
self._expect_blendshapes_correct(
|
self._expect_blendshapes_correct(
|
||||||
result.face_blendshapes[0], expected_face_blendshapes
|
result.face_blendshapes, expected_face_blendshapes
|
||||||
)
|
)
|
||||||
if expected_facial_transformation_matrixes is not None:
|
if expected_facial_transformation_matrixes is not None:
|
||||||
self._expect_facial_transformation_matrixes_correct(
|
self._expect_facial_transformation_matrixes_correct(
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"""Tests for pose landmarker."""
|
"""Tests for pose landmarker."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from typing import List
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -49,8 +50,8 @@ _POSE_LANDMARKER_BUNDLE_ASSET_FILE = 'pose_landmarker.task'
|
||||||
_BURGER_IMAGE = 'burger.jpg'
|
_BURGER_IMAGE = 'burger.jpg'
|
||||||
_POSE_IMAGE = 'pose.jpg'
|
_POSE_IMAGE = 'pose.jpg'
|
||||||
_POSE_LANDMARKS = 'pose_landmarks.pbtxt'
|
_POSE_LANDMARKS = 'pose_landmarks.pbtxt'
|
||||||
_LANDMARKS_ERROR_TOLERANCE = 0.03
|
_LANDMARKS_DIFF_MARGIN = 0.03
|
||||||
_LANDMARKS_ON_VIDEO_ERROR_TOLERANCE = 0.03
|
_LANDMARKS_ON_VIDEO_DIFF_MARGIN = 0.03
|
||||||
|
|
||||||
|
|
||||||
def _get_expected_pose_landmarker_result(
|
def _get_expected_pose_landmarker_result(
|
||||||
|
@ -85,33 +86,34 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
self.model_path = test_utils.get_test_data_path(
|
self.model_path = test_utils.get_test_data_path(
|
||||||
_POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
_POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
|
||||||
|
def _expect_pose_landmarks_correct(
|
||||||
|
self,
|
||||||
|
actual_landmarks: List[List[landmark_module.NormalizedLandmark]],
|
||||||
|
expected_landmarks: List[List[landmark_module.NormalizedLandmark]],
|
||||||
|
diff_margin: float
|
||||||
|
):
|
||||||
|
# Expects to have the same number of poses detected.
|
||||||
|
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||||
|
|
||||||
|
for i, _ in enumerate(actual_landmarks):
|
||||||
|
for j, elem in enumerate(actual_landmarks[i]):
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
elem.x, expected_landmarks[i][j].x, delta=diff_margin
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
elem.y, expected_landmarks[i][j].y, delta=diff_margin
|
||||||
|
)
|
||||||
|
|
||||||
def _expect_pose_landmarker_results_correct(
|
def _expect_pose_landmarker_results_correct(
|
||||||
self,
|
self,
|
||||||
actual_result: PoseLandmarkerResult,
|
actual_result: PoseLandmarkerResult,
|
||||||
expected_result: PoseLandmarkerResult,
|
expected_result: PoseLandmarkerResult,
|
||||||
error_tolerance: float
|
diff_margin: float
|
||||||
):
|
):
|
||||||
# Expects to have the same number of poses detected.
|
self._expect_pose_landmarks_correct(
|
||||||
self.assertLen(actual_result.pose_landmarks,
|
actual_result.pose_landmarks, expected_result.pose_landmarks,
|
||||||
len(expected_result.pose_landmarks))
|
diff_margin
|
||||||
self.assertLen(actual_result.pose_world_landmarks,
|
)
|
||||||
len(expected_result.pose_world_landmarks))
|
|
||||||
self.assertLen(actual_result.pose_auxiliary_landmarks,
|
|
||||||
len(expected_result.pose_auxiliary_landmarks))
|
|
||||||
# Actual landmarks match expected landmarks.
|
|
||||||
actual_landmarks = actual_result.pose_landmarks[0]
|
|
||||||
expected_landmarks = expected_result.pose_landmarks[0]
|
|
||||||
for i, pose_landmark in enumerate(actual_landmarks):
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
pose_landmark.x,
|
|
||||||
expected_landmarks[i].x,
|
|
||||||
delta=error_tolerance
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
pose_landmark.y,
|
|
||||||
expected_landmarks[i].y,
|
|
||||||
delta=error_tolerance
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -146,7 +148,8 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
(ModelFileType.FILE_NAME,
|
(ModelFileType.FILE_NAME,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
(ModelFileType.FILE_CONTENT,
|
(ModelFileType.FILE_CONTENT,
|
||||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)))
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
||||||
|
)
|
||||||
def test_detect(self, model_file_type, expected_detection_result):
|
def test_detect(self, model_file_type, expected_detection_result):
|
||||||
# Creates pose landmarker.
|
# Creates pose landmarker.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
@ -164,14 +167,229 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Performs pose landmarks detection on the input.
|
# Performs pose landmarks detection on the input.
|
||||||
detection_result = landmarker.detect(self.test_image)
|
detection_result = landmarker.detect(self.test_image)
|
||||||
|
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self._expect_pose_landmarker_results_correct(
|
self._expect_pose_landmarker_results_correct(
|
||||||
detection_result, expected_detection_result, _LANDMARKS_ERROR_TOLERANCE
|
detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
|
||||||
)
|
)
|
||||||
# Closes the pose landmarker explicitly when the pose landmarker is not used
|
# Closes the pose landmarker explicitly when the pose landmarker is not used
|
||||||
# in a context.
|
# in a context.
|
||||||
landmarker.close()
|
landmarker.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_CONTENT,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS))
|
||||||
|
)
|
||||||
|
def test_detect_in_context(self, model_file_type, expected_detection_result):
|
||||||
|
# Creates pose landmarker.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Performs pose landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(self.test_image)
|
||||||
|
|
||||||
|
# Comparing results.
|
||||||
|
self._expect_pose_landmarker_results_correct(
|
||||||
|
detection_result, expected_detection_result, _LANDMARKS_DIFF_MARGIN
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_detect_fails_with_region_of_interest(self):
|
||||||
|
# Creates pose landmarker.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "This task doesn't support region-of-interest."):
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
region_of_interest=_Rect(0, 0, 1, 1))
|
||||||
|
# Attempt to perform pose landmarks detection on the cropped input.
|
||||||
|
landmarker.detect(self.test_image, image_processing_options)
|
||||||
|
|
||||||
|
def test_empty_detection_outputs(self):
|
||||||
|
# Creates pose landmarker.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Load an image with no poses.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_BURGER_IMAGE))
|
||||||
|
# Performs pose landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(test_image)
|
||||||
|
# Comparing results.
|
||||||
|
self.assertEmpty(detection_result.pose_landmarks)
|
||||||
|
self.assertEmpty(detection_result.pose_world_landmarks)
|
||||||
|
self.assertEmpty(detection_result.pose_auxiliary_landmarks)
|
||||||
|
|
||||||
|
def test_missing_result_callback(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback must be provided'):
|
||||||
|
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||||
|
def test_illegal_result_callback(self, running_mode):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=running_mode,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback should not be provided'):
|
||||||
|
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_calling_detect_for_video_in_image_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_detect_async_in_image_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_detect_in_video_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
landmarker.detect(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_detect_async_in_video_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(_POSE_IMAGE, 0,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(_BURGER_IMAGE, 0,
|
||||||
|
PoseLandmarkerResult([], [], []))
|
||||||
|
)
|
||||||
|
def test_detect_for_video(self, image_path, rotation, expected_result):
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(image_path))
|
||||||
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
rotation_degrees=rotation)
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
result = landmarker.detect_for_video(test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
if result.pose_landmarks:
|
||||||
|
self._expect_pose_landmarker_results_correct(
|
||||||
|
result, expected_result, _LANDMARKS_ON_VIDEO_DIFF_MARGIN
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_calling_detect_in_live_stream_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
landmarker.detect(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
landmarker.detect_async(self.test_image, 100)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(_POSE_IMAGE, 0,
|
||||||
|
_get_expected_pose_landmarker_result(_POSE_LANDMARKS)),
|
||||||
|
(_BURGER_IMAGE, 0,
|
||||||
|
PoseLandmarkerResult([], [], []))
|
||||||
|
)
|
||||||
|
def test_detect_async_calls(self, image_path, rotation, expected_result):
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(image_path))
|
||||||
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
rotation_degrees=rotation)
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: PoseLandmarkerResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
if result.pose_landmarks:
|
||||||
|
self._expect_pose_landmarker_results_correct(
|
||||||
|
result, expected_result, _LANDMARKS_DIFF_MARGIN
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
options = _PoseLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
landmarker.detect_async(test_image, timestamp, image_processing_options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -64,7 +64,7 @@ class PoseLandmarkerResult:
|
||||||
pose_world_landmarks: Detected pose landmarks in world coordinates.
|
pose_world_landmarks: Detected pose landmarks in world coordinates.
|
||||||
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
|
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
|
||||||
ROI for next frame.
|
ROI for next frame.
|
||||||
segmentation_masks: Segmentation masks for pose.
|
segmentation_masks: Optional segmentation masks for pose.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||||
|
@ -77,7 +77,7 @@ def _build_landmarker_result(
|
||||||
output_packets: Mapping[str, packet_module.Packet]
|
output_packets: Mapping[str, packet_module.Packet]
|
||||||
) -> PoseLandmarkerResult:
|
) -> PoseLandmarkerResult:
|
||||||
"""Constructs a `PoseLandmarkerResult` from output packets."""
|
"""Constructs a `PoseLandmarkerResult` from output packets."""
|
||||||
pose_landmarker_result = PoseLandmarkerResult([], [], [], [])
|
pose_landmarker_result = PoseLandmarkerResult([], [], [])
|
||||||
|
|
||||||
if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
|
if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
|
||||||
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
|
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
|
||||||
|
@ -356,7 +356,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image_processing_options: Options for image processing.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The pose landmarks detection results.
|
The pose landmarker detection results.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
@ -402,7 +402,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||||
per input image.
|
per input image.
|
||||||
|
|
||||||
The `result_callback` provides:
|
The `result_callback` provides:
|
||||||
- The pose landmarks detection results.
|
- The pose landmarker detection results.
|
||||||
- The input image that the pose landmarker runs on.
|
- The input image that the pose landmarker runs on.
|
||||||
- The input timestamp in milliseconds.
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user