diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index eae05de4d..16a59741d 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -204,13 +204,8 @@ py_test( ], tags = ["not_run:arm"], deps = [ - "//mediapipe/framework/formats:classification_py_pb2", - "//mediapipe/framework/formats:landmark_py_pb2", "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2", "//mediapipe/python:_framework_bindings", - "//mediapipe/tasks/python/components/containers:category", - "//mediapipe/tasks/python/components/containers:landmark", - "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:holistic_landmarker", diff --git a/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py b/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py index 429ea9810..4b624af90 100644 --- a/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py +++ b/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py @@ -14,7 +14,6 @@ """Tests for holistic landmarker.""" import enum -from typing import List from unittest import mock from absl.testing import absltest @@ -22,13 +21,8 @@ from absl.testing import parameterized import numpy as np from google.protobuf import text_format -from mediapipe.framework.formats import classification_pb2 -from mediapipe.framework.formats import landmark_pb2 from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2 from mediapipe.python._framework_bindings import image as image_module -from mediapipe.tasks.python.components.containers import category as category_module -from mediapipe.tasks.python.components.containers import landmark as landmark_module -from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import holistic_landmarker @@ -39,10 +33,6 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult _HolisticResultProto = holistic_result_pb2.HolisticResult _BaseOptions = base_options_module.BaseOptions -_Category = category_module.Category -_Rect = rect_module.Rect -_Landmark = landmark_module.Landmark -_NormalizedLandmark = landmark_module.NormalizedLandmark _Image = image_module.Image _HolisticLandmarker = holistic_landmarker.HolisticLandmarker _HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions @@ -53,23 +43,27 @@ _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task' _POSE_IMAGE = 'male_full_height_hands.jpg' _CAT_IMAGE = 'cat.jpg' _EXPECTED_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt" +_IMAGE_WIDTH = 638 +_IMAGE_HEIGHT = 1000 _LANDMARKS_MARGIN = 0.03 _BLENDSHAPES_MARGIN = 0.13 +_VIDEO_LANDMARKS_MARGIN = 0.03 +_VIDEO_BLENDSHAPES_MARGIN = 0.31 +_LIVE_STREAM_LANDMARKS_MARGIN = 0.03 +_LIVE_STREAM_BLENDSHAPES_MARGIN = 0.31 def _get_expected_holistic_landmarker_result( file_path: str, ) -> HolisticLandmarkerResult: - holistic_result_file_path = test_utils.get_test_data_path( - file_path - ) + holistic_result_file_path = test_utils.get_test_data_path(file_path) with open(holistic_result_file_path, 'rb') as f: holistic_result_proto = _HolisticResultProto() # Use this if a .pb file is available. # holistic_result_proto.ParseFromString(f.read()) text_format.Parse(f.read(), holistic_result_proto) holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2( - holistic_result_proto + holistic_result_proto ) return holistic_landmarker_result @@ -108,38 +102,70 @@ class HolisticLandmarkerTest(parameterized.TestCase): for i, elem in enumerate(actual_blendshapes): self.assertEqual(elem.index, expected_blendshapes[i].index) + self.assertEqual(elem.category_name, expected_blendshapes[i].category_name) self.assertAlmostEqual( - elem.score, - expected_blendshapes[i].score, - delta=margin, + elem.score, + expected_blendshapes[i].score, + delta=margin, ) def _expect_holistic_landmarker_results_correct( self, actual_result: HolisticLandmarkerResult, expected_result: HolisticLandmarkerResult, - output_segmentation_masks: bool, + output_segmentation_mask: bool, landmarks_margin: float, blendshapes_margin: float, ): self._expect_landmarks_correct( - actual_result.pose_landmarks, expected_result.pose_landmarks, - landmarks_margin + actual_result.pose_landmarks, expected_result.pose_landmarks, + landmarks_margin ) self._expect_landmarks_correct( - actual_result.face_landmarks, expected_result.face_landmarks, - landmarks_margin + actual_result.face_landmarks, expected_result.face_landmarks, + landmarks_margin ) self._expect_blendshapes_correct( - actual_result.face_blendshapes, expected_result.face_blendshapes, - blendshapes_margin + actual_result.face_blendshapes, expected_result.face_blendshapes, + blendshapes_margin ) - if output_segmentation_masks: - self.assertIsInstance(actual_result.segmentation_masks, List) - for _, mask in enumerate(actual_result.segmentation_masks): - self.assertIsInstance(mask, _Image) + if output_segmentation_mask: + self.assertIsInstance(actual_result.segmentation_mask, _Image) + self.assertEqual(actual_result.segmentation_mask.width, _IMAGE_WIDTH) + self.assertEqual(actual_result.segmentation_mask.height, _IMAGE_HEIGHT) else: - self.assertIsNone(actual_result.segmentation_masks) + self.assertIsNone(actual_result.segmentation_mask) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _HolisticLandmarker.create_from_model_path(self.model_path) as landmarker: + self.assertIsInstance(landmarker, _HolisticLandmarker) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _HolisticLandmarkerOptions(base_options=base_options) + with _HolisticLandmarker.create_from_options(options) as landmarker: + self.assertIsInstance(landmarker, _HolisticLandmarker) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite' + ) + options = _HolisticLandmarkerOptions(base_options=base_options) + _HolisticLandmarker.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _HolisticLandmarkerOptions(base_options=base_options) + landmarker = _HolisticLandmarker.create_from_options(options) + self.assertIsInstance(landmarker, _HolisticLandmarker) @parameterized.parameters( ( @@ -154,13 +180,25 @@ class HolisticLandmarkerTest(parameterized.TestCase): False, _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) ), + ( + ModelFileType.FILE_NAME, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + True, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ( + ModelFileType.FILE_CONTENT, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + True, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), ) def test_detect( self, model_file_type, model_name, - output_segmentation_masks, - expected_holistic_landmarker_result: HolisticLandmarkerResult + output_segmentation_mask, + expected_holistic_landmarker_result ): # Creates holistic landmarker. model_path = test_utils.get_test_data_path(model_name) @@ -178,7 +216,7 @@ class HolisticLandmarkerTest(parameterized.TestCase): base_options=base_options, output_face_blendshapes=True if expected_holistic_landmarker_result.face_blendshapes else False, - output_segmentation_masks=output_segmentation_masks, + output_segmentation_mask=output_segmentation_mask, ) landmarker = _HolisticLandmarker.create_from_options(options) @@ -186,12 +224,294 @@ class HolisticLandmarkerTest(parameterized.TestCase): detection_result = landmarker.detect(self.test_image) self._expect_holistic_landmarker_results_correct( detection_result, expected_holistic_landmarker_result, - output_segmentation_masks, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN + output_segmentation_mask, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN ) # Closes the holistic landmarker explicitly when the holistic landmarker is # not used in a context. landmarker.close() + @parameterized.parameters( + ( + ModelFileType.FILE_NAME, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + False, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ( + ModelFileType.FILE_CONTENT, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + True, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ) + def test_detect_in_context( + self, + model_file_type, + model_name, + output_segmentation_mask, + expected_holistic_landmarker_result + ): + # Creates holistic landmarker. + model_path = test_utils.get_test_data_path(model_name) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _HolisticLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True + if expected_holistic_landmarker_result.face_blendshapes else False, + output_segmentation_mask=output_segmentation_mask, + ) + + with _HolisticLandmarker.create_from_options(options) as landmarker: + # Performs holistic landmarks detection on the input. + detection_result = landmarker.detect(self.test_image) + self._expect_holistic_landmarker_results_correct( + detection_result, expected_holistic_landmarker_result, + output_segmentation_mask, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN + ) + + def test_empty_detection_outputs(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path) + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + # Load the cat image. + cat_test_image = _Image.create_from_file( + test_utils.get_test_data_path(_CAT_IMAGE) + ) + # Performs holistic landmarks detection on the input. + detection_result = landmarker.detect(cat_test_image) + self.assertEmpty(detection_result.face_landmarks) + self.assertEmpty(detection_result.pose_landmarks) + self.assertEmpty(detection_result.pose_world_landmarks) + self.assertEmpty(detection_result.left_hand_landmarks) + self.assertEmpty(detection_result.left_hand_world_landmarks) + self.assertEmpty(detection_result.right_hand_landmarks) + self.assertEmpty(detection_result.right_hand_world_landmarks) + self.assertIsNone(detection_result.face_blendshapes) + self.assertIsNone(detection_result.segmentation_mask) + + def test_missing_result_callback(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): + with _HolisticLandmarker.create_from_options(options) as unused_landmarker: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): + with _HolisticLandmarker.create_from_options(options) as unused_landmarker: + pass + + def test_calling_detect_for_video_in_image_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + landmarker.detect_for_video(self.test_image, 0) + + def test_calling_detect_async_in_image_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + landmarker.detect_async(self.test_image, 0) + + def test_calling_detect_in_video_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + landmarker.detect(self.test_image) + + def test_calling_detect_async_in_video_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + landmarker.detect_async(self.test_image, 0) + + def test_detect_for_video_with_out_of_order_timestamp(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + unused_result = landmarker.detect_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + landmarker.detect_for_video(self.test_image, 0) + + @parameterized.parameters( + ( + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + False, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ( + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + True, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ) + def test_detect_for_video( + self, + model_name, + output_segmentation_mask, + expected_holistic_landmarker_result + ): + # Creates holistic landmarker. + model_path = test_utils.get_test_data_path(model_name) + base_options = _BaseOptions(model_asset_path=model_path) + options = _HolisticLandmarkerOptions( + base_options=base_options, + running_mode=_RUNNING_MODE.VIDEO, + output_face_blendshapes=True + if expected_holistic_landmarker_result.face_blendshapes else False, + output_segmentation_mask=output_segmentation_mask, + ) + + with _HolisticLandmarker.create_from_options(options) as landmarker: + for timestamp in range(0, 300, 30): + # Performs holistic landmarks detection on the input. + detection_result = landmarker.detect_for_video( + self.test_image, timestamp + ) + # Comparing results. + self._expect_holistic_landmarker_results_correct( + detection_result, expected_holistic_landmarker_result, + output_segmentation_mask, + _VIDEO_LANDMARKS_MARGIN, _VIDEO_BLENDSHAPES_MARGIN + ) + + def test_calling_detect_in_live_stream_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + landmarker.detect(self.test_image) + + def test_calling_detect_for_video_in_live_stream_mode(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + landmarker.detect_for_video(self.test_image, 0) + + def test_detect_async_calls_with_illegal_timestamp(self): + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + landmarker.detect_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + landmarker.detect_async(self.test_image, 0) + + @parameterized.parameters( + ( + _POSE_IMAGE, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + False, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ( + _POSE_IMAGE, + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + True, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) + ), + ) + def test_detect_async_calls( + self, + image_path, + model_name, + output_segmentation_mask, + expected_holistic_landmarker_result + ): + test_image = _Image.create_from_file( + test_utils.get_test_data_path(image_path) + ) + observed_timestamp_ms = -1 + + def check_result( + result: HolisticLandmarkerResult, output_image: _Image, timestamp_ms: int + ): + # Comparing results. + self._expect_holistic_landmarker_results_correct( + result, expected_holistic_landmarker_result, + output_segmentation_mask, + _LIVE_STREAM_LANDMARKS_MARGIN, _LIVE_STREAM_BLENDSHAPES_MARGIN + ) + self.assertTrue( + np.array_equal(output_image.numpy_view(), test_image.numpy_view()) + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + model_path = test_utils.get_test_data_path(model_name) + options = _HolisticLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_face_blendshapes=True + if expected_holistic_landmarker_result.face_blendshapes else False, + output_segmentation_mask=output_segmentation_mask, + result_callback=check_result, + ) + with _HolisticLandmarker.create_from_options(options) as landmarker: + for timestamp in range(0, 300, 30): + landmarker.detect_async(test_image, timestamp) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/vision/holistic_landmarker.py b/mediapipe/tasks/python/vision/holistic_landmarker.py index ac96ab7d1..315600b1c 100644 --- a/mediapipe/tasks/python/vision/holistic_landmarker.py +++ b/mediapipe/tasks/python/vision/holistic_landmarker.py @@ -51,7 +51,7 @@ _POSE_LANDMARKS_TAG_NAME = "POSE_LANDMARKS" _POSE_WORLD_LANDMARKS_STREAM_NAME = "pose_world_landmarks" _POSE_WORLD_LANDMARKS_TAG = "POSE_WORLD_LANDMARKS" _POSE_SEGMENTATION_MASK_STREAM_NAME = "pose_segmentation_mask" -_POSE_SEGMENTATION_MASK_TAG = "pose_segmentation_mask" +_POSE_SEGMENTATION_MASK_TAG = "POSE_SEGMENTATION_MASK" _FACE_LANDMARKS_STREAM_NAME = "face_landmarks" _FACE_LANDMARKS_TAG = "FACE_LANDMARKS" _FACE_BLENDSHAPES_STREAM_NAME = "extra_blendshapes" @@ -84,7 +84,7 @@ class HolisticLandmarkerResult: right_hand_landmarks: List[landmark_module.NormalizedLandmark] right_hand_world_landmarks: List[landmark_module.Landmark] face_blendshapes: Optional[List[category_module.Category]] = None - segmentation_masks: Optional[List[image_module.Image]] = None + segmentation_mask: Optional[image_module.Image] = None @classmethod @doc_controls.do_not_generate_docs @@ -96,41 +96,41 @@ class HolisticLandmarkerResult: object.""" return HolisticLandmarkerResult( face_landmarks=[ - landmark_module.NormalizedLandmark.create_from_pb2(landmark) - for landmark in pb2_obj.face_landmarks.landmark + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.face_landmarks.landmark ] if hasattr(pb2_obj, 'face_landmarks') else None, pose_landmarks=[ - landmark_module.NormalizedLandmark.create_from_pb2(landmark) - for landmark in pb2_obj.pose_landmarks.landmark + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.pose_landmarks.landmark ] if hasattr(pb2_obj, 'pose_landmarks') else None, pose_world_landmarks=[ - landmark_module.Landmark.create_from_pb2(landmark) - for landmark in pb2_obj.pose_world_landmarks.landmark + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.pose_world_landmarks.landmark ] if hasattr(pb2_obj, 'pose_world_landmarks') else None, left_hand_landmarks=[ - landmark_module.NormalizedLandmark.create_from_pb2(landmark) - for landmark in pb2_obj.left_hand_landmarks.landmark + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.left_hand_landmarks.landmark ] if hasattr(pb2_obj, 'left_hand_landmarks') else None, left_hand_world_landmarks=[ - landmark_module.Landmark.create_from_pb2(landmark) - for landmark in pb2_obj.left_hand_world_landmarks.landmark + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.left_hand_world_landmarks.landmark ] if hasattr(pb2_obj, 'left_hand_world_landmarks') else None, right_hand_landmarks=[ - landmark_module.NormalizedLandmark.create_from_pb2(landmark) - for landmark in pb2_obj.right_hand_landmarks.landmark + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.right_hand_landmarks.landmark ] if hasattr(pb2_obj, 'right_hand_landmarks') else None, right_hand_world_landmarks=[ - landmark_module.Landmark.create_from_pb2(landmark) - for landmark in pb2_obj.right_hand_world_landmarks.landmark + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.right_hand_world_landmarks.landmark ] if hasattr(pb2_obj, 'right_hand_world_landmarks') else None, face_blendshapes=[ - category_module.Category( - score=classification.score, - index=classification.index, - category_name=classification.label, - display_name=classification.display_name - ) - for classification in pb2_obj.face_blendshapes.classification + category_module.Category( + score=classification.score, + index=classification.index, + category_name=classification.label, + display_name=classification.display_name + ) + for classification in pb2_obj.face_blendshapes.classification ] if hasattr(pb2_obj, 'face_blendshapes') else None, ) @@ -147,98 +147,98 @@ def _build_landmarker_result( ) pose_landmarks_proto_list = packet_getter.get_proto( - output_packets[_POSE_LANDMARKS_STREAM_NAME] + output_packets[_POSE_LANDMARKS_STREAM_NAME] ) pose_world_landmarks_proto_list = packet_getter.get_proto( - output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] + output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] ) left_hand_landmarks_proto_list = packet_getter.get_proto( - output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] + output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] ) left_hand_world_landmarks_proto_list = packet_getter.get_proto( - output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] + output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] ) right_hand_landmarks_proto_list = packet_getter.get_proto( - output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] + output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] ) right_hand_world_landmarks_proto_list = packet_getter.get_proto( - output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] + output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] ) face_landmarks = landmark_pb2.NormalizedLandmarkList() face_landmarks.MergeFrom(face_landmarks_proto_list) for face_landmark in face_landmarks.landmark: holistic_landmarker_result.face_landmarks.append( - landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) + landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) ) pose_landmarks = landmark_pb2.NormalizedLandmarkList() pose_landmarks.MergeFrom(pose_landmarks_proto_list) for pose_landmark in pose_landmarks.landmark: holistic_landmarker_result.pose_landmarks.append( - landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark) + landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark) ) pose_world_landmarks = landmark_pb2.LandmarkList() pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list) for pose_world_landmark in pose_world_landmarks.landmark: holistic_landmarker_result.pose_world_landmarks.append( - landmark_module.Landmark.create_from_pb2(pose_world_landmark) + landmark_module.Landmark.create_from_pb2(pose_world_landmark) ) left_hand_landmarks = landmark_pb2.NormalizedLandmarkList() left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list) for hand_landmark in left_hand_landmarks.landmark: holistic_landmarker_result.left_hand_landmarks.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) ) left_hand_world_landmarks = landmark_pb2.LandmarkList() left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list) for left_hand_world_landmark in left_hand_world_landmarks.landmark: holistic_landmarker_result.left_hand_world_landmarks.append( - landmark_module.Landmark.create_from_pb2(left_hand_world_landmark) + landmark_module.Landmark.create_from_pb2(left_hand_world_landmark) ) right_hand_landmarks = landmark_pb2.NormalizedLandmarkList() right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list) for hand_landmark in right_hand_landmarks.landmark: holistic_landmarker_result.right_hand_landmarks.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) ) right_hand_world_landmarks = landmark_pb2.LandmarkList() right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list) for right_hand_world_landmark in right_hand_world_landmarks.landmark: holistic_landmarker_result.right_hand_world_landmarks.append( - landmark_module.Landmark.create_from_pb2(right_hand_world_landmark) + landmark_module.Landmark.create_from_pb2(right_hand_world_landmark) ) if _FACE_BLENDSHAPES_STREAM_NAME in output_packets: face_blendshapes_proto_list = packet_getter.get_proto( - output_packets[_FACE_BLENDSHAPES_STREAM_NAME] + output_packets[_FACE_BLENDSHAPES_STREAM_NAME] ) face_blendshapes_classifications = classification_pb2.ClassificationList() face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list) holistic_landmarker_result.face_blendshapes = [] for face_blendshapes in face_blendshapes_classifications.classification: holistic_landmarker_result.face_blendshapes.append( - category_module.Category( - index=face_blendshapes.index, - score=face_blendshapes.score, - display_name=face_blendshapes.display_name, - category_name=face_blendshapes.label, - ) + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) ) if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets: - holistic_landmarker_result.segmentation_masks = packet_getter.get_image_list( - output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME] + holistic_landmarker_result.segmentation_mask = packet_getter.get_image( + output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME] ) return holistic_landmarker_result @@ -273,7 +273,7 @@ class HolisticLandmarkerOptions: landmark detection to be considered successful. output_face_blendshapes: Whether HolisticLandmarker outputs face blendshapes classification. Face blendshapes are used for rendering the 3D face model. - output_segmentation_masks: whether to output segmentation masks. + output_segmentation_mask: whether to output segmentation masks. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. @@ -290,7 +290,7 @@ class HolisticLandmarkerOptions: min_pose_landmarks_confidence: float = 0.5 min_hand_landmarks_confidence: float = 0.5 output_face_blendshapes: bool = False - output_segmentation_masks: bool = False + output_segmentation_mask: bool = False result_callback: Optional[ Callable[[HolisticLandmarkerResult, image_module.Image, int], None] ] = None @@ -319,17 +319,17 @@ class HolisticLandmarkerOptions: ) # Configure pose detector and pose landmarks detector options. holistic_landmarker_options_proto.pose_detector_graph_options.min_detection_confidence = ( - self.min_pose_detection_confidence + self.min_pose_detection_confidence ) holistic_landmarker_options_proto.pose_detector_graph_options.min_suppression_threshold = ( - self.min_pose_suppression_threshold + self.min_pose_suppression_threshold ) holistic_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = ( - self.min_pose_landmarks_confidence + self.min_pose_landmarks_confidence ) # Configure hand landmarks detector options. holistic_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = ( - self.min_hand_landmarks_confidence + self.min_hand_landmarks_confidence ) return holistic_landmarker_options_proto @@ -404,30 +404,34 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): ) output_streams = [ - ':'.join([_FACE_LANDMARKS_TAG, _FACE_LANDMARKS_STREAM_NAME]), - ':'.join([_POSE_LANDMARKS_TAG_NAME, _POSE_LANDMARKS_STREAM_NAME]), - ':'.join( - [_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME] - ), - ':'.join([_LEFT_HAND_LANDMARKS_TAG, _LEFT_HAND_LANDMARKS_STREAM_NAME]), - ':'.join( - [_LEFT_HAND_WORLD_LANDMARKS_TAG, _LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] - ), - ':'.join([_RIGHT_HAND_LANDMARKS_TAG, _RIGHT_HAND_LANDMARKS_STREAM_NAME]), - ':'.join( - [_RIGHT_HAND_WORLD_LANDMARKS_TAG, _RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] - ), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ':'.join([_FACE_LANDMARKS_TAG, _FACE_LANDMARKS_STREAM_NAME]), + ':'.join([_POSE_LANDMARKS_TAG_NAME, _POSE_LANDMARKS_STREAM_NAME]), + ':'.join( + [_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME] + ), + ':'.join([_LEFT_HAND_LANDMARKS_TAG, _LEFT_HAND_LANDMARKS_STREAM_NAME]), + ':'.join( + [_LEFT_HAND_WORLD_LANDMARKS_TAG, + _LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] + ), + ':'.join([_RIGHT_HAND_LANDMARKS_TAG, + _RIGHT_HAND_LANDMARKS_STREAM_NAME]), + ':'.join( + [_RIGHT_HAND_WORLD_LANDMARKS_TAG, + _RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] + ), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ] - if options.output_segmentation_masks: + if options.output_segmentation_mask: output_streams.append( - ':'.join([_POSE_SEGMENTATION_MASK_TAG, _POSE_SEGMENTATION_MASK_STREAM_NAME]) + ':'.join([_POSE_SEGMENTATION_MASK_TAG, + _POSE_SEGMENTATION_MASK_STREAM_NAME]) ) if options.output_face_blendshapes: output_streams.append( - ':'.join([_FACE_BLENDSHAPES_TAG, _FACE_BLENDSHAPES_STREAM_NAME]) + ':'.join([_FACE_BLENDSHAPES_TAG, _FACE_BLENDSHAPES_STREAM_NAME]) ) task_info = _TaskInfo(