Revised implementation and added more tests

This commit is contained in:
Kinar 2023-12-18 02:47:28 -08:00
parent 88463aeb9e
commit 30e6b766d4
3 changed files with 425 additions and 106 deletions

View File

@ -204,13 +204,8 @@ py_test(
], ],
tags = ["not_run:arm"], tags = ["not_run:arm"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2", "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:holistic_landmarker", "//mediapipe/tasks/python/vision:holistic_landmarker",

View File

@ -14,7 +14,6 @@
"""Tests for holistic landmarker.""" """Tests for holistic landmarker."""
import enum import enum
from typing import List
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -22,13 +21,8 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from google.protobuf import text_format from google.protobuf import text_format
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2 from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import holistic_landmarker from mediapipe.tasks.python.vision import holistic_landmarker
@ -39,10 +33,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult
_HolisticResultProto = holistic_result_pb2.HolisticResult _HolisticResultProto = holistic_result_pb2.HolisticResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Category = category_module.Category
_Rect = rect_module.Rect
_Landmark = landmark_module.Landmark
_NormalizedLandmark = landmark_module.NormalizedLandmark
_Image = image_module.Image _Image = image_module.Image
_HolisticLandmarker = holistic_landmarker.HolisticLandmarker _HolisticLandmarker = holistic_landmarker.HolisticLandmarker
_HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions _HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions
@ -53,23 +43,27 @@ _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task'
_POSE_IMAGE = 'male_full_height_hands.jpg' _POSE_IMAGE = 'male_full_height_hands.jpg'
_CAT_IMAGE = 'cat.jpg' _CAT_IMAGE = 'cat.jpg'
_EXPECTED_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt" _EXPECTED_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt"
_IMAGE_WIDTH = 638
_IMAGE_HEIGHT = 1000
_LANDMARKS_MARGIN = 0.03 _LANDMARKS_MARGIN = 0.03
_BLENDSHAPES_MARGIN = 0.13 _BLENDSHAPES_MARGIN = 0.13
_VIDEO_LANDMARKS_MARGIN = 0.03
_VIDEO_BLENDSHAPES_MARGIN = 0.31
_LIVE_STREAM_LANDMARKS_MARGIN = 0.03
_LIVE_STREAM_BLENDSHAPES_MARGIN = 0.31
def _get_expected_holistic_landmarker_result( def _get_expected_holistic_landmarker_result(
file_path: str, file_path: str,
) -> HolisticLandmarkerResult: ) -> HolisticLandmarkerResult:
holistic_result_file_path = test_utils.get_test_data_path( holistic_result_file_path = test_utils.get_test_data_path(file_path)
file_path
)
with open(holistic_result_file_path, 'rb') as f: with open(holistic_result_file_path, 'rb') as f:
holistic_result_proto = _HolisticResultProto() holistic_result_proto = _HolisticResultProto()
# Use this if a .pb file is available. # Use this if a .pb file is available.
# holistic_result_proto.ParseFromString(f.read()) # holistic_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), holistic_result_proto) text_format.Parse(f.read(), holistic_result_proto)
holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2( holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2(
holistic_result_proto holistic_result_proto
) )
return holistic_landmarker_result return holistic_landmarker_result
@ -108,38 +102,70 @@ class HolisticLandmarkerTest(parameterized.TestCase):
for i, elem in enumerate(actual_blendshapes): for i, elem in enumerate(actual_blendshapes):
self.assertEqual(elem.index, expected_blendshapes[i].index) self.assertEqual(elem.index, expected_blendshapes[i].index)
self.assertEqual(elem.category_name, expected_blendshapes[i].category_name)
self.assertAlmostEqual( self.assertAlmostEqual(
elem.score, elem.score,
expected_blendshapes[i].score, expected_blendshapes[i].score,
delta=margin, delta=margin,
) )
def _expect_holistic_landmarker_results_correct( def _expect_holistic_landmarker_results_correct(
self, self,
actual_result: HolisticLandmarkerResult, actual_result: HolisticLandmarkerResult,
expected_result: HolisticLandmarkerResult, expected_result: HolisticLandmarkerResult,
output_segmentation_masks: bool, output_segmentation_mask: bool,
landmarks_margin: float, landmarks_margin: float,
blendshapes_margin: float, blendshapes_margin: float,
): ):
self._expect_landmarks_correct( self._expect_landmarks_correct(
actual_result.pose_landmarks, expected_result.pose_landmarks, actual_result.pose_landmarks, expected_result.pose_landmarks,
landmarks_margin landmarks_margin
) )
self._expect_landmarks_correct( self._expect_landmarks_correct(
actual_result.face_landmarks, expected_result.face_landmarks, actual_result.face_landmarks, expected_result.face_landmarks,
landmarks_margin landmarks_margin
) )
self._expect_blendshapes_correct( self._expect_blendshapes_correct(
actual_result.face_blendshapes, expected_result.face_blendshapes, actual_result.face_blendshapes, expected_result.face_blendshapes,
blendshapes_margin blendshapes_margin
) )
if output_segmentation_masks: if output_segmentation_mask:
self.assertIsInstance(actual_result.segmentation_masks, List) self.assertIsInstance(actual_result.segmentation_mask, _Image)
for _, mask in enumerate(actual_result.segmentation_masks): self.assertEqual(actual_result.segmentation_mask.width, _IMAGE_WIDTH)
self.assertIsInstance(mask, _Image) self.assertEqual(actual_result.segmentation_mask.height, _IMAGE_HEIGHT)
else: else:
self.assertIsNone(actual_result.segmentation_masks) self.assertIsNone(actual_result.segmentation_mask)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _HolisticLandmarker.create_from_model_path(self.model_path) as landmarker:
self.assertIsInstance(landmarker, _HolisticLandmarker)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HolisticLandmarkerOptions(base_options=base_options)
with _HolisticLandmarker.create_from_options(options) as landmarker:
self.assertIsInstance(landmarker, _HolisticLandmarker)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite'
)
options = _HolisticLandmarkerOptions(base_options=base_options)
_HolisticLandmarker.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _HolisticLandmarkerOptions(base_options=base_options)
landmarker = _HolisticLandmarker.create_from_options(options)
self.assertIsInstance(landmarker, _HolisticLandmarker)
@parameterized.parameters( @parameterized.parameters(
( (
@ -154,13 +180,25 @@ class HolisticLandmarkerTest(parameterized.TestCase):
False, False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
), ),
(
ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
(
ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
) )
def test_detect( def test_detect(
self, self,
model_file_type, model_file_type,
model_name, model_name,
output_segmentation_masks, output_segmentation_mask,
expected_holistic_landmarker_result: HolisticLandmarkerResult expected_holistic_landmarker_result
): ):
# Creates holistic landmarker. # Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name) model_path = test_utils.get_test_data_path(model_name)
@ -178,7 +216,7 @@ class HolisticLandmarkerTest(parameterized.TestCase):
base_options=base_options, base_options=base_options,
output_face_blendshapes=True output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes else False, if expected_holistic_landmarker_result.face_blendshapes else False,
output_segmentation_masks=output_segmentation_masks, output_segmentation_mask=output_segmentation_mask,
) )
landmarker = _HolisticLandmarker.create_from_options(options) landmarker = _HolisticLandmarker.create_from_options(options)
@ -186,12 +224,294 @@ class HolisticLandmarkerTest(parameterized.TestCase):
detection_result = landmarker.detect(self.test_image) detection_result = landmarker.detect(self.test_image)
self._expect_holistic_landmarker_results_correct( self._expect_holistic_landmarker_results_correct(
detection_result, expected_holistic_landmarker_result, detection_result, expected_holistic_landmarker_result,
output_segmentation_masks, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN output_segmentation_mask, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN
) )
# Closes the holistic landmarker explicitly when the holistic landmarker is # Closes the holistic landmarker explicitly when the holistic landmarker is
# not used in a context. # not used in a context.
landmarker.close() landmarker.close()
@parameterized.parameters(
(
ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
(
ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
)
def test_detect_in_context(
self,
model_file_type,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result
):
# Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _HolisticLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes else False,
output_segmentation_mask=output_segmentation_mask,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
self._expect_holistic_landmarker_results_correct(
detection_result, expected_holistic_landmarker_result,
output_segmentation_mask, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN
)
def test_empty_detection_outputs(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path)
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
# Load the cat image.
cat_test_image = _Image.create_from_file(
test_utils.get_test_data_path(_CAT_IMAGE)
)
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(cat_test_image)
self.assertEmpty(detection_result.face_landmarks)
self.assertEmpty(detection_result.pose_landmarks)
self.assertEmpty(detection_result.pose_world_landmarks)
self.assertEmpty(detection_result.left_hand_landmarks)
self.assertEmpty(detection_result.left_hand_world_landmarks)
self.assertEmpty(detection_result.right_hand_landmarks)
self.assertEmpty(detection_result.right_hand_world_landmarks)
self.assertIsNone(detection_result.face_blendshapes)
self.assertIsNone(detection_result.segmentation_mask)
def test_missing_result_callback(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _HolisticLandmarker.create_from_options(options) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _HolisticLandmarker.create_from_options(options) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
(
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
)
def test_detect_for_video(
self,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result
):
# Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name)
base_options = _BaseOptions(model_asset_path=model_path)
options = _HolisticLandmarkerOptions(
base_options=base_options,
running_mode=_RUNNING_MODE.VIDEO,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes else False,
output_segmentation_mask=output_segmentation_mask,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect_for_video(
self.test_image, timestamp
)
# Comparing results.
self._expect_holistic_landmarker_results_correct(
detection_result, expected_holistic_landmarker_result,
output_segmentation_mask,
_VIDEO_LANDMARKS_MARGIN, _VIDEO_BLENDSHAPES_MARGIN
)
def test_calling_detect_in_live_stream_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(
_POSE_IMAGE,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
(
_POSE_IMAGE,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
),
)
def test_detect_async_calls(
self,
image_path,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result
):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)
)
observed_timestamp_ms = -1
def check_result(
result: HolisticLandmarkerResult, output_image: _Image, timestamp_ms: int
):
# Comparing results.
self._expect_holistic_landmarker_results_correct(
result, expected_holistic_landmarker_result,
output_segmentation_mask,
_LIVE_STREAM_LANDMARKS_MARGIN, _LIVE_STREAM_BLENDSHAPES_MARGIN
)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view())
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
model_path = test_utils.get_test_data_path(model_name)
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes else False,
output_segmentation_mask=output_segmentation_mask,
result_callback=check_result,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -51,7 +51,7 @@ _POSE_LANDMARKS_TAG_NAME = "POSE_LANDMARKS"
_POSE_WORLD_LANDMARKS_STREAM_NAME = "pose_world_landmarks" _POSE_WORLD_LANDMARKS_STREAM_NAME = "pose_world_landmarks"
_POSE_WORLD_LANDMARKS_TAG = "POSE_WORLD_LANDMARKS" _POSE_WORLD_LANDMARKS_TAG = "POSE_WORLD_LANDMARKS"
_POSE_SEGMENTATION_MASK_STREAM_NAME = "pose_segmentation_mask" _POSE_SEGMENTATION_MASK_STREAM_NAME = "pose_segmentation_mask"
_POSE_SEGMENTATION_MASK_TAG = "pose_segmentation_mask" _POSE_SEGMENTATION_MASK_TAG = "POSE_SEGMENTATION_MASK"
_FACE_LANDMARKS_STREAM_NAME = "face_landmarks" _FACE_LANDMARKS_STREAM_NAME = "face_landmarks"
_FACE_LANDMARKS_TAG = "FACE_LANDMARKS" _FACE_LANDMARKS_TAG = "FACE_LANDMARKS"
_FACE_BLENDSHAPES_STREAM_NAME = "extra_blendshapes" _FACE_BLENDSHAPES_STREAM_NAME = "extra_blendshapes"
@ -84,7 +84,7 @@ class HolisticLandmarkerResult:
right_hand_landmarks: List[landmark_module.NormalizedLandmark] right_hand_landmarks: List[landmark_module.NormalizedLandmark]
right_hand_world_landmarks: List[landmark_module.Landmark] right_hand_world_landmarks: List[landmark_module.Landmark]
face_blendshapes: Optional[List[category_module.Category]] = None face_blendshapes: Optional[List[category_module.Category]] = None
segmentation_masks: Optional[List[image_module.Image]] = None segmentation_mask: Optional[image_module.Image] = None
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -96,41 +96,41 @@ class HolisticLandmarkerResult:
object.""" object."""
return HolisticLandmarkerResult( return HolisticLandmarkerResult(
face_landmarks=[ face_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark) landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.face_landmarks.landmark for landmark in pb2_obj.face_landmarks.landmark
] if hasattr(pb2_obj, 'face_landmarks') else None, ] if hasattr(pb2_obj, 'face_landmarks') else None,
pose_landmarks=[ pose_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark) landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_landmarks.landmark for landmark in pb2_obj.pose_landmarks.landmark
] if hasattr(pb2_obj, 'pose_landmarks') else None, ] if hasattr(pb2_obj, 'pose_landmarks') else None,
pose_world_landmarks=[ pose_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark) landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_world_landmarks.landmark for landmark in pb2_obj.pose_world_landmarks.landmark
] if hasattr(pb2_obj, 'pose_world_landmarks') else None, ] if hasattr(pb2_obj, 'pose_world_landmarks') else None,
left_hand_landmarks=[ left_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark) landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.left_hand_landmarks.landmark for landmark in pb2_obj.left_hand_landmarks.landmark
] if hasattr(pb2_obj, 'left_hand_landmarks') else None, ] if hasattr(pb2_obj, 'left_hand_landmarks') else None,
left_hand_world_landmarks=[ left_hand_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark) landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.left_hand_world_landmarks.landmark for landmark in pb2_obj.left_hand_world_landmarks.landmark
] if hasattr(pb2_obj, 'left_hand_world_landmarks') else None, ] if hasattr(pb2_obj, 'left_hand_world_landmarks') else None,
right_hand_landmarks=[ right_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark) landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.right_hand_landmarks.landmark for landmark in pb2_obj.right_hand_landmarks.landmark
] if hasattr(pb2_obj, 'right_hand_landmarks') else None, ] if hasattr(pb2_obj, 'right_hand_landmarks') else None,
right_hand_world_landmarks=[ right_hand_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark) landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.right_hand_world_landmarks.landmark for landmark in pb2_obj.right_hand_world_landmarks.landmark
] if hasattr(pb2_obj, 'right_hand_world_landmarks') else None, ] if hasattr(pb2_obj, 'right_hand_world_landmarks') else None,
face_blendshapes=[ face_blendshapes=[
category_module.Category( category_module.Category(
score=classification.score, score=classification.score,
index=classification.index, index=classification.index,
category_name=classification.label, category_name=classification.label,
display_name=classification.display_name display_name=classification.display_name
) )
for classification in pb2_obj.face_blendshapes.classification for classification in pb2_obj.face_blendshapes.classification
] if hasattr(pb2_obj, 'face_blendshapes') else None, ] if hasattr(pb2_obj, 'face_blendshapes') else None,
) )
@ -147,98 +147,98 @@ def _build_landmarker_result(
) )
pose_landmarks_proto_list = packet_getter.get_proto( pose_landmarks_proto_list = packet_getter.get_proto(
output_packets[_POSE_LANDMARKS_STREAM_NAME] output_packets[_POSE_LANDMARKS_STREAM_NAME]
) )
pose_world_landmarks_proto_list = packet_getter.get_proto( pose_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME]
) )
left_hand_landmarks_proto_list = packet_getter.get_proto( left_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME]
) )
left_hand_world_landmarks_proto_list = packet_getter.get_proto( left_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME]
) )
right_hand_landmarks_proto_list = packet_getter.get_proto( right_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME]
) )
right_hand_world_landmarks_proto_list = packet_getter.get_proto( right_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME]
) )
face_landmarks = landmark_pb2.NormalizedLandmarkList() face_landmarks = landmark_pb2.NormalizedLandmarkList()
face_landmarks.MergeFrom(face_landmarks_proto_list) face_landmarks.MergeFrom(face_landmarks_proto_list)
for face_landmark in face_landmarks.landmark: for face_landmark in face_landmarks.landmark:
holistic_landmarker_result.face_landmarks.append( holistic_landmarker_result.face_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
) )
pose_landmarks = landmark_pb2.NormalizedLandmarkList() pose_landmarks = landmark_pb2.NormalizedLandmarkList()
pose_landmarks.MergeFrom(pose_landmarks_proto_list) pose_landmarks.MergeFrom(pose_landmarks_proto_list)
for pose_landmark in pose_landmarks.landmark: for pose_landmark in pose_landmarks.landmark:
holistic_landmarker_result.pose_landmarks.append( holistic_landmarker_result.pose_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark) landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark)
) )
pose_world_landmarks = landmark_pb2.LandmarkList() pose_world_landmarks = landmark_pb2.LandmarkList()
pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list) pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list)
for pose_world_landmark in pose_world_landmarks.landmark: for pose_world_landmark in pose_world_landmarks.landmark:
holistic_landmarker_result.pose_world_landmarks.append( holistic_landmarker_result.pose_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(pose_world_landmark) landmark_module.Landmark.create_from_pb2(pose_world_landmark)
) )
left_hand_landmarks = landmark_pb2.NormalizedLandmarkList() left_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list) left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list)
for hand_landmark in left_hand_landmarks.landmark: for hand_landmark in left_hand_landmarks.landmark:
holistic_landmarker_result.left_hand_landmarks.append( holistic_landmarker_result.left_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
) )
left_hand_world_landmarks = landmark_pb2.LandmarkList() left_hand_world_landmarks = landmark_pb2.LandmarkList()
left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list) left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list)
for left_hand_world_landmark in left_hand_world_landmarks.landmark: for left_hand_world_landmark in left_hand_world_landmarks.landmark:
holistic_landmarker_result.left_hand_world_landmarks.append( holistic_landmarker_result.left_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(left_hand_world_landmark) landmark_module.Landmark.create_from_pb2(left_hand_world_landmark)
) )
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList() right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list) right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list)
for hand_landmark in right_hand_landmarks.landmark: for hand_landmark in right_hand_landmarks.landmark:
holistic_landmarker_result.right_hand_landmarks.append( holistic_landmarker_result.right_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
) )
right_hand_world_landmarks = landmark_pb2.LandmarkList() right_hand_world_landmarks = landmark_pb2.LandmarkList()
right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list) right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list)
for right_hand_world_landmark in right_hand_world_landmarks.landmark: for right_hand_world_landmark in right_hand_world_landmarks.landmark:
holistic_landmarker_result.right_hand_world_landmarks.append( holistic_landmarker_result.right_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(right_hand_world_landmark) landmark_module.Landmark.create_from_pb2(right_hand_world_landmark)
) )
if _FACE_BLENDSHAPES_STREAM_NAME in output_packets: if _FACE_BLENDSHAPES_STREAM_NAME in output_packets:
face_blendshapes_proto_list = packet_getter.get_proto( face_blendshapes_proto_list = packet_getter.get_proto(
output_packets[_FACE_BLENDSHAPES_STREAM_NAME] output_packets[_FACE_BLENDSHAPES_STREAM_NAME]
) )
face_blendshapes_classifications = classification_pb2.ClassificationList() face_blendshapes_classifications = classification_pb2.ClassificationList()
face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list) face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list)
holistic_landmarker_result.face_blendshapes = [] holistic_landmarker_result.face_blendshapes = []
for face_blendshapes in face_blendshapes_classifications.classification: for face_blendshapes in face_blendshapes_classifications.classification:
holistic_landmarker_result.face_blendshapes.append( holistic_landmarker_result.face_blendshapes.append(
category_module.Category( category_module.Category(
index=face_blendshapes.index, index=face_blendshapes.index,
score=face_blendshapes.score, score=face_blendshapes.score,
display_name=face_blendshapes.display_name, display_name=face_blendshapes.display_name,
category_name=face_blendshapes.label, category_name=face_blendshapes.label,
) )
) )
if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets: if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets:
holistic_landmarker_result.segmentation_masks = packet_getter.get_image_list( holistic_landmarker_result.segmentation_mask = packet_getter.get_image(
output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME] output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME]
) )
return holistic_landmarker_result return holistic_landmarker_result
@ -273,7 +273,7 @@ class HolisticLandmarkerOptions:
landmark detection to be considered successful. landmark detection to be considered successful.
output_face_blendshapes: Whether HolisticLandmarker outputs face blendshapes output_face_blendshapes: Whether HolisticLandmarker outputs face blendshapes
classification. Face blendshapes are used for rendering the 3D face model. classification. Face blendshapes are used for rendering the 3D face model.
output_segmentation_masks: whether to output segmentation masks. output_segmentation_mask: whether to output segmentation masks.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
@ -290,7 +290,7 @@ class HolisticLandmarkerOptions:
min_pose_landmarks_confidence: float = 0.5 min_pose_landmarks_confidence: float = 0.5
min_hand_landmarks_confidence: float = 0.5 min_hand_landmarks_confidence: float = 0.5
output_face_blendshapes: bool = False output_face_blendshapes: bool = False
output_segmentation_masks: bool = False output_segmentation_mask: bool = False
result_callback: Optional[ result_callback: Optional[
Callable[[HolisticLandmarkerResult, image_module.Image, int], None] Callable[[HolisticLandmarkerResult, image_module.Image, int], None]
] = None ] = None
@ -319,17 +319,17 @@ class HolisticLandmarkerOptions:
) )
# Configure pose detector and pose landmarks detector options. # Configure pose detector and pose landmarks detector options.
holistic_landmarker_options_proto.pose_detector_graph_options.min_detection_confidence = ( holistic_landmarker_options_proto.pose_detector_graph_options.min_detection_confidence = (
self.min_pose_detection_confidence self.min_pose_detection_confidence
) )
holistic_landmarker_options_proto.pose_detector_graph_options.min_suppression_threshold = ( holistic_landmarker_options_proto.pose_detector_graph_options.min_suppression_threshold = (
self.min_pose_suppression_threshold self.min_pose_suppression_threshold
) )
holistic_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = ( holistic_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = (
self.min_pose_landmarks_confidence self.min_pose_landmarks_confidence
) )
# Configure hand landmarks detector options. # Configure hand landmarks detector options.
holistic_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = ( holistic_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_landmarks_confidence self.min_hand_landmarks_confidence
) )
return holistic_landmarker_options_proto return holistic_landmarker_options_proto
@ -404,30 +404,34 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
) )
output_streams = [ output_streams = [
':'.join([_FACE_LANDMARKS_TAG, _FACE_LANDMARKS_STREAM_NAME]), ':'.join([_FACE_LANDMARKS_TAG, _FACE_LANDMARKS_STREAM_NAME]),
':'.join([_POSE_LANDMARKS_TAG_NAME, _POSE_LANDMARKS_STREAM_NAME]), ':'.join([_POSE_LANDMARKS_TAG_NAME, _POSE_LANDMARKS_STREAM_NAME]),
':'.join( ':'.join(
[_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME] [_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME]
), ),
':'.join([_LEFT_HAND_LANDMARKS_TAG, _LEFT_HAND_LANDMARKS_STREAM_NAME]), ':'.join([_LEFT_HAND_LANDMARKS_TAG, _LEFT_HAND_LANDMARKS_STREAM_NAME]),
':'.join( ':'.join(
[_LEFT_HAND_WORLD_LANDMARKS_TAG, _LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] [_LEFT_HAND_WORLD_LANDMARKS_TAG,
), _LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME]
':'.join([_RIGHT_HAND_LANDMARKS_TAG, _RIGHT_HAND_LANDMARKS_STREAM_NAME]), ),
':'.join( ':'.join([_RIGHT_HAND_LANDMARKS_TAG,
[_RIGHT_HAND_WORLD_LANDMARKS_TAG, _RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] _RIGHT_HAND_LANDMARKS_STREAM_NAME]),
), ':'.join(
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), [_RIGHT_HAND_WORLD_LANDMARKS_TAG,
_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME]
),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
] ]
if options.output_segmentation_masks: if options.output_segmentation_mask:
output_streams.append( output_streams.append(
':'.join([_POSE_SEGMENTATION_MASK_TAG, _POSE_SEGMENTATION_MASK_STREAM_NAME]) ':'.join([_POSE_SEGMENTATION_MASK_TAG,
_POSE_SEGMENTATION_MASK_STREAM_NAME])
) )
if options.output_face_blendshapes: if options.output_face_blendshapes:
output_streams.append( output_streams.append(
':'.join([_FACE_BLENDSHAPES_TAG, _FACE_BLENDSHAPES_STREAM_NAME]) ':'.join([_FACE_BLENDSHAPES_TAG, _FACE_BLENDSHAPES_STREAM_NAME])
) )
task_info = _TaskInfo( task_info = _TaskInfo(