Support both proto2 and proto3 in task subgraph options configuration, and revised the Holistic Landmarker API's implementation

This commit is contained in:
Kinar 2023-12-17 15:13:34 -08:00
parent ea95ae753d
commit 24fe8eb73a
6 changed files with 265 additions and 183 deletions

View File

@ -49,5 +49,6 @@ py_library(
"//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2",
"//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_options_py_pb2",
"//mediapipe/framework:calculator_py_pb2", "//mediapipe/framework:calculator_py_pb2",
"@com_google_protobuf//:protobuf_python"
], ],
) )

View File

@ -21,6 +21,7 @@ from mediapipe.calculators.core import flow_limiter_calculator_pb2
from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_options_pb2
from mediapipe.framework import calculator_pb2 from mediapipe.framework import calculator_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from google.protobuf.any_pb2 import Any
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
@ -80,22 +81,31 @@ class TaskInfo:
raise ValueError( raise ValueError(
'`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.' '`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.'
) )
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_options_proto = self.task_options.to_pb2() task_options_proto = self.task_options.to_pb2()
# For protobuf 2 compat. node_config = calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams
)
if hasattr(task_options_proto, 'ext'): if hasattr(task_options_proto, 'ext'):
# Use the extension mechanism for task_subgraph_options (proto2)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom( task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto) task_options_proto)
node_config.options.CopyFrom(task_subgraph_options)
else:
# Use the Any type for task_subgraph_options (proto3)
task_subgraph_options = Any()
task_subgraph_options.Pack(self.task_options.to_pb2())
node_config.node_options.append(task_subgraph_options)
if not enable_flow_limiting: if not enable_flow_limiting:
return calculator_pb2.CalculatorGraphConfig( return calculator_pb2.CalculatorGraphConfig(
node=[ node=[
calculator_pb2.CalculatorGraphConfig.Node( node_config
calculator=self.task_graph,
input_stream=self.input_streams,
output_stream=self.output_streams,
options=task_subgraph_options)
], ],
input_stream=self.input_streams, input_stream=self.input_streams,
output_stream=self.output_streams) output_stream=self.output_streams)
@ -125,11 +135,7 @@ class TaskInfo:
options=flow_limiter_options) options=flow_limiter_options)
config = calculator_pb2.CalculatorGraphConfig( config = calculator_pb2.CalculatorGraphConfig(
node=[ node=[
calculator_pb2.CalculatorGraphConfig.Node( node_config, flow_limiter
calculator=self.task_graph,
input_stream=task_subgraph_inputs,
output_stream=self.output_streams,
options=task_subgraph_options), flow_limiter
], ],
input_stream=self.input_streams, input_stream=self.input_streams,
output_stream=self.output_streams) output_stream=self.output_streams)

View File

@ -206,6 +206,7 @@ py_test(
deps = [ deps = [
"//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2", "//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",

View File

@ -14,6 +14,7 @@
"""Tests for holistic landmarker.""" """Tests for holistic landmarker."""
import enum import enum
from typing import List
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -23,6 +24,7 @@ import numpy as np
from google.protobuf import text_format from google.protobuf import text_format
from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2 from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
@ -35,6 +37,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult
_HolisticResultProto = holistic_result_pb2.HolisticResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Category = category_module.Category _Category = category_module.Category
_Rect = rect_module.Rect _Rect = rect_module.Rect
@ -46,14 +49,31 @@ _HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker.task' _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task'
_POSE_IMAGE = 'male_full_height_hands.jpg' _POSE_IMAGE = 'male_full_height_hands.jpg'
_CAT_IMAGE = 'cat.jpg' _CAT_IMAGE = 'cat.jpg'
_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt" _EXPECTED_HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pbtxt"
_LANDMARKS_MARGIN = 0.03 _LANDMARKS_MARGIN = 0.03
_BLENDSHAPES_MARGIN = 0.13 _BLENDSHAPES_MARGIN = 0.13
def _get_expected_holistic_landmarker_result(
file_path: str,
) -> HolisticLandmarkerResult:
holistic_result_file_path = test_utils.get_test_data_path(
file_path
)
with open(holistic_result_file_path, 'rb') as f:
holistic_result_proto = _HolisticResultProto()
# Use this if a .pb file is available.
# holistic_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), holistic_result_proto)
holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2(
holistic_result_proto
)
return holistic_landmarker_result
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
FILE_CONTENT = 1 FILE_CONTENT = 1
FILE_NAME = 2 FILE_NAME = 2
@ -70,20 +90,77 @@ class HolisticLandmarkerTest(parameterized.TestCase):
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE
) )
def _expect_landmarks_correct(
self, actual_landmarks, expected_landmarks, margin
):
# Expects to have the same number of poses detected.
self.assertLen(actual_landmarks, len(expected_landmarks))
for i, elem in enumerate(actual_landmarks):
self.assertAlmostEqual(elem.x, expected_landmarks[i].x, delta=margin)
self.assertAlmostEqual(elem.y, expected_landmarks[i].y, delta=margin)
def _expect_blendshapes_correct(
self, actual_blendshapes, expected_blendshapes, margin
):
# Expects to have the same number of blendshapes.
self.assertLen(actual_blendshapes, len(expected_blendshapes))
for i, elem in enumerate(actual_blendshapes):
self.assertEqual(elem.index, expected_blendshapes[i].index)
self.assertAlmostEqual(
elem.score,
expected_blendshapes[i].score,
delta=margin,
)
def _expect_holistic_landmarker_results_correct(
self,
actual_result: HolisticLandmarkerResult,
expected_result: HolisticLandmarkerResult,
output_segmentation_masks: bool,
landmarks_margin: float,
blendshapes_margin: float,
):
self._expect_landmarks_correct(
actual_result.pose_landmarks, expected_result.pose_landmarks,
landmarks_margin
)
self._expect_landmarks_correct(
actual_result.face_landmarks, expected_result.face_landmarks,
landmarks_margin
)
self._expect_blendshapes_correct(
actual_result.face_blendshapes, expected_result.face_blendshapes,
blendshapes_margin
)
if output_segmentation_masks:
self.assertIsInstance(actual_result.segmentation_masks, List)
for _, mask in enumerate(actual_result.segmentation_masks):
self.assertIsInstance(mask, _Image)
else:
self.assertIsNone(actual_result.segmentation_masks)
@parameterized.parameters( @parameterized.parameters(
( (
ModelFileType.FILE_NAME, ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
), ),
( (
ModelFileType.FILE_CONTENT, ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE _HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT)
), ),
) )
def test_detect( def test_detect(
self, self,
model_file_type, model_file_type,
model_name model_name,
output_segmentation_masks,
expected_holistic_landmarker_result: HolisticLandmarkerResult
): ):
# Creates holistic landmarker. # Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name) model_path = test_utils.get_test_data_path(model_name)
@ -98,15 +175,21 @@ class HolisticLandmarkerTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _HolisticLandmarkerOptions( options = _HolisticLandmarkerOptions(
base_options=base_options base_options=base_options,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes else False,
output_segmentation_masks=output_segmentation_masks,
) )
landmarker = _HolisticLandmarker.create_from_options(options) landmarker = _HolisticLandmarker.create_from_options(options)
# Performs holistic landmarks detection on the input. # Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(self.test_image) detection_result = landmarker.detect(self.test_image)
self._expect_holistic_landmarker_results_correct(
# Closes the holistic landmarker explicitly when the holistic landmarker is not used detection_result, expected_holistic_landmarker_result,
# in a context. output_segmentation_masks, _LANDMARKS_MARGIN, _BLENDSHAPES_MARGIN
)
# Closes the holistic landmarker explicitly when the holistic landmarker is
# not used in a context.
landmarker.close() landmarker.close()

View File

@ -254,6 +254,7 @@ py_library(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",

View File

@ -22,6 +22,7 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_landmarker_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
@ -33,6 +34,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_HolisticResultProto = holistic_result_pb2.HolisticResult
_HolisticLandmarkerGraphOptionsProto = ( _HolisticLandmarkerGraphOptionsProto = (
holistic_landmarker_graph_options_pb2.HolisticLandmarkerGraphOptions holistic_landmarker_graph_options_pb2.HolisticLandmarkerGraphOptions
) )
@ -43,9 +45,6 @@ _TaskInfo = task_info_module.TaskInfo
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_POSE_LANDMARKS_STREAM_NAME = "pose_landmarks" _POSE_LANDMARKS_STREAM_NAME = "pose_landmarks"
_POSE_LANDMARKS_TAG_NAME = "POSE_LANDMARKS" _POSE_LANDMARKS_TAG_NAME = "POSE_LANDMARKS"
@ -77,16 +76,64 @@ class HolisticLandmarkerResult:
Attributes: Attributes:
TODO TODO
""" """
face_landmarks: List[List[landmark_module.NormalizedLandmark]] face_landmarks: List[landmark_module.NormalizedLandmark]
pose_landmarks: List[List[landmark_module.NormalizedLandmark]] pose_landmarks: List[landmark_module.NormalizedLandmark]
pose_world_landmarks: List[List[landmark_module.Landmark]] pose_world_landmarks:List[landmark_module.Landmark]
left_hand_landmarks: List[List[landmark_module.NormalizedLandmark]] left_hand_landmarks: List[landmark_module.NormalizedLandmark]
left_hand_world_landmarks: List[List[landmark_module.Landmark]] left_hand_world_landmarks: List[landmark_module.Landmark]
right_hand_landmarks: List[List[landmark_module.NormalizedLandmark]] right_hand_landmarks: List[landmark_module.NormalizedLandmark]
right_hand_world_landmarks: List[List[landmark_module.Landmark]] right_hand_world_landmarks: List[landmark_module.Landmark]
face_blendshapes: Optional[List[List[category_module.Category]]] = None face_blendshapes: Optional[List[category_module.Category]] = None
segmentation_masks: Optional[List[image_module.Image]] = None segmentation_masks: Optional[List[image_module.Image]] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _HolisticResultProto
) -> 'HolisticLandmarkerResult':
"""Creates a `HolisticLandmarkerResult` object from the given protobuf
object."""
return HolisticLandmarkerResult(
face_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.face_landmarks.landmark
] if hasattr(pb2_obj, 'face_landmarks') else None,
pose_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_landmarks.landmark
] if hasattr(pb2_obj, 'pose_landmarks') else None,
pose_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_world_landmarks.landmark
] if hasattr(pb2_obj, 'pose_world_landmarks') else None,
left_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.left_hand_landmarks.landmark
] if hasattr(pb2_obj, 'left_hand_landmarks') else None,
left_hand_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.left_hand_world_landmarks.landmark
] if hasattr(pb2_obj, 'left_hand_world_landmarks') else None,
right_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.right_hand_landmarks.landmark
] if hasattr(pb2_obj, 'right_hand_landmarks') else None,
right_hand_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.right_hand_world_landmarks.landmark
] if hasattr(pb2_obj, 'right_hand_world_landmarks') else None,
face_blendshapes=[
category_module.Category(
score=classification.score,
index=classification.index,
category_name=classification.label,
display_name=classification.display_name
)
for classification in pb2_obj.face_blendshapes.classification
] if hasattr(pb2_obj, 'face_blendshapes') else None,
)
def _build_landmarker_result( def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet] output_packets: Mapping[str, packet_module.Packet]
@ -95,61 +142,92 @@ def _build_landmarker_result(
holistic_landmarker_result = HolisticLandmarkerResult([], [], [], [], [], [], holistic_landmarker_result = HolisticLandmarkerResult([], [], [], [], [], [],
[]) [])
face_landmarks_proto_list = packet_getter.get_proto_list( face_landmarks_proto_list = packet_getter.get_proto(
output_packets[_FACE_LANDMARKS_STREAM_NAME] output_packets[_FACE_LANDMARKS_STREAM_NAME]
) )
if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets: pose_landmarks_proto_list = packet_getter.get_proto(
holistic_landmarker_result.segmentation_masks = packet_getter.get_image_list(
output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME]
)
pose_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_POSE_LANDMARKS_STREAM_NAME] output_packets[_POSE_LANDMARKS_STREAM_NAME]
) )
pose_world_landmarks_proto_list = packet_getter.get_proto_list( pose_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME] output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME]
) )
left_hand_landmarks_proto_list = packet_getter.get_proto_list( left_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME] output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME]
) )
left_hand_world_landmarks_proto_list = packet_getter.get_proto_list( left_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME] output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME]
) )
right_hand_landmarks_proto_list = packet_getter.get_proto_list( right_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME] output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME]
) )
right_hand_world_landmarks_proto_list = packet_getter.get_proto_list( right_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME] output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME]
) )
face_landmarks_results = []
for proto in face_landmarks_proto_list:
face_landmarks = landmark_pb2.NormalizedLandmarkList() face_landmarks = landmark_pb2.NormalizedLandmarkList()
face_landmarks.MergeFrom(proto) face_landmarks.MergeFrom(face_landmarks_proto_list)
face_landmarks_list = []
for face_landmark in face_landmarks.landmark: for face_landmark in face_landmarks.landmark:
face_landmarks_list.append( holistic_landmarker_result.face_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
) )
face_landmarks_results.append(face_landmarks_list)
face_blendshapes_results = [] pose_landmarks = landmark_pb2.NormalizedLandmarkList()
pose_landmarks.MergeFrom(pose_landmarks_proto_list)
for pose_landmark in pose_landmarks.landmark:
holistic_landmarker_result.pose_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark)
)
pose_world_landmarks = landmark_pb2.LandmarkList()
pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list)
for pose_world_landmark in pose_world_landmarks.landmark:
holistic_landmarker_result.pose_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(pose_world_landmark)
)
left_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list)
for hand_landmark in left_hand_landmarks.landmark:
holistic_landmarker_result.left_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
left_hand_world_landmarks = landmark_pb2.LandmarkList()
left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list)
for left_hand_world_landmark in left_hand_world_landmarks.landmark:
holistic_landmarker_result.left_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(left_hand_world_landmark)
)
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list)
for hand_landmark in right_hand_landmarks.landmark:
holistic_landmarker_result.right_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
right_hand_world_landmarks = landmark_pb2.LandmarkList()
right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list)
for right_hand_world_landmark in right_hand_world_landmarks.landmark:
holistic_landmarker_result.right_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(right_hand_world_landmark)
)
if _FACE_BLENDSHAPES_STREAM_NAME in output_packets: if _FACE_BLENDSHAPES_STREAM_NAME in output_packets:
face_blendshapes_proto_list = packet_getter.get_proto_list( face_blendshapes_proto_list = packet_getter.get_proto(
output_packets[_FACE_BLENDSHAPES_STREAM_NAME] output_packets[_FACE_BLENDSHAPES_STREAM_NAME]
) )
for proto in face_blendshapes_proto_list:
face_blendshapes_categories = []
face_blendshapes_classifications = classification_pb2.ClassificationList() face_blendshapes_classifications = classification_pb2.ClassificationList()
face_blendshapes_classifications.MergeFrom(proto) face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list)
holistic_landmarker_result.face_blendshapes = []
for face_blendshapes in face_blendshapes_classifications.classification: for face_blendshapes in face_blendshapes_classifications.classification:
face_blendshapes_categories.append( holistic_landmarker_result.face_blendshapes.append(
category_module.Category( category_module.Category(
index=face_blendshapes.index, index=face_blendshapes.index,
score=face_blendshapes.score, score=face_blendshapes.score,
@ -157,76 +235,10 @@ def _build_landmarker_result(
category_name=face_blendshapes.label, category_name=face_blendshapes.label,
) )
) )
face_blendshapes_results.append(face_blendshapes_categories)
for proto in pose_landmarks_proto_list: if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets:
pose_landmarks = landmark_pb2.NormalizedLandmarkList() holistic_landmarker_result.segmentation_masks = packet_getter.get_image_list(
pose_landmarks.MergeFrom(proto) output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME]
pose_landmarks_list = []
for pose_landmark in pose_landmarks.landmark:
pose_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark)
)
holistic_landmarker_result.pose_landmarks.append(pose_landmarks_list)
for proto in pose_world_landmarks_proto_list:
pose_world_landmarks = landmark_pb2.LandmarkList()
pose_world_landmarks.MergeFrom(proto)
pose_world_landmarks_list = []
for pose_world_landmark in pose_world_landmarks.landmark:
pose_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(pose_world_landmark)
)
holistic_landmarker_result.pose_world_landmarks.append(
pose_world_landmarks_list
)
for proto in left_hand_landmarks_proto_list:
left_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
left_hand_landmarks.MergeFrom(proto)
left_hand_landmarks_list = []
for hand_landmark in left_hand_landmarks.landmark:
left_hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
holistic_landmarker_result.left_hand_landmarks.append(
left_hand_landmarks_list
)
for proto in left_hand_world_landmarks_proto_list:
left_hand_world_landmarks = landmark_pb2.LandmarkList()
left_hand_world_landmarks.MergeFrom(proto)
left_hand_world_landmarks_list = []
for left_hand_world_landmark in left_hand_world_landmarks.landmark:
left_hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(left_hand_world_landmark)
)
holistic_landmarker_result.left_hand_world_landmarks.append(
left_hand_world_landmarks_list
)
for proto in right_hand_landmarks_proto_list:
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
right_hand_landmarks.MergeFrom(proto)
right_hand_landmarks_list = []
for hand_landmark in right_hand_landmarks.landmark:
right_hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
holistic_landmarker_result.right_hand_landmarks.append(
right_hand_landmarks_list
)
for proto in right_hand_world_landmarks_proto_list:
right_hand_world_landmarks = landmark_pb2.LandmarkList()
right_hand_world_landmarks.MergeFrom(proto)
right_hand_world_landmarks_list = []
for right_hand_world_landmark in right_hand_world_landmarks.landmark:
right_hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(right_hand_world_landmark)
)
holistic_landmarker_result.right_hand_world_landmarks.append(
right_hand_world_landmarks_list
) )
return holistic_landmarker_result return holistic_landmarker_result
@ -259,6 +271,9 @@ class HolisticLandmarkerOptions:
landmark detection to be considered successful. landmark detection to be considered successful.
min_hand_landmarks_confidence: The minimum confidence score for the hand min_hand_landmarks_confidence: The minimum confidence score for the hand
landmark detection to be considered successful. landmark detection to be considered successful.
output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes
classification. Face blendshapes are used for rendering the 3D face model.
output_segmentation_masks: whether to output segmentation masks.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
@ -419,7 +434,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=output_streams, output_streams=output_streams,
task_options=options, task_options=options,
@ -436,7 +450,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
def detect( def detect(
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HolisticLandmarkerResult: ) -> HolisticLandmarkerResult:
"""Performs holistic landmarks detection on the given image. """Performs holistic landmarks detection on the given image.
@ -449,7 +462,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns: Returns:
The holistic landmarks detection results. The holistic landmarks detection results.
@ -458,14 +470,8 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If holistic landmarker detection failed to run. RuntimeError: If holistic landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
}) })
if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty():
@ -477,7 +483,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HolisticLandmarkerResult: ) -> HolisticLandmarkerResult:
"""Performs holistic landmarks detection on the provided video frame. """Performs holistic landmarks detection on the provided video frame.
@ -492,7 +497,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns: Returns:
The holistic landmarks detection results. The holistic landmarks detection results.
@ -501,16 +505,10 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If holistic landmarker detection failed to run. RuntimeError: If holistic landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
), ),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty():
@ -522,7 +520,6 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None: ) -> None:
"""Sends live image data to perform holistic landmarks detection. """Sends live image data to perform holistic landmarks detection.
@ -548,20 +545,13 @@ class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the ValueError: If the current input timestamp is smaller than what the
holistic landmarker has already processed. holistic landmarker has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
), ),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })