diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 9d2dc3f0b..f256430c3 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -49,5 +49,6 @@ py_library( "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_py_pb2", + "@com_google_protobuf//:protobuf_python" ], ) diff --git a/mediapipe/tasks/python/core/task_info.py b/mediapipe/tasks/python/core/task_info.py index 894103361..5d039a034 100644 --- a/mediapipe/tasks/python/core/task_info.py +++ b/mediapipe/tasks/python/core/task_info.py @@ -21,6 +21,7 @@ from mediapipe.calculators.core import flow_limiter_calculator_pb2 from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from google.protobuf.any_pb2 import Any @doc_controls.do_not_generate_docs @@ -80,22 +81,31 @@ class TaskInfo: raise ValueError( '`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.' ) - task_subgraph_options = calculator_options_pb2.CalculatorOptions() + task_options_proto = self.task_options.to_pb2() - # For protobuf 2 compat. + node_config = calculator_pb2.CalculatorGraphConfig.Node( + calculator=self.task_graph, + input_stream=self.input_streams, + output_stream=self.output_streams + ) + if hasattr(task_options_proto, 'ext'): + # Use the extension mechanism for task_subgraph_options (proto2) + task_subgraph_options = calculator_options_pb2.CalculatorOptions() task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom( task_options_proto) + node_config.options.CopyFrom(task_subgraph_options) + else: + # Use the Any type for task_subgraph_options (proto3) + task_subgraph_options = Any() + task_subgraph_options.Pack(self.task_options.to_pb2()) + node_config.node_options.append(task_subgraph_options) if not enable_flow_limiting: return calculator_pb2.CalculatorGraphConfig( node=[ - calculator_pb2.CalculatorGraphConfig.Node( - calculator=self.task_graph, - input_stream=self.input_streams, - output_stream=self.output_streams, - options=task_subgraph_options) + node_config ], input_stream=self.input_streams, output_stream=self.output_streams) @@ -125,11 +135,7 @@ class TaskInfo: options=flow_limiter_options) config = calculator_pb2.CalculatorGraphConfig( node=[ - calculator_pb2.CalculatorGraphConfig.Node( - calculator=self.task_graph, - input_stream=task_subgraph_inputs, - output_stream=self.output_streams, - options=task_subgraph_options), flow_limiter + node_config, flow_limiter ], input_stream=self.input_streams, output_stream=self.output_streams) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 374ba689c..eae05de4d 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -206,6 +206,7 @@ py_test( deps = [ "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2", "//mediapipe/python:_framework_bindings", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", diff --git a/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py b/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py index 0c9179301..6f27c903d 100644 --- a/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py +++ b/mediapipe/tasks/python/test/vision/holistic_landmarker_test.py @@ -14,6 +14,7 @@ """Tests for holistic landmarker.""" import enum +from typing import List from unittest import mock from absl.testing import absltest @@ -23,6 +24,7 @@ import numpy as np from google.protobuf import text_format from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import landmark_pb2 +from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2 from mediapipe.python._framework_bindings import image as image_module from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module @@ -35,6 +37,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult +_HolisticResultProto = holistic_result_pb2.HolisticResult _BaseOptions = base_options_module.BaseOptions _Category = category_module.Category _Rect = rect_module.Rect @@ -46,14 +49,31 @@ _HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker.task' +_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task' _POSE_IMAGE = 'male_full_height_hands.jpg' _CAT_IMAGE = 'cat.jpg' -_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt" +_EXPECTED_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt" _LANDMARKS_MARGIN = 0.03 _BLENDSHAPES_MARGIN = 0.13 +def _get_expected_holistic_landmarker_result( + file_path: str, +) -> HolisticLandmarkerResult: + holistic_result_file_path = test_utils.get_test_data_path( + file_path + ) + with open(holistic_result_file_path, 'rb') as f: + holistic_result_proto = _HolisticResultProto() + # Use this if a .pb file is available. + # holistic_result_proto.ParseFromString(f.read()) + text_format.Parse(f.read(), holistic_result_proto) + holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2( + holistic_result_proto + ) + return holistic_landmarker_result + + class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 @@ -70,20 +90,77 @@ class HolisticLandmarkerTest(parameterized.TestCase): _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE ) + def _expect_landmarks_correct( + self, actual_landmarks, expected_landmarks, margin + ): + # Expects to have the same number of poses detected. + self.assertLen(actual_landmarks, len(expected_landmarks)) + + for i, elem in enumerate(actual_landmarks): + self.assertAlmostEqual(elem.x, expected_landmarks[i].x, delta=margin) + self.assertAlmostEqual(elem.y, expected_landmarks[i].y, delta=margin) + + def _expect_blendshapes_correct( + self, actual_blendshapes, expected_blendshapes, margin + ): + # Expects to have the same number of blendshapes. + self.assertLen(actual_blendshapes, len(expected_blendshapes)) + + for i, elem in enumerate(actual_blendshapes): + self.assertEqual(elem.index, expected_blendshapes[i].index) + self.assertAlmostEqual( + elem.score, + expected_blendshapes[i].score, + delta=margin, + ) + + def _expect_holistic_landmarker_results_correct( + self, + actual_result: HolisticLandmarkerResult, + expected_result: HolisticLandmarkerResult, + output_segmentation_masks: bool, + landmarks_margin: float, + blendshapes_margin: float, + ): + self._expect_landmarks_correct( + actual_result.pose_landmarks, expected_result.pose_landmarks, + landmarks_margin + ) + self._expect_landmarks_correct( + actual_result.face_landmarks, expected_result.face_landmarks, + landmarks_margin + ) + self._expect_blendshapes_correct( + actual_result.face_blendshapes, expected_result.face_blendshapes, + blendshapes_margin + ) + if output_segmentation_masks: + self.assertIsInstance(actual_result.segmentation_masks, List) + for _, mask in enumerate(actual_result.segmentation_masks): + self.assertIsInstance(mask, _Image) + else: + self.assertIsNone(actual_result.segmentation_masks) + @parameterized.parameters( ( ModelFileType.FILE_NAME, - _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + False, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) ), ( ModelFileType.FILE_CONTENT, - _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE + _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE, + False, + _get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT) ), ) def test_detect( self, model_file_type, - model_name + model_name, + output_segmentation_masks, + expected_holistic_landmarker_result: HolisticLandmarkerResult ): # Creates holistic landmarker. model_path = test_utils.get_test_data_path(model_name) @@ -98,15 +175,21 @@ class HolisticLandmarkerTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _HolisticLandmarkerOptions( - base_options=base_options + base_options=base_options, + output_face_blendshapes=True + if expected_holistic_landmarker_result.face_blendshapes else False, + output_segmentation_masks=output_segmentation_masks, ) landmarker = _HolisticLandmarker.create_from_options(options) # Performs holistic landmarks detection on the input. detection_result = landmarker.detect(self.test_image) - - # Closes the holistic landmarker explicitly when the holistic landmarker is not used - # in a context. + self._expect_holistic_landmarker_results_correct( + detection_result, expected_holistic_landmarker_result, + output_segmentation_masks, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN + ) + # Closes the holistic landmarker explicitly when the holistic landmarker is + # not used in a context. landmarker.close() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 8253a9232..1b0a1454b 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -254,6 +254,7 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2", "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", diff --git a/mediapipe/tasks/python/vision/holistic_landmarker.py b/mediapipe/tasks/python/vision/holistic_landmarker.py index 940663230..8edfaf445 100644 --- a/mediapipe/tasks/python/vision/holistic_landmarker.py +++ b/mediapipe/tasks/python/vision/holistic_landmarker.py @@ -22,6 +22,7 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2 from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_landmarker_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module @@ -33,6 +34,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions +_HolisticResultProto = holistic_result_pb2.HolisticResult _HolisticLandmarkerGraphOptionsProto = ( holistic_landmarker_graph_options_pb2.HolisticLandmarkerGraphOptions ) @@ -43,9 +45,6 @@ _TaskInfo = task_info_module.TaskInfo _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' -_NORM_RECT_STREAM_NAME = 'norm_rect_in' -_NORM_RECT_TAG = 'NORM_RECT' - _POSE_LANDMARKS_STREAM_NAME = "pose_landmarks" _POSE_LANDMARKS_TAG_NAME = "POSE_LANDMARKS" @@ -77,16 +76,64 @@ class HolisticLandmarkerResult: Attributes: TODO """ - face_landmarks: List[List[landmark_module.NormalizedLandmark]] - pose_landmarks: List[List[landmark_module.NormalizedLandmark]] - pose_world_landmarks: List[List[landmark_module.Landmark]] - left_hand_landmarks: List[List[landmark_module.NormalizedLandmark]] - left_hand_world_landmarks: List[List[landmark_module.Landmark]] - right_hand_landmarks: List[List[landmark_module.NormalizedLandmark]] - right_hand_world_landmarks: List[List[landmark_module.Landmark]] - face_blendshapes: Optional[List[List[category_module.Category]]] = None + face_landmarks: List[landmark_module.NormalizedLandmark] + pose_landmarks: List[landmark_module.NormalizedLandmark] + pose_world_landmarks:List[landmark_module.Landmark] + left_hand_landmarks: List[landmark_module.NormalizedLandmark] + left_hand_world_landmarks: List[landmark_module.Landmark] + right_hand_landmarks: List[landmark_module.NormalizedLandmark] + right_hand_world_landmarks: List[landmark_module.Landmark] + face_blendshapes: Optional[List[category_module.Category]] = None segmentation_masks: Optional[List[image_module.Image]] = None + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _HolisticResultProto + ) -> 'HolisticLandmarkerResult': + """Creates a `HolisticLandmarkerResult` object from the given protobuf + object.""" + return HolisticLandmarkerResult( + face_landmarks=[ + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.face_landmarks.landmark + ] if hasattr(pb2_obj, 'face_landmarks') else None, + pose_landmarks=[ + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.pose_landmarks.landmark + ] if hasattr(pb2_obj, 'pose_landmarks') else None, + pose_world_landmarks=[ + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.pose_world_landmarks.landmark + ] if hasattr(pb2_obj, 'pose_world_landmarks') else None, + left_hand_landmarks=[ + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.left_hand_landmarks.landmark + ] if hasattr(pb2_obj, 'left_hand_landmarks') else None, + left_hand_world_landmarks=[ + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.left_hand_world_landmarks.landmark + ] if hasattr(pb2_obj, 'left_hand_world_landmarks') else None, + right_hand_landmarks=[ + landmark_module.NormalizedLandmark.create_from_pb2(landmark) + for landmark in pb2_obj.right_hand_landmarks.landmark + ] if hasattr(pb2_obj, 'right_hand_landmarks') else None, + right_hand_world_landmarks=[ + landmark_module.Landmark.create_from_pb2(landmark) + for landmark in pb2_obj.right_hand_world_landmarks.landmark + ] if hasattr(pb2_obj, 'right_hand_world_landmarks') else None, + face_blendshapes=[ + category_module.Category( + score=classification.score, + index=classification.index, + category_name=classification.label, + display_name=classification.display_name + ) + for classification in pb2_obj.face_blendshapes.classification + ] if hasattr(pb2_obj, 'face_blendshapes') else None, + ) + def _build_landmarker_result( output_packets: Mapping[str, packet_module.Packet] @@ -95,140 +142,105 @@ def _build_landmarker_result( holistic_landmarker_result = HolisticLandmarkerResult([], [], [], [], [], [], []) - face_landmarks_proto_list = packet_getter.get_proto_list( + face_landmarks_proto_list = packet_getter.get_proto( output_packets[_FACE_LANDMARKS_STREAM_NAME] ) + pose_landmarks_proto_list = packet_getter.get_proto( + output_packets[_POSE_LANDMARKS_STREAM_NAME] + ) + + pose_world_landmarks_proto_list = packet_getter.get_proto( + output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] + ) + + left_hand_landmarks_proto_list = packet_getter.get_proto( + output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] + ) + + left_hand_world_landmarks_proto_list = packet_getter.get_proto( + output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] + ) + + right_hand_landmarks_proto_list = packet_getter.get_proto( + output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] + ) + + right_hand_world_landmarks_proto_list = packet_getter.get_proto( + output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] + ) + + face_landmarks = landmark_pb2.NormalizedLandmarkList() + face_landmarks.MergeFrom(face_landmarks_proto_list) + for face_landmark in face_landmarks.landmark: + holistic_landmarker_result.face_landmarks.append( + landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) + ) + + pose_landmarks = landmark_pb2.NormalizedLandmarkList() + pose_landmarks.MergeFrom(pose_landmarks_proto_list) + for pose_landmark in pose_landmarks.landmark: + holistic_landmarker_result.pose_landmarks.append( + landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark) + ) + + pose_world_landmarks = landmark_pb2.LandmarkList() + pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list) + for pose_world_landmark in pose_world_landmarks.landmark: + holistic_landmarker_result.pose_world_landmarks.append( + landmark_module.Landmark.create_from_pb2(pose_world_landmark) + ) + + left_hand_landmarks = landmark_pb2.NormalizedLandmarkList() + left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list) + for hand_landmark in left_hand_landmarks.landmark: + holistic_landmarker_result.left_hand_landmarks.append( + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + ) + + left_hand_world_landmarks = landmark_pb2.LandmarkList() + left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list) + for left_hand_world_landmark in left_hand_world_landmarks.landmark: + holistic_landmarker_result.left_hand_world_landmarks.append( + landmark_module.Landmark.create_from_pb2(left_hand_world_landmark) + ) + + right_hand_landmarks = landmark_pb2.NormalizedLandmarkList() + right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list) + for hand_landmark in right_hand_landmarks.landmark: + holistic_landmarker_result.right_hand_landmarks.append( + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + ) + + right_hand_world_landmarks = landmark_pb2.LandmarkList() + right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list) + for right_hand_world_landmark in right_hand_world_landmarks.landmark: + holistic_landmarker_result.right_hand_world_landmarks.append( + landmark_module.Landmark.create_from_pb2(right_hand_world_landmark) + ) + + if _FACE_BLENDSHAPES_STREAM_NAME in output_packets: + face_blendshapes_proto_list = packet_getter.get_proto( + output_packets[_FACE_BLENDSHAPES_STREAM_NAME] + ) + face_blendshapes_classifications = classification_pb2.ClassificationList() + face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list) + holistic_landmarker_result.face_blendshapes = [] + for face_blendshapes in face_blendshapes_classifications.classification: + holistic_landmarker_result.face_blendshapes.append( + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) + ) + if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets: holistic_landmarker_result.segmentation_masks = packet_getter.get_image_list( output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME] ) - pose_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_POSE_LANDMARKS_STREAM_NAME] - ) - - pose_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] - ) - - left_hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] - ) - - left_hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] - ) - - right_hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] - ) - - right_hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] - ) - - face_landmarks_results = [] - for proto in face_landmarks_proto_list: - face_landmarks = landmark_pb2.NormalizedLandmarkList() - face_landmarks.MergeFrom(proto) - face_landmarks_list = [] - for face_landmark in face_landmarks.landmark: - face_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) - ) - face_landmarks_results.append(face_landmarks_list) - - face_blendshapes_results = [] - if _FACE_BLENDSHAPES_STREAM_NAME in output_packets: - face_blendshapes_proto_list = packet_getter.get_proto_list( - output_packets[_FACE_BLENDSHAPES_STREAM_NAME] - ) - for proto in face_blendshapes_proto_list: - face_blendshapes_categories = [] - face_blendshapes_classifications = classification_pb2.ClassificationList() - face_blendshapes_classifications.MergeFrom(proto) - for face_blendshapes in face_blendshapes_classifications.classification: - face_blendshapes_categories.append( - category_module.Category( - index=face_blendshapes.index, - score=face_blendshapes.score, - display_name=face_blendshapes.display_name, - category_name=face_blendshapes.label, - ) - ) - face_blendshapes_results.append(face_blendshapes_categories) - - for proto in pose_landmarks_proto_list: - pose_landmarks = landmark_pb2.NormalizedLandmarkList() - pose_landmarks.MergeFrom(proto) - pose_landmarks_list = [] - for pose_landmark in pose_landmarks.landmark: - pose_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark) - ) - holistic_landmarker_result.pose_landmarks.append(pose_landmarks_list) - - for proto in pose_world_landmarks_proto_list: - pose_world_landmarks = landmark_pb2.LandmarkList() - pose_world_landmarks.MergeFrom(proto) - pose_world_landmarks_list = [] - for pose_world_landmark in pose_world_landmarks.landmark: - pose_world_landmarks_list.append( - landmark_module.Landmark.create_from_pb2(pose_world_landmark) - ) - holistic_landmarker_result.pose_world_landmarks.append( - pose_world_landmarks_list - ) - - for proto in left_hand_landmarks_proto_list: - left_hand_landmarks = landmark_pb2.NormalizedLandmarkList() - left_hand_landmarks.MergeFrom(proto) - left_hand_landmarks_list = [] - for hand_landmark in left_hand_landmarks.landmark: - left_hand_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - ) - holistic_landmarker_result.left_hand_landmarks.append( - left_hand_landmarks_list - ) - - for proto in left_hand_world_landmarks_proto_list: - left_hand_world_landmarks = landmark_pb2.LandmarkList() - left_hand_world_landmarks.MergeFrom(proto) - left_hand_world_landmarks_list = [] - for left_hand_world_landmark in left_hand_world_landmarks.landmark: - left_hand_world_landmarks_list.append( - landmark_module.Landmark.create_from_pb2(left_hand_world_landmark) - ) - holistic_landmarker_result.left_hand_world_landmarks.append( - left_hand_world_landmarks_list - ) - - for proto in right_hand_landmarks_proto_list: - right_hand_landmarks = landmark_pb2.NormalizedLandmarkList() - right_hand_landmarks.MergeFrom(proto) - right_hand_landmarks_list = [] - for hand_landmark in right_hand_landmarks.landmark: - right_hand_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - ) - holistic_landmarker_result.right_hand_landmarks.append( - right_hand_landmarks_list - ) - - for proto in right_hand_world_landmarks_proto_list: - right_hand_world_landmarks = landmark_pb2.LandmarkList() - right_hand_world_landmarks.MergeFrom(proto) - right_hand_world_landmarks_list = [] - for right_hand_world_landmark in right_hand_world_landmarks.landmark: - right_hand_world_landmarks_list.append( - landmark_module.Landmark.create_from_pb2(right_hand_world_landmark) - ) - holistic_landmarker_result.right_hand_world_landmarks.append( - right_hand_world_landmarks_list - ) - return holistic_landmarker_result @@ -259,6 +271,9 @@ class HolisticLandmarkerOptions: landmark detection to be considered successful. min_hand_landmarks_confidence: The minimum confidence score for the hand landmark detection to be considered successful. + output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes + classification. Face blendshapes are used for rendering the 3D face model. + output_segmentation_masks: whether to output segmentation masks. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. @@ -419,7 +434,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), - ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], output_streams=output_streams, task_options=options, @@ -436,7 +450,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): def detect( self, image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> HolisticLandmarkerResult: """Performs holistic landmarks detection on the given image. @@ -449,7 +462,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. - image_processing_options: Options for image processing. Returns: The holistic landmarks detection results. @@ -458,14 +470,8 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If holistic landmarker detection failed to run. """ - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) output_packets = self._process_image_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ), }) if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty(): @@ -477,7 +483,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> HolisticLandmarkerResult: """Performs holistic landmarks detection on the provided video frame. @@ -492,7 +497,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. - image_processing_options: Options for image processing. Returns: The holistic landmarks detection results. @@ -501,16 +505,10 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If holistic landmarker detection failed to run. """ - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) output_packets = self._process_video_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND ), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty(): @@ -522,7 +520,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> None: """Sends live image data to perform holistic landmarks detection. @@ -548,20 +545,13 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. - image_processing_options: Options for image processing. Raises: ValueError: If the current input timestamp is smaller than what the holistic landmarker has already processed. """ - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) self._send_live_stream_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND ), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), })