diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5b75ef8fc..6db49c668 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -155,6 +155,37 @@ cc_library( # TODO: open source hand joints graph +cc_library( + name = "hand_roi_refinement_graph", + srcs = ["hand_roi_refinement_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:landmarks_projection", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + cc_library( name = "hand_landmarker_result", srcs = ["hand_landmarker_result.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc new file mode 100644 index 000000000..e7e9b94d0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_roi_refinement_graph.cc @@ -0,0 +1,154 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/landmarks_projection.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::ProjectLandmarks; +using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong; +using ::mediapipe::api2::builder::Stream; + +// Refine the input hand RoI with hand_roi_refinement model. +// +// Inputs: +// IMAGE - Image +// The image to preprocess. +// NORM_RECT - NormalizedRect +// Coarse RoI of hand. +// Outputs: +// NORM_RECT - NormalizedRect +// Refined RoI of hand. +class HandRoiRefinementGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* context) override { + Graph graph; + Stream image_in = graph.In("IMAGE").Cast(); + Stream roi_in = + graph.In("NORM_RECT").Cast(); + + auto& graph_options = + *context->MutableOptions(); + + MP_ASSIGN_OR_RETURN( + const auto* model_resources, + GetOrCreateModelResources( + context)); + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + graph_options.base_options().acceleration()); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_REPLICATE); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + *model_resources, use_gpu, graph_options.base_options().gpu_origin(), + &preprocessing.GetOptions())); + image_in >> preprocessing.In("IMAGE"); + roi_in >> preprocessing.In("NORM_RECT"); + auto tensors_in = preprocessing.Out("TENSORS"); + auto matrix = preprocessing.Out("MATRIX").Cast>(); + auto image_size = + preprocessing.Out("IMAGE_SIZE").Cast>(); + + auto& inference = AddInference( + *model_resources, graph_options.base_options().acceleration(), graph); + tensors_in >> inference.In("TENSORS"); + auto tensors_out = inference.Out("TENSORS").Cast>(); + + MP_ASSIGN_OR_RETURN(auto image_tensor_specs, + BuildInputImageTensorSpecs(*model_resources)); + + // Convert tensors to landmarks. Recrop model outputs two points, + // center point and guide point. + auto& to_landmarks = graph.AddNode("TensorsToLandmarksCalculator"); + auto& to_landmarks_opts = + to_landmarks + .GetOptions(); + to_landmarks_opts.set_num_landmarks(/*num_landmarks=*/2); + to_landmarks_opts.set_input_image_width(image_tensor_specs.image_width); + to_landmarks_opts.set_input_image_height(image_tensor_specs.image_height); + to_landmarks_opts.set_normalize_z(/*z_norm_factor=*/1.0f); + tensors_out.ConnectTo(to_landmarks.In("TENSORS")); + auto recrop_landmarks = to_landmarks.Out("NORM_LANDMARKS") + .Cast(); + + // Project landmarks. + auto projected_recrop_landmarks = + ProjectLandmarks(recrop_landmarks, matrix, graph); + + // Convert re-crop landmarks to detection. + auto recrop_detection = + ConvertLandmarksToDetection(projected_recrop_landmarks, graph); + + // Convert re-crop detection to rect. + auto recrop_rect = ConvertAlignmentPointsDetectionToRect( + recrop_detection, image_size, /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, /*target_angle=*/-90, graph); + + auto refined_roi = + ScaleAndShiftAndMakeSquareLong(recrop_rect, image_size, + /*scale_x_factor=*/1.0, + /*scale_y_factor=*/1.0, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); + refined_roi >> graph.Out("NORM_RECT").Cast(); + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandRoiRefinementGraph); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD b/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD new file mode 100644 index 000000000..446cf1e09 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/BUILD @@ -0,0 +1,152 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe/tasks:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "holistic_face_tracking", + srcs = ["holistic_face_tracking.cc"], + hdrs = ["holistic_face_tracking.h"], + deps = [ + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker:face_blendshapes_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "holistic_hand_tracking", + srcs = ["holistic_hand_tracking.cc"], + hdrs = ["holistic_hand_tracking.h"], + deps = [ + "//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator", + "//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator_cc_proto", + "//mediapipe/calculators/util:landmark_visibility_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/api2/stream:threshold", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/modules/holistic_landmark/calculators:hand_detections_from_pose_to_rects_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator", + "//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_roi_refinement_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "holistic_pose_tracking", + srcs = ["holistic_pose_tracking.cc"], + hdrs = ["holistic_pose_tracking.h"], + deps = [ + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:detections_to_rects", + "//mediapipe/framework/api2/stream:image_size", + "//mediapipe/framework/api2/stream:landmarks_to_detection", + "//mediapipe/framework/api2/stream:loopback", + "//mediapipe/framework/api2/stream:merge", + "//mediapipe/framework/api2/stream:presence", + "//mediapipe/framework/api2/stream:rect_transformation", + "//mediapipe/framework/api2/stream:segmentation_smoothing", + "//mediapipe/framework/api2/stream:smoothing", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "holistic_landmarker_graph", + srcs = ["holistic_landmarker_graph.cc"], + deps = [ + ":holistic_face_tracking", + ":holistic_hand_tracking", + ":holistic_pose_tracking", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2/stream:split", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_topology", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc new file mode 100644 index 000000000..1116cda21 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.cc @@ -0,0 +1,260 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::ConvertDetectionsToRectUsingKeypoints; +using ::mediapipe::api2::builder::ConvertDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Scale; +using ::mediapipe::api2::builder::ScaleAndMakeSquare; +using ::mediapipe::api2::builder::Stream; + +struct FaceLandmarksResult { + std::optional> landmarks; + std::optional> classifications; +}; + +absl::Status ValidateGraphOptions( + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request) { + if (face_detector_graph_options.num_faces() != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Only support num_faces to be 1, but got num_faces = %d.", + face_detector_graph_options.num_faces())); + } + if (request.classifications && !face_landmarks_detector_graph_options + .has_face_blendshapes_graph_options()) { + return absl::InvalidArgumentError( + "Blendshapes detection is requested, but " + "face_blendshapes_graph_options is not configured."); + } + return absl::OkStatus(); +} + +Stream GetFaceRoiFromPoseFaceLandmarks( + Stream pose_face_landmarks, + Stream> image_size, Graph& graph) { + Stream detection = + ConvertLandmarksToDetection(pose_face_landmarks, graph); + + // Refer the pose face landmarks indices here: + // https://developers.google.com/mediapipe/solutions/vision/pose_landmarker#pose_landmarker_model + Stream rect = ConvertDetectionToRect( + detection, image_size, /*start_keypoint_index=*/5, + /*end_keypoint_index=*/2, /*target_angle=*/0, graph); + + // Scale the face RoI from a tight rect enclosing the pose face landmarks, to + // a larger square so that the whole face is within the RoI. + return ScaleAndMakeSquare(rect, image_size, + /*scale_x_factor=*/3.0, + /*scale_y_factor=*/3.0, graph); +} + +Stream GetFaceRoiFromFaceLandmarks( + Stream face_landmarks, + Stream> image_size, Graph& graph) { + Stream detection = + ConvertLandmarksToDetection(face_landmarks, graph); + + Stream rect = ConvertDetectionToRect( + detection, image_size, /*start_keypoint_index=*/33, + /*end_keypoint_index=*/263, /*target_angle=*/0, graph); + + return Scale(rect, image_size, + /*scale_x_factor=*/1.5, + /*scale_y_factor=*/1.5, graph); +} + +Stream> GetFaceDetections( + Stream image, Stream roi, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + Graph& graph) { + auto& face_detector_graph = + graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph"); + face_detector_graph + .GetOptions() = + face_detector_graph_options; + image >> face_detector_graph.In("IMAGE"); + roi >> face_detector_graph.In("NORM_RECT"); + return face_detector_graph.Out("DETECTIONS").Cast>(); +} + +Stream GetFaceRoiFromFaceDetections( + Stream> face_detections, + Stream> image_size, Graph& graph) { + // Convert detection to rect. + Stream rect = ConvertDetectionsToRectUsingKeypoints( + face_detections, image_size, /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, /*target_angle=*/0, graph); + + return ScaleAndMakeSquare(rect, image_size, + /*scale_x_factor=*/2.0, + /*scale_y_factor=*/2.0, graph); +} + +Stream TrackFaceRoi( + Stream prev_landmarks, Stream roi, + Stream> image_size, Graph& graph) { + // Gets face ROI from previous frame face landmarks. + Stream prev_roi = + GetFaceRoiFromFaceLandmarks(prev_landmarks, image_size, graph); + + auto& tracking_node = graph.AddNode("RoiTrackingCalculator"); + auto& tracking_node_opts = + tracking_node.GetOptions(); + auto* rect_requirements = tracking_node_opts.mutable_rect_requirements(); + rect_requirements->set_rotation_degrees(15.0); + rect_requirements->set_translation(0.1); + rect_requirements->set_scale(0.3); + auto* landmarks_requirements = + tracking_node_opts.mutable_landmarks_requirements(); + landmarks_requirements->set_recrop_rect_margin(-0.2); + prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS")); + prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT")); + roi.ConnectTo(tracking_node.In("RECROP_RECT")); + image_size.ConnectTo(tracking_node.In("IMAGE_SIZE")); + return tracking_node.Out("TRACKING_RECT").Cast(); +} + +FaceLandmarksResult GetFaceLandmarksDetection( + Stream image, Stream roi, + Stream> image_size, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, Graph& graph) { + FaceLandmarksResult result; + auto& face_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker." + "SingleFaceLandmarksDetectorGraph"); + face_landmarks_detector_graph + .GetOptions() = + face_landmarks_detector_graph_options; + image >> face_landmarks_detector_graph.In("IMAGE"); + roi >> face_landmarks_detector_graph.In("NORM_RECT"); + auto landmarks = face_landmarks_detector_graph.Out("NORM_LANDMARKS") + .Cast(); + result.landmarks = landmarks; + if (request.classifications) { + auto& blendshapes_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph"); + blendshapes_graph + .GetOptions() = + face_landmarks_detector_graph_options.face_blendshapes_graph_options(); + landmarks >> blendshapes_graph.In("LANDMARKS"); + image_size >> blendshapes_graph.In("IMAGE_SIZE"); + result.classifications = + blendshapes_graph.Out("BLENDSHAPES").Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr TrackHolisticFace( + Stream image, Stream pose_face_landmarks, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, Graph& graph) { + MP_RETURN_IF_ERROR(ValidateGraphOptions(face_detector_graph_options, + face_landmarks_detector_graph_options, + request)); + + // Extracts image size from the input images. + Stream> image_size = GetImageSize(image, graph); + + // Gets face ROI from pose face landmarks. + Stream roi_from_pose = + GetFaceRoiFromPoseFaceLandmarks(pose_face_landmarks, image_size, graph); + + // Detects faces within ROI of pose face. + Stream> face_detections = GetFaceDetections( + image, roi_from_pose, face_detector_graph_options, graph); + + // Gets face ROI from face detector. + Stream roi_from_detection = + GetFaceRoiFromFaceDetections(face_detections, image_size, graph); + + // Loop for previous frame landmarks. + auto [prev_landmarks, set_prev_landmarks_fn] = + GetLoopbackData(/*tick=*/image_size, graph); + + // Tracks face ROI. + auto tracking_roi = + TrackFaceRoi(prev_landmarks, roi_from_detection, image_size, graph); + + // Predicts face landmarks. + auto landmarks_detection_result = GetFaceLandmarksDetection( + image, tracking_roi, image_size, face_landmarks_detector_graph_options, + request, graph); + + // Sets previous landmarks for ROI tracking. + set_prev_landmarks_fn(landmarks_detection_result.landmarks.value()); + + return {{.landmarks = landmarks_detection_result.landmarks, + .classifications = landmarks_detection_result.classifications, + .debug_output = { + .roi_from_pose = roi_from_pose, + .roi_from_detection = roi_from_detection, + .tracking_roi = tracking_roi, + }}}; +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h new file mode 100644 index 000000000..835767ebc --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h @@ -0,0 +1,89 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +struct HolisticFaceTrackingRequest { + bool classifications = false; +}; + +struct HolisticFaceTrackingOutput { + std::optional> + landmarks; + std::optional> + classifications; + + struct DebugOutput { + api2::builder::Stream roi_from_pose; + api2::builder::Stream roi_from_detection; + api2::builder::Stream tracking_roi; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track a single face in @image based on pose landmarks. +// +// To track single face this subgraph uses pose face landmarks to obtain +// approximate face location, refines it with face detector model and then runs +// face landmarks model. It can also reuse face ROI from the previous frame if +// face hasn't moved too much. +// +// @image - Image to track a single face in. +// @pose_face_landmarks - Pose face landmarks to derive initial face location +// from. +// @face_detector_graph_options - face detector graph options used to detect the +// face within the RoI constructed from the pose face landmarks. +// @face_landmarks_detector_graph_options - face landmarks detector graph +// options used to detect face landmarks within the RoI given be the face +// detector graph. +// @request - object to request specific face tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed. +// @graph - graph to update. +absl::StatusOr TrackHolisticFace( + api2::builder::Stream image, + api2::builder::Stream + pose_face_landmarks, + const face_detector::proto::FaceDetectorGraphOptions& + face_detector_graph_options, + const face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_detector_graph_options, + const HolisticFaceTrackingRequest& request, + mediapipe::api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc new file mode 100644 index 000000000..314c330b3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking_test.cc @@ -0,0 +1,227 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::SplitToRanges; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::core::proto::ExternalFile; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.015; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kTestImageFile[] = "male_full_height_hands.jpg"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in"; +constexpr char kFaceLandmarksOutStream[] = "face_landmarks_out"; +constexpr char kRenderedImageOutStream[] = "rendered_image_out"; +constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite"; +constexpr char kFaceLandmarksDetectorTFLiteName[] = + "face_landmarks_detector.tflite"; + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetFaceRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions render_options; + for (const auto& connection : + face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors) { + render_options.add_landmark_connections(connection[0]); + render_options.add_landmark_connections(connection[1]); + } + render_options.mutable_landmark_color()->set_r(255); + render_options.mutable_landmark_color()->set_g(255); + render_options.mutable_landmark_color()->set_b(255); + render_options.mutable_connection_color()->set_r(255); + render_options.mutable_connection_color()->set_g(255); + render_options.mutable_connection_color()->set_b(255); + render_options.set_thickness(0.5); + render_options.set_visualize_landmark_depth(false); + return render_options; +} + +absl::StatusOr> +CreateModelAssetBundleResources(const std::string& model_asset_filename) { + auto external_model_bundle = std::make_unique(); + external_model_bundle->set_file_name(model_asset_filename); + return ModelAssetBundleResources::Create("", + std::move(external_model_bundle)); +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream face_landmarks_from_pose = + SplitToRanges(pose_landmarks, {{0, 11}}, graph)[0]; + // Create face landmarker model bundle. + MP_ASSIGN_OR_RETURN( + auto model_bundle, + CreateModelAssetBundleResources(GetFilePath("face_landmarker_v2.task"))); + face_detector::proto::FaceDetectorGraphOptions detector_options; + face_landmarker::proto::FaceLandmarksDetectorGraphOptions + landmarks_detector_options; + + // Set face detection model. + MP_ASSIGN_OR_RETURN(auto face_detector_model_file, + model_bundle->GetFile(kFaceDetectorTFLiteName)); + core::proto::FilePointerMeta face_detection_file_pointer; + face_detection_file_pointer.set_pointer( + reinterpret_cast(face_detector_model_file.data())); + face_detection_file_pointer.set_length(face_detector_model_file.size()); + detector_options.mutable_base_options() + ->mutable_model_asset() + ->mutable_file_pointer_meta() + ->Swap(&face_detection_file_pointer); + detector_options.set_num_faces(1); + + // Set face landmarks model. + MP_ASSIGN_OR_RETURN(auto face_landmarks_model_file, + model_bundle->GetFile(kFaceLandmarksDetectorTFLiteName)); + core::proto::FilePointerMeta face_landmarks_detector_file_pointer; + face_landmarks_detector_file_pointer.set_pointer( + reinterpret_cast(face_landmarks_model_file.data())); + face_landmarks_detector_file_pointer.set_length( + face_landmarks_model_file.size()); + landmarks_detector_options.mutable_base_options() + ->mutable_model_asset() + ->mutable_file_pointer_meta() + ->Swap(&face_landmarks_detector_file_pointer); + + // Track holistic face. + HolisticFaceTrackingRequest request; + MP_ASSIGN_OR_RETURN( + HolisticFaceTrackingOutput result, + TrackHolisticFace(image, face_landmarks_from_pose, detector_options, + landmarks_detector_options, request, graph)); + auto face_landmarks = + result.landmarks.value().SetName(kFaceLandmarksOutStream); + + auto image_size = GetImageSize(image, graph); + auto render_scale = utils::GetRenderScale( + image_size, result.debug_output.roi_from_pose, 0.0001, graph); + + auto face_landmarks_render_data = utils::RenderLandmarks( + face_landmarks, render_scale, GetFaceRendererOptions(), graph); + std::vector> render_list = { + face_landmarks_render_data}; + + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + face_landmarks >> graph.Out("FACE_LANDMARKS"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + return TaskRunner::Create( + config, std::make_unique()); +} + +class HolisticFaceTrackingTest : public ::testing::Test {}; + +TEST_F(HolisticFaceTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + ::file::Defaults())); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + task_runner->Process( + {{kImageInStream, MakePacket(image)}, + {kPoseLandmarksInStream, MakePacket( + holistic_result.pose_landmarks())}})); + ASSERT_TRUE(output_packets.find(kFaceLandmarksOutStream) != + output_packets.end()); + auto face_landmarks = output_packets.find(kFaceLandmarksOutStream) + ->second.Get(); + EXPECT_THAT( + face_landmarks, + Approximately(Partially(EqualsProto(holistic_result.face_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = output_packets.at(kRenderedImageOutStream).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "holistic_face_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc new file mode 100644 index 000000000..2c57aa059 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.cc @@ -0,0 +1,272 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.h" +#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/api2/stream/threshold.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::AlignHandToPoseInWorldCalculator; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::IsOverThreshold; +using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong; +using ::mediapipe::api2::builder::SplitAndCombine; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::components::utils::AllowIf; + +struct HandLandmarksResult { + std::optional> landmarks; + std::optional> world_landmarks; +}; + +Stream AlignHandToPoseInWorldCalculator( + Stream hand_world_landmarks, + Stream pose_world_landmarks, int pose_wrist_idx, + Graph& graph) { + auto& node = graph.AddNode("AlignHandToPoseInWorldCalculator"); + auto& opts = node.GetOptions(); + opts.set_hand_wrist_idx(0); + opts.set_pose_wrist_idx(pose_wrist_idx); + hand_world_landmarks.ConnectTo( + node[AlignHandToPoseInWorldCalculator::kInHandLandmarks]); + pose_world_landmarks.ConnectTo( + node[AlignHandToPoseInWorldCalculator::kInPoseLandmarks]); + return node[AlignHandToPoseInWorldCalculator::kOutHandLandmarks]; +} + +Stream GetPosePalmVisibility( + Stream pose_palm_landmarks, Graph& graph) { + // Get wrist landmark. + auto pose_wrist = SplitAndCombine(pose_palm_landmarks, {0}, graph); + + // Get visibility score. + auto& score_node = graph.AddNode("LandmarkVisibilityCalculator"); + pose_wrist.ConnectTo(score_node.In("NORM_LANDMARKS")); + Stream score = score_node.Out("VISIBILITY").Cast(); + + // Convert score into flag. + return IsOverThreshold(score, /*threshold=*/0.1, graph); +} + +Stream GetHandRoiFromPosePalmLandmarks( + Stream pose_palm_landmarks, + Stream> image_size, Graph& graph) { + // Convert pose palm landmarks to detection. + auto detection = ConvertLandmarksToDetection(pose_palm_landmarks, graph); + + // Convert detection to rect. + auto& rect_node = graph.AddNode("HandDetectionsFromPoseToRectsCalculator"); + detection.ConnectTo(rect_node.In("DETECTION")); + image_size.ConnectTo(rect_node.In("IMAGE_SIZE")); + Stream rect = + rect_node.Out("NORM_RECT").Cast(); + + return ScaleAndShiftAndMakeSquareLong(rect, image_size, + /*scale_x_factor=*/2.7, + /*scale_y_factor=*/2.7, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); +} + +absl::StatusOr> RefineHandRoi( + Stream image, Stream roi, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinenement_graph_options, + Graph& graph) { + auto& hand_roi_refinement = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandRoiRefinementGraph"); + hand_roi_refinement + .GetOptions() = + hand_roi_refinenement_graph_options; + image >> hand_roi_refinement.In("IMAGE"); + roi >> hand_roi_refinement.In("NORM_RECT"); + return hand_roi_refinement.Out("NORM_RECT").Cast(); +} + +Stream TrackHandRoi( + Stream prev_landmarks, Stream roi, + Stream> image_size, Graph& graph) { + // Convert hand landmarks to tight rect. + auto& prev_rect_node = graph.AddNode("HandLandmarksToRectCalculator"); + prev_landmarks.ConnectTo(prev_rect_node.In("NORM_LANDMARKS")); + image_size.ConnectTo(prev_rect_node.In("IMAGE_SIZE")); + Stream prev_rect = + prev_rect_node.Out("NORM_RECT").Cast(); + + // Convert tight hand rect to hand roi. + Stream prev_roi = + ScaleAndShiftAndMakeSquareLong(prev_rect, image_size, + /*scale_x_factor=*/2.0, + /*scale_y_factor=*/2.0, /*shift_x=*/0, + /*shift_y=*/-0.1, graph); + + auto& tracking_node = graph.AddNode("RoiTrackingCalculator"); + auto& tracking_node_opts = + tracking_node.GetOptions(); + auto* rect_requirements = tracking_node_opts.mutable_rect_requirements(); + rect_requirements->set_rotation_degrees(40.0); + rect_requirements->set_translation(0.2); + rect_requirements->set_scale(0.4); + auto* landmarks_requirements = + tracking_node_opts.mutable_landmarks_requirements(); + landmarks_requirements->set_recrop_rect_margin(-0.1); + prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS")); + prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT")); + roi.ConnectTo(tracking_node.In("RECROP_RECT")); + image_size.ConnectTo(tracking_node.In("IMAGE_SIZE")); + return tracking_node.Out("TRACKING_RECT").Cast(); +} + +HandLandmarksResult GetHandLandmarksDetection( + Stream image, Stream roi, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const HolisticHandTrackingRequest& request, Graph& graph) { + HandLandmarksResult result; + auto& hand_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); + hand_landmarks_detector_graph + .GetOptions() = + hand_landmarks_detector_graph_options; + + image >> hand_landmarks_detector_graph.In("IMAGE"); + roi >> hand_landmarks_detector_graph.In("HAND_RECT"); + + if (request.landmarks) { + result.landmarks = hand_landmarks_detector_graph.Out("LANDMARKS") + .Cast(); + } + if (request.world_landmarks) { + result.world_landmarks = + hand_landmarks_detector_graph.Out("WORLD_LANDMARKS") + .Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr TrackHolisticHand( + Stream image, Stream pose_landmarks, + Stream pose_world_landmarks, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_graph_options, + const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request, + Graph& graph) { + // Extracts pose palm landmarks. + Stream pose_palm_landmarks = SplitAndCombine( + pose_landmarks, + {pose_indices.wrist_idx, pose_indices.pinky_idx, pose_indices.index_idx}, + graph); + + // Get pose palm visibility. + Stream is_pose_palm_visible = + GetPosePalmVisibility(pose_palm_landmarks, graph); + + // Drop pose palm landmarks if pose palm is invisible. + pose_palm_landmarks = + AllowIf(pose_palm_landmarks, is_pose_palm_visible, graph); + + // Extracts image size from the input images. + Stream> image_size = GetImageSize(image, graph); + + // Get hand ROI from pose palm landmarks. + Stream roi_from_pose = + GetHandRoiFromPosePalmLandmarks(pose_palm_landmarks, image_size, graph); + + // Refine hand ROI with re-crop model. + MP_ASSIGN_OR_RETURN(Stream roi_from_recrop, + RefineHandRoi(image, roi_from_pose, + hand_roi_refinement_graph_options, graph)); + + // Loop for previous frame landmarks. + auto [prev_landmarks, set_prev_landmarks_fn] = + GetLoopbackData(/*tick=*/image_size, graph); + + // Track hand ROI. + auto tracking_roi = + TrackHandRoi(prev_landmarks, roi_from_recrop, image_size, graph); + + // Predict hand landmarks. + auto landmarks_detection_result = GetHandLandmarksDetection( + image, tracking_roi, hand_landmarks_detector_graph_options, request, + graph); + + // Set previous landmarks for ROI tracking. + set_prev_landmarks_fn(landmarks_detection_result.landmarks.value()); + + // Output landmarks. + std::optional> hand_landmarks; + if (request.landmarks) { + hand_landmarks = landmarks_detection_result.landmarks; + } + + // Output world landmarks. + std::optional> hand_world_landmarks; + if (request.world_landmarks) { + hand_world_landmarks = landmarks_detection_result.world_landmarks; + + // Align hand world landmarks with pose world landmarks. + hand_world_landmarks = AlignHandToPoseInWorldCalculator( + hand_world_landmarks.value(), pose_world_landmarks, + pose_indices.wrist_idx, graph); + } + + return {{.landmarks = hand_landmarks, + .world_landmarks = hand_world_landmarks, + .debug_output = { + .roi_from_pose = roi_from_pose, + .roi_from_recrop = roi_from_recrop, + .tracking_roi = tracking_roi, + }}}; +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h new file mode 100644 index 000000000..463f4979b --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +struct PoseIndices { + int wrist_idx; + int pinky_idx; + int index_idx; +}; + +struct HolisticHandTrackingRequest { + bool landmarks = false; + bool world_landmarks = false; +}; + +struct HolisticHandTrackingOutput { + std::optional> + landmarks; + std::optional> world_landmarks; + + struct DebugOutput { + api2::builder::Stream roi_from_pose; + api2::builder::Stream roi_from_recrop; + api2::builder::Stream tracking_roi; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track a single hand in @image based on pose landmarks. +// +// To track single hand this subgraph uses pose palm landmarks to obtain +// approximate hand location, refines it with re-crop model and then runs hand +// landmarks model. It can also reuse hand ROI from the previous frame if hand +// hasn't moved too much. +// +// @image - ImageFrame/GpuBuffer to track a single hand in. +// @pose_landmarks - Pose landmarks to derive initial hand location from. +// @pose_world_landmarks - Pose world landmarks to align hand world landmarks +// wrist with. +// @ hand_landmarks_detector_graph_options - Options of the +// HandLandmarksDetectorGraph used to detect the hand landmarks. +// @ hand_roi_refinement_graph_options - Options of HandRoiRefinementGraph used +// to refine the hand RoIs got from Pose landmarks. +// @request - object to request specific hand tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed. +// @graph - graph to update. +absl::StatusOr TrackHolisticHand( + api2::builder::Stream image, + api2::builder::Stream pose_landmarks, + api2::builder::Stream pose_world_landmarks, + const hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + const hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_graph_options, + const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request, + mediapipe::api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc new file mode 100644 index 000000000..4ae4a37ed --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking_test.cc @@ -0,0 +1,303 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.018; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kHolisticHandTrackingLeft[] = + "holistic_hand_tracking_left_hand_graph.pbtxt"; +constexpr char kTestImageFile[] = "male_full_height_hands.jpg"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in"; +constexpr char kPoseWorldLandmarksInStream[] = "pose_world_landmarks_in"; +constexpr char kLeftHandLandmarksOutStream[] = "left_hand_landmarks_out"; +constexpr char kLeftHandWorldLandmarksOutStream[] = + "left_hand_world_landmarks_out"; +constexpr char kRightHandLandmarksOutStream[] = "right_hand_landmarks_out"; +constexpr char kRenderedImageOutStream[] = "rendered_image_out"; +constexpr char kHandLandmarksModelFile[] = "hand_landmark_full.tflite"; +constexpr char kHandRoiRefinementModelFile[] = + "handrecrop_2020_07_21_v0.f16.tflite"; + +std::string GetFilePath(const std::string& filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetHandRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : hand_landmarker::kHandConnections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + renderer_options.mutable_landmark_color()->set_r(255); + renderer_options.mutable_landmark_color()->set_g(255); + renderer_options.mutable_landmark_color()->set_b(255); + renderer_options.mutable_connection_color()->set_r(255); + renderer_options.mutable_connection_color()->set_g(255); + renderer_options.mutable_connection_color()->set_b(255); + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +void ConfigHandTrackingModelsOptions( + hand_landmarker::proto::HandLandmarksDetectorGraphOptions& + hand_landmarks_detector_graph_options, + hand_landmarker::proto::HandRoiRefinementGraphOptions& + hand_roi_refinement_options) { + hand_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandLandmarksModelFile)); + + hand_roi_refinement_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandRoiRefinementModelFile)); +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream pose_world_landmarks = + graph.In("POSE_WORLD_LANDMARKS") + .Cast() + .SetName(kPoseWorldLandmarksInStream); + hand_landmarker::proto::HandLandmarksDetectorGraphOptions + hand_landmarks_detector_options; + hand_landmarker::proto::HandRoiRefinementGraphOptions + hand_roi_refinement_options; + ConfigHandTrackingModelsOptions(hand_landmarks_detector_options, + hand_roi_refinement_options); + HolisticHandTrackingRequest request; + request.landmarks = true; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput left_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kLeftWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftPinky1), + /*index_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftIndex1)}, + request, graph)); + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput right_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kRightWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kRightPinky1), + /*index_idx=*/ + static_cast( + pose_landmarker::PoseLandmarkName::kRightIndex1)}, + request, graph)); + + auto image_size = GetImageSize(image, graph); + auto left_hand_landmarks_render_data = utils::RenderLandmarks( + *left_hand_result.landmarks, + utils::GetRenderScale(image_size, + left_hand_result.debug_output.roi_from_pose, 0.0001, + graph), + GetHandRendererOptions(), graph); + auto right_hand_landmarks_render_data = utils::RenderLandmarks( + *right_hand_result.landmarks, + utils::GetRenderScale(image_size, + right_hand_result.debug_output.roi_from_pose, + 0.0001, graph), + GetHandRendererOptions(), graph); + std::vector> render_list = { + left_hand_landmarks_render_data, right_hand_landmarks_render_data}; + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >> + graph.Out("LEFT_HAND_LANDMARKS"); + right_hand_result.landmarks->SetName(kRightHandLandmarksOutStream) >> + graph.Out("RIGHT_HAND_LANDMARKS"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +class HolisticHandTrackingTest : public ::testing::Test {}; + +TEST_F(HolisticHandTrackingTest, VerifyGraph) { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + Stream pose_landmarks = + graph.In("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksInStream); + Stream pose_world_landmarks = + graph.In("POSE_WORLD_LANDMARKS") + .Cast() + .SetName(kPoseWorldLandmarksInStream); + hand_landmarker::proto::HandLandmarksDetectorGraphOptions + hand_landmarks_detector_options; + hand_landmarker::proto::HandRoiRefinementGraphOptions + hand_roi_refinement_options; + ConfigHandTrackingModelsOptions(hand_landmarks_detector_options, + hand_roi_refinement_options); + HolisticHandTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSERT_OK_AND_ASSIGN( + HolisticHandTrackingOutput left_hand_result, + TrackHolisticHand( + image, pose_landmarks, pose_world_landmarks, + hand_landmarks_detector_options, hand_roi_refinement_options, + PoseIndices{ + /*wrist_idx=*/static_cast( + pose_landmarker::PoseLandmarkName::kLeftWrist), + /*pinky_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftPinky1), + /*index_idx=*/ + static_cast(pose_landmarker::PoseLandmarkName::kLeftIndex1)}, + request, graph)); + left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >> + graph.Out("LEFT_HAND_LANDMARKS"); + left_hand_result.world_landmarks->SetName(kLeftHandWorldLandmarksOutStream) >> + graph.Out("LEFT_HAND_WORLD_LANDMARKS"); + + // Read the expected graph config. + std::string expected_graph_contents; + MP_ASSERT_OK(file::GetContents( + file::JoinPath("./", kTestDataDirectory, kHolisticHandTrackingLeft), + &expected_graph_contents)); + + // Need to replace the expected graph config with the test srcdir, because + // each run has different test dir on TAP. + expected_graph_contents = absl::Substitute( + expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir); + CalculatorGraphConfig expected_graph = + ParseTextProtoOrDie(expected_graph_contents); + + EXPECT_THAT(graph.GetConfig(), testing::proto::IgnoringRepeatedFieldOrdering( + testing::EqualsProto(expected_graph))); +} + +TEST_F(HolisticHandTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + Defaults())); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + task_runner->Process( + {{kImageInStream, MakePacket(image)}, + {kPoseLandmarksInStream, MakePacket( + holistic_result.pose_landmarks())}, + {kPoseWorldLandmarksInStream, + MakePacket( + holistic_result.pose_world_landmarks())}})); + auto left_hand_landmarks = output_packets.at(kLeftHandLandmarksOutStream) + .Get(); + auto right_hand_landmarks = output_packets.at(kRightHandLandmarksOutStream) + .Get(); + EXPECT_THAT(left_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.left_hand_landmarks())), + /*margin=*/kAbsMargin)); + EXPECT_THAT( + right_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.right_hand_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = output_packets.at(kRenderedImageOutStream).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "holistic_hand_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc new file mode 100644 index 000000000..2de358a6c --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph.cc @@ -0,0 +1,521 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { +namespace { + +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::metadata::SetExternalFile; + +constexpr absl::string_view kHandLandmarksDetectorModelName = + "hand_landmarks_detector.tflite"; +constexpr absl::string_view kHandRoiRefinementModelName = + "hand_roi_refinement.tflite"; +constexpr absl::string_view kFaceDetectorModelName = "face_detector.tflite"; +constexpr absl::string_view kFaceLandmarksDetectorModelName = + "face_landmarks_detector.tflite"; +constexpr absl::string_view kFaceBlendshapesModelName = + "face_blendshapes.tflite"; +constexpr absl::string_view kPoseDetectorModelName = "pose_detector.tflite"; +constexpr absl::string_view kPoseLandmarksDetectorModelName = + "pose_landmarks_detector.tflite"; + +absl::Status SetGraphPoseOutputs( + const HolisticPoseTrackingRequest& pose_request, + const CalculatorGraphConfig::Node& node, + HolisticPoseTrackingOutput& pose_output, Graph& graph) { + // Main outputs. + if (pose_request.landmarks) { + RET_CHECK(pose_output.landmarks.has_value()) + << "POSE_LANDMARKS output is not supported."; + pose_output.landmarks->ConnectTo(graph.Out("POSE_LANDMARKS")); + } + if (pose_request.world_landmarks) { + RET_CHECK(pose_output.world_landmarks.has_value()) + << "POSE_WORLD_LANDMARKS output is not supported."; + pose_output.world_landmarks->ConnectTo(graph.Out("POSE_WORLD_LANDMARKS")); + } + if (pose_request.segmentation_mask) { + RET_CHECK(pose_output.segmentation_mask.has_value()) + << "POSE_SEGMENTATION_MASK output is not supported."; + pose_output.segmentation_mask->ConnectTo( + graph.Out("POSE_SEGMENTATION_MASK")); + } + + // Debug outputs. + if (HasOutput(node, "POSE_AUXILIARY_LANDMARKS")) { + pose_output.debug_output.auxiliary_landmarks.ConnectTo( + graph.Out("POSE_AUXILIARY_LANDMARKS")); + } + if (HasOutput(node, "POSE_LANDMARKS_ROI")) { + pose_output.debug_output.roi_from_landmarks.ConnectTo( + graph.Out("POSE_LANDMARKS_ROI")); + } + + return absl::OkStatus(); +} + +// Sets the base options in the sub tasks. +template +absl::Status SetSubTaskBaseOptions( + const core::ModelAssetBundleResources* resources, + proto::HolisticLandmarkerGraphOptions* options, T* sub_task_options, + absl::string_view model_name, bool is_copy) { + if (!sub_task_options->base_options().has_model_asset()) { + MP_ASSIGN_OR_RETURN(const auto model_file_content, + resources->GetFile(std::string(model_name))); + SetExternalFile( + model_file_content, + sub_task_options->mutable_base_options()->mutable_model_asset(), + is_copy); + } + sub_task_options->mutable_base_options()->mutable_acceleration()->CopyFrom( + options->base_options().acceleration()); + sub_task_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + sub_task_options->mutable_base_options()->set_gpu_origin( + options->base_options().gpu_origin()); + return absl::OkStatus(); +} + +void SetGraphHandOutputs(bool is_left, const CalculatorGraphConfig::Node& node, + HolisticHandTrackingOutput& hand_output, + Graph& graph) { + const std::string hand_side = is_left ? "LEFT" : "RIGHT"; + + if (hand_output.landmarks) { + hand_output.landmarks->ConnectTo(graph.Out(hand_side + "_HAND_LANDMARKS")); + } + if (hand_output.world_landmarks) { + hand_output.world_landmarks->ConnectTo( + graph.Out(hand_side + "_HAND_WORLD_LANDMARKS")); + } + + // Debug outputs. + if (HasOutput(node, hand_side + "_HAND_ROI_FROM_POSE")) { + hand_output.debug_output.roi_from_pose.ConnectTo( + graph.Out(hand_side + "_HAND_ROI_FROM_POSE")); + } + if (HasOutput(node, hand_side + "_HAND_ROI_FROM_RECROP")) { + hand_output.debug_output.roi_from_recrop.ConnectTo( + graph.Out(hand_side + "_HAND_ROI_FROM_RECROP")); + } + if (HasOutput(node, hand_side + "_HAND_TRACKING_ROI")) { + hand_output.debug_output.tracking_roi.ConnectTo( + graph.Out(hand_side + "_HAND_TRACKING_ROI")); + } +} + +void SetGraphFaceOutputs(const CalculatorGraphConfig::Node& node, + HolisticFaceTrackingOutput& face_output, + Graph& graph) { + if (face_output.landmarks) { + face_output.landmarks->ConnectTo(graph.Out("FACE_LANDMARKS")); + } + if (face_output.classifications) { + face_output.classifications->ConnectTo(graph.Out("FACE_BLENDSHAPES")); + } + + // Face detection debug outputs + if (HasOutput(node, "FACE_ROI_FROM_POSE")) { + face_output.debug_output.roi_from_pose.ConnectTo( + graph.Out("FACE_ROI_FROM_POSE")); + } + if (HasOutput(node, "FACE_ROI_FROM_DETECTION")) { + face_output.debug_output.roi_from_detection.ConnectTo( + graph.Out("FACE_ROI_FROM_DETECTION")); + } + if (HasOutput(node, "FACE_TRACKING_ROI")) { + face_output.debug_output.tracking_roi.ConnectTo( + graph.Out("FACE_TRACKING_ROI")); + } +} + +} // namespace + +// Tracks pose and detects hands and face. +// +// NOTE: for GPU works only with image having GpuOrigin::TOP_LEFT +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// +// Outputs: +// POSE_LANDMARKS - NormalizedLandmarkList +// 33 landmarks (see pose_landmarker/pose_topology.h) +// 0 - nose +// 1 - left eye (inner) +// 2 - left eye +// 3 - left eye (outer) +// 4 - right eye (inner) +// 5 - right eye +// 6 - right eye (outer) +// 7 - left ear +// 8 - right ear +// 9 - mouth (left) +// 10 - mouth (right) +// 11 - left shoulder +// 12 - right shoulder +// 13 - left elbow +// 14 - right elbow +// 15 - left wrist +// 16 - right wrist +// 17 - left pinky +// 18 - right pinky +// 19 - left index +// 20 - right index +// 21 - left thumb +// 22 - right thumb +// 23 - left hip +// 24 - right hip +// 25 - left knee +// 26 - right knee +// 27 - left ankle +// 28 - right ankle +// 29 - left heel +// 30 - right heel +// 31 - left foot index +// 32 - right foot index +// POSE_WORLD_LANDMARKS - LandmarkList +// World landmarks are real world 3D coordinates with origin in hips center +// and coordinates in meters. To understand the difference: POSE_LANDMARKS +// stream provides coordinates (in pixels) of 3D object projected on a 2D +// surface of the image (check on how perspective projection works), while +// POSE_WORLD_LANDMARKS stream provides coordinates (in meters) of the 3D +// object itself. POSE_WORLD_LANDMARKS has the same landmarks topology, +// visibility and presence as POSE_LANDMARKS. +// POSE_SEGMENTATION_MASK - Image +// Separates person from background. Mask is stored as gray float32 image +// with [0.0, 1.0] range for pixels (1 for person and 0 for background) on +// CPU and, on GPU - RGBA texture with R channel indicating person vs. +// background probability. +// LEFT_HAND_LANDMARKS - NormalizedLandmarkList +// 21 left hand landmarks. +// RIGHT_HAND_LANDMARKS - NormalizedLandmarkList +// 21 right hand landmarks. +// FACE_LANDMARKS - NormalizedLandmarkList +// 468 face landmarks. +// FACE_BLENDSHAPES - ClassificationList +// Supplementary blendshape coefficients that are predicted directly from +// the input image. +// LEFT_HAND_WORLD_LANDMARKS - LandmarkList +// 21 left hand world 3D landmarks. +// Hand landmarks are aligned with pose landmarks: translated so that wrist +// from # hand matches wrist from pose in pose coordinates system. +// RIGHT_HAND_WORLD_LANDMARKS - LandmarkList +// 21 right hand world 3D landmarks. +// Hand landmarks are aligned with pose landmarks: translated so that wrist +// from # hand matches wrist from pose in pose coordinates system. +// IMAGE - Image +// The input image that the hiolistic landmarker runs on and has the pixel +// data stored on the target storage (CPU vs GPU). +// +// Debug outputs: +// POSE_AUXILIARY_LANDMARKS - NormalizedLandmarkList +// TODO: Return ROI rather than auxiliary landmarks +// Auxiliary landmarks for deriving the ROI in the subsequent image. +// 0 - hidden center point +// 1 - hidden scale point +// POSE_LANDMARKS_ROI - NormalizedRect +// Region of interest calculated based on landmarks. +// LEFT_HAND_ROI_FROM_POSE - NormalizedLandmarkList +// LEFT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList +// LEFT_HAND_TRACKING_ROI - NormalizedLandmarkList +// RIGHT_HAND_ROI_FROM_POSE - NormalizedLandmarkList +// RIGHT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList +// RIGHT_HAND_TRACKING_ROI - NormalizedLandmarkList +// FACE_ROI_FROM_POSE - NormalizedLandmarkList +// FACE_ROI_FROM_DETECTION - NormalizedLandmarkList +// FACE_TRACKING_ROI - NormalizedLandmarkList +// +// NOTE: failure is reported if some output has been requested, but specified +// model doesn't support it. +// +// NOTE: there will not be an output packet in an output stream for a +// particular timestamp if nothing is detected. However, the MediaPipe +// framework will internally inform the downstream calculators of the +// absence of this packet so that they don't wait for it unnecessarily. +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph" +// input_stream: "IMAGE:input_frames_image" +// output_stream: "POSE_LANDMARKS:pose_landmarks" +// output_stream: "POSE_WORLD_LANDMARKS:pose_world_landmarks" +// output_stream: "FACE_LANDMARKS:face_landmarks" +// output_stream: "FACE_BLENDSHAPES:extra_blendshapes" +// output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" +// output_stream: "LEFT_HAND_WORLD_LANDMARKS:left_hand_world_landmarks" +// output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" +// output_stream: "RIGHT_HAND_WORLD_LANDMARKS:right_hand_world_landmarks" +// node_options { +// [type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions] +// { +// base_options { +// model_asset { +// file_name: +// "mediapipe/tasks/testdata/vision/holistic_landmarker.task" +// } +// } +// face_detector_graph_options: { +// num_faces: 1 +// } +// pose_detector_graph_options: { +// num_poses: 1 +// } +// } +// } +// } +class HolisticLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + const auto& holistic_node = sc->OriginalNode(); + proto::HolisticLandmarkerGraphOptions* holistic_options = + sc->MutableOptions(); + const core::ModelAssetBundleResources* model_asset_bundle_resources; + if (holistic_options->base_options().has_model_asset()) { + MP_ASSIGN_OR_RETURN(model_asset_bundle_resources, + CreateModelAssetBundleResources< + proto::HolisticLandmarkerGraphOptions>(sc)); + } + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + bool create_copy = + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable(); + + Stream image = graph.In("IMAGE").Cast(); + + // Check whether Hand requested + const bool is_left_hand_requested = + HasOutput(holistic_node, "LEFT_HAND_LANDMARKS"); + const bool is_right_hand_requested = + HasOutput(holistic_node, "RIGHT_HAND_LANDMARKS"); + const bool is_left_hand_world_requested = + HasOutput(holistic_node, "LEFT_HAND_WORLD_LANDMARKS"); + const bool is_right_hand_world_requested = + HasOutput(holistic_node, "RIGHT_HAND_WORLD_LANDMARKS"); + const bool hands_requested = + is_left_hand_requested || is_right_hand_requested || + is_left_hand_world_requested || is_right_hand_world_requested; + if (hands_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_hand_landmarks_detector_graph_options(), + kHandLandmarksDetectorModelName, create_copy)); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_hand_roi_refinement_graph_options(), + kHandRoiRefinementModelName, create_copy)); + } + + // Check whether Face requested + const bool is_face_requested = HasOutput(holistic_node, "FACE_LANDMARKS"); + const bool is_face_blendshapes_requested = + HasOutput(holistic_node, "FACE_BLENDSHAPES"); + const bool face_requested = + is_face_requested || is_face_blendshapes_requested; + if (face_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_detector_graph_options(), + kFaceDetectorModelName, create_copy)); + // Forcely set num_faces to 1, because holistic landmarker only supports a + // single subject for now. + holistic_options->mutable_face_detector_graph_options()->set_num_faces(1); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_landmarks_detector_graph_options(), + kFaceLandmarksDetectorModelName, create_copy)); + if (is_face_blendshapes_requested) { + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_face_landmarks_detector_graph_options() + ->mutable_face_blendshapes_graph_options(), + kFaceBlendshapesModelName, create_copy)); + } + } + + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_pose_detector_graph_options(), + kPoseDetectorModelName, create_copy)); + // Forcely set num_poses to 1, because holistic landmarker sonly supports a + // single subject for now. + holistic_options->mutable_pose_detector_graph_options()->set_num_poses(1); + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + model_asset_bundle_resources, holistic_options, + holistic_options->mutable_pose_landmarks_detector_graph_options(), + kPoseLandmarksDetectorModelName, create_copy)); + + HolisticPoseTrackingRequest pose_request = { + .landmarks = HasOutput(holistic_node, "POSE_LANDMARKS") || + hands_requested || face_requested, + .world_landmarks = + HasOutput(holistic_node, "POSE_WORLD_LANDMARKS") || hands_requested, + .segmentation_mask = + HasOutput(holistic_node, "POSE_SEGMENTATION_MASK")}; + + // Detect and track pose. + MP_ASSIGN_OR_RETURN( + HolisticPoseTrackingOutput pose_output, + TrackHolisticPose( + image, holistic_options->pose_detector_graph_options(), + holistic_options->pose_landmarks_detector_graph_options(), + pose_request, graph)); + MP_RETURN_IF_ERROR( + SetGraphPoseOutputs(pose_request, holistic_node, pose_output, graph)); + + // Detect and track hand. + if (hands_requested) { + if (is_left_hand_requested || is_left_hand_world_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + RET_CHECK(pose_output.world_landmarks.has_value()); + + PoseIndices pose_indices = { + .wrist_idx = + static_cast(pose_landmarker::PoseLandmarkName::kLeftWrist), + .pinky_idx = static_cast( + pose_landmarker::PoseLandmarkName::kLeftPinky1), + .index_idx = static_cast( + pose_landmarker::PoseLandmarkName::kLeftIndex1), + }; + HolisticHandTrackingRequest hand_request = { + .landmarks = is_left_hand_requested, + .world_landmarks = is_left_hand_world_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput hand_output, + TrackHolisticHand( + image, *pose_output.landmarks, *pose_output.world_landmarks, + holistic_options->hand_landmarks_detector_graph_options(), + holistic_options->hand_roi_refinement_graph_options(), + pose_indices, hand_request, graph + + )); + SetGraphHandOutputs(/*is_left=*/true, holistic_node, hand_output, + graph); + } + + if (is_right_hand_requested || is_right_hand_world_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + RET_CHECK(pose_output.world_landmarks.has_value()); + + PoseIndices pose_indices = { + .wrist_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightWrist), + .pinky_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightPinky1), + .index_idx = static_cast( + pose_landmarker::PoseLandmarkName::kRightIndex1), + }; + HolisticHandTrackingRequest hand_request = { + .landmarks = is_right_hand_requested, + .world_landmarks = is_right_hand_world_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticHandTrackingOutput hand_output, + TrackHolisticHand( + image, *pose_output.landmarks, *pose_output.world_landmarks, + holistic_options->hand_landmarks_detector_graph_options(), + holistic_options->hand_roi_refinement_graph_options(), + pose_indices, hand_request, graph + + )); + SetGraphHandOutputs(/*is_left=*/false, holistic_node, hand_output, + graph); + } + } + + // Detect and track face. + if (face_requested) { + RET_CHECK(pose_output.landmarks.has_value()); + + Stream face_landmarks_from_pose = + api2::builder::SplitToRanges(*pose_output.landmarks, {{0, 11}}, + graph)[0]; + + HolisticFaceTrackingRequest face_request = { + .classifications = is_face_blendshapes_requested, + }; + MP_ASSIGN_OR_RETURN( + HolisticFaceTrackingOutput face_output, + TrackHolisticFace( + image, face_landmarks_from_pose, + holistic_options->face_detector_graph_options(), + holistic_options->face_landmarks_detector_graph_options(), + face_request, graph)); + SetGraphFaceOutputs(holistic_node, face_output, graph); + } + + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image >> pass_through.In(""); + pass_through.Out("") >> graph.Out("IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + return config; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::holistic_landmarker::HolisticLandmarkerGraph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc new file mode 100644 index 000000000..c549a022b --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_landmarker_graph_test.cc @@ -0,0 +1,595 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/googletest.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { +namespace { + +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.025; +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision/"; +constexpr char kHolisticResultFile[] = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg"; +constexpr absl::string_view kImageInStream = "image_in"; +constexpr absl::string_view kLeftHandLandmarksStream = "left_hand_landmarks"; +constexpr absl::string_view kRightHandLandmarksStream = "right_hand_landmarks"; +constexpr absl::string_view kFaceLandmarksStream = "face_landmarks"; +constexpr absl::string_view kFaceBlendshapesStream = "face_blendshapes"; +constexpr absl::string_view kPoseLandmarksStream = "pose_landmarks"; +constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out"; +constexpr absl::string_view kPoseSegmentationMaskStream = + "pose_segmentation_mask"; +constexpr absl::string_view kHolisticLandmarkerModelBundleFile = + "holistic_landmarker.task"; +constexpr absl::string_view kHandLandmarksModelFile = + "hand_landmark_full.tflite"; +constexpr absl::string_view kHandRoiRefinementModelFile = + "handrecrop_2020_07_21_v0.f16.tflite"; +constexpr absl::string_view kPoseDetectionModelFile = "pose_detection.tflite"; +constexpr absl::string_view kPoseLandmarksModelFile = + "pose_landmark_lite.tflite"; +constexpr absl::string_view kFaceDetectionModelFile = + "face_detection_short_range.tflite"; +constexpr absl::string_view kFaceLandmarksModelFile = + "facemesh2_lite_iris_faceflag_2023_02_14.tflite"; +constexpr absl::string_view kFaceBlendshapesModelFile = + "face_blendshapes.tflite"; + +enum RenderPart { + HAND = 0, + POSE = 1, + FACE = 2, +}; + +mediapipe::Color GetColor(RenderPart render_part) { + mediapipe::Color color; + switch (render_part) { + case HAND: + color.set_b(255); + color.set_g(255); + color.set_r(255); + break; + case POSE: + color.set_b(0); + color.set_g(255); + color.set_r(0); + break; + case FACE: + color.set_b(0); + color.set_g(0); + color.set_r(255); + break; + } + return color; +} + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +template +mediapipe::LandmarksToRenderDataCalculatorOptions GetRendererOptions( + const std::array, N>& connections, + mediapipe::Color color) { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : connections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + *renderer_options.mutable_landmark_color() = color; + *renderer_options.mutable_connection_color() = color; + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +void ConfigureHandProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + options.mutable_hand_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandLandmarksModelFile)); + + options.mutable_hand_roi_refinement_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kHandRoiRefinementModelFile)); +} + +void ConfigureFaceProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + // Set face detection model. + face_detector::proto::FaceDetectorGraphOptions& face_detector_graph_options = + *options.mutable_face_detector_graph_options(); + face_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceDetectionModelFile)); + face_detector_graph_options.set_num_faces(1); + + // Set face landmarks model. + face_landmarker::proto::FaceLandmarksDetectorGraphOptions& + face_landmarks_graph_options = + *options.mutable_face_landmarks_detector_graph_options(); + face_landmarks_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceLandmarksModelFile)); + face_landmarks_graph_options.mutable_face_blendshapes_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kFaceBlendshapesModelFile)); +} + +void ConfigurePoseProtoOptions(proto::HolisticLandmarkerGraphOptions& options) { + pose_detector::proto::PoseDetectorGraphOptions& pose_detector_graph_options = + *options.mutable_pose_detector_graph_options(); + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kPoseDetectionModelFile)); + pose_detector_graph_options.set_num_poses(1); + options.mutable_pose_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath(kPoseLandmarksModelFile)); +} + +struct HolisticRequest { + bool is_left_hand_requested = false; + bool is_right_hand_requested = false; + bool is_face_requested = false; + bool is_face_blendshapes_requested = false; +}; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + bool use_model_bundle, HolisticRequest holistic_request) { + Graph graph; + + Stream image = graph.In("IMAEG").Cast().SetName(kImageInStream); + + auto& holistic_graph = graph.AddNode( + "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"); + proto::HolisticLandmarkerGraphOptions& options = + holistic_graph.GetOptions(); + if (use_model_bundle) { + options.mutable_base_options()->mutable_model_asset()->set_file_name( + GetFilePath(kHolisticLandmarkerModelBundleFile)); + } else { + ConfigureHandProtoOptions(options); + ConfigurePoseProtoOptions(options); + ConfigureFaceProtoOptions(options); + } + + std::vector> render_list; + image >> holistic_graph.In("IMAGE"); + Stream> image_size = GetImageSize(image, graph); + + if (holistic_request.is_left_hand_requested) { + Stream left_hand_landmarks = + holistic_graph.Out("LEFT_HAND_LANDMARKS") + .Cast() + .SetName(kLeftHandLandmarksStream); + Stream left_hand_tracking_roi = + holistic_graph.Out("LEFT_HAND_TRACKING_ROI").Cast(); + auto left_hand_landmarks_render_data = utils::RenderLandmarks( + left_hand_landmarks, + utils::GetRenderScale(image_size, left_hand_tracking_roi, 0.0001, + graph), + GetRendererOptions(hand_landmarker::kHandConnections, + GetColor(RenderPart::HAND)), + graph); + render_list.push_back(left_hand_landmarks_render_data); + left_hand_landmarks >> graph.Out("LEFT_HAND_LANDMARKS"); + } + if (holistic_request.is_right_hand_requested) { + Stream right_hand_landmarks = + holistic_graph.Out("RIGHT_HAND_LANDMARKS") + .Cast() + .SetName(kRightHandLandmarksStream); + Stream right_hand_tracking_roi = + holistic_graph.Out("RIGHT_HAND_TRACKING_ROI").Cast(); + auto right_hand_landmarks_render_data = utils::RenderLandmarks( + right_hand_landmarks, + utils::GetRenderScale(image_size, right_hand_tracking_roi, 0.0001, + graph), + GetRendererOptions(hand_landmarker::kHandConnections, + GetColor(RenderPart::HAND)), + graph); + render_list.push_back(right_hand_landmarks_render_data); + right_hand_landmarks >> graph.Out("RIGHT_HAND_LANDMARKS"); + } + if (holistic_request.is_face_requested) { + Stream face_landmarks = + holistic_graph.Out("FACE_LANDMARKS") + .Cast() + .SetName(kFaceLandmarksStream); + Stream face_tracking_roi = + holistic_graph.Out("FACE_TRACKING_ROI").Cast(); + auto face_landmarks_render_data = utils::RenderLandmarks( + face_landmarks, + utils::GetRenderScale(image_size, face_tracking_roi, 0.0001, graph), + GetRendererOptions( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors, + GetColor(RenderPart::FACE)), + graph); + render_list.push_back(face_landmarks_render_data); + face_landmarks >> graph.Out("FACE_LANDMARKS"); + } + if (holistic_request.is_face_blendshapes_requested) { + Stream face_blendshapes = + holistic_graph.Out("FACE_BLENDSHAPES") + .Cast() + .SetName(kFaceBlendshapesStream); + face_blendshapes >> graph.Out("FACE_BLENDSHAPES"); + } + Stream pose_landmarks = + holistic_graph.Out("POSE_LANDMARKS") + .Cast() + .SetName(kPoseLandmarksStream); + Stream pose_tracking_roi = + holistic_graph.Out("POSE_LANDMARKS_ROI").Cast(); + Stream pose_segmentation_mask = + holistic_graph.Out("POSE_SEGMENTATION_MASK") + .Cast() + .SetName(kPoseSegmentationMaskStream); + + auto pose_landmarks_render_data = utils::RenderLandmarks( + pose_landmarks, + utils::GetRenderScale(image_size, pose_tracking_roi, 0.0001, graph), + GetRendererOptions(pose_landmarker::kPoseLandmarksConnections, + GetColor(RenderPart::POSE)), + graph); + render_list.push_back(pose_landmarks_render_data); + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + + pose_landmarks >> graph.Out("POSE_LANDMARKS"); + pose_segmentation_mask >> graph.Out("POSE_SEGMENTATION_MASK"); + rendered_image >> graph.Out("RENDERED_IMAGE"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +template +absl::StatusOr FetchResult(const core::PacketMap& output_packets, + absl::string_view stream_name) { + auto it = output_packets.find(std::string(stream_name)); + RET_CHECK(it != output_packets.end()); + return it->second.Get(); +} + +// Remove fields not to be checked in the result, since the model +// generating expected result is different from the testing model. +void RemoveUncheckedResult(proto::HolisticResult& holistic_result) { + for (auto& landmark : + *holistic_result.mutable_pose_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_face_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_left_hand_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } + for (auto& landmark : + *holistic_result.mutable_right_hand_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } +} + +std::string RequestToString(HolisticRequest request) { + return absl::StrFormat( + "%s_%s_%s_%s", + request.is_left_hand_requested ? "left_hand" : "no_left_hand", + request.is_right_hand_requested ? "right_hand" : "no_right_hand", + request.is_face_requested ? "face" : "no_face", + request.is_face_blendshapes_requested ? "face_blendshapes" + : "no_face_blendshapes"); +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of test image. + std::string test_image_name; + // Whether to use holistic model bundle to test. + bool use_model_bundle; + // Requests of holistic parts. + HolisticRequest holistic_request; +}; + +class SmokeTest : public testing::TestWithParam {}; + +TEST_P(SmokeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(GetFilePath(GetParam().test_image_name))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + ::file::Defaults())); + RemoveUncheckedResult(holistic_result); + + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, + CreateTaskRunner(GetParam().use_model_bundle, + GetParam().holistic_request)); + MP_ASSERT_OK_AND_ASSIGN(auto output_packets, + task_runner->Process({{std::string(kImageInStream), + MakePacket(image)}})); + + // Check face landmarks + if (GetParam().holistic_request.is_face_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto face_landmarks, + FetchResult( + output_packets, kFaceLandmarksStream)); + EXPECT_THAT( + face_landmarks, + Approximately(Partially(EqualsProto(holistic_result.face_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE(output_packets.contains(std::string(kFaceLandmarksStream))); + } + + if (GetParam().holistic_request.is_face_blendshapes_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto face_blendshapes, + FetchResult( + output_packets, kFaceBlendshapesStream)); + EXPECT_THAT(face_blendshapes, + Approximately( + Partially(EqualsProto(holistic_result.face_blendshapes())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE(output_packets.contains(std::string(kFaceBlendshapesStream))); + } + + // Check Pose landmarks + MP_ASSERT_OK_AND_ASSIGN(auto pose_landmarks, + FetchResult( + output_packets, kPoseLandmarksStream)); + EXPECT_THAT( + pose_landmarks, + Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())), + /*margin=*/kAbsMargin)); + + // Check Hand landmarks + if (GetParam().holistic_request.is_left_hand_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto left_hand_landmarks, + FetchResult( + output_packets, kLeftHandLandmarksStream)); + EXPECT_THAT( + left_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.left_hand_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE( + output_packets.contains(std::string(kLeftHandLandmarksStream))); + } + + if (GetParam().holistic_request.is_right_hand_requested) { + MP_ASSERT_OK_AND_ASSIGN(auto right_hand_landmarks, + FetchResult( + output_packets, kRightHandLandmarksStream)); + EXPECT_THAT( + right_hand_landmarks, + Approximately( + Partially(EqualsProto(holistic_result.right_hand_landmarks())), + /*margin=*/kAbsMargin)); + } else { + ASSERT_FALSE( + output_packets.contains(std::string(kRightHandLandmarksStream))); + } + + auto rendered_image = + output_packets.at(std::string(kRenderedImageOutStream)).Get(); + MP_EXPECT_OK(SavePngTestOutput( + *rendered_image.GetImageFrameSharedPtr(), + absl::StrCat("holistic_landmark_", + RequestToString(GetParam().holistic_request)))); + + auto pose_segmentation_mask = + output_packets.at(std::string(kPoseSegmentationMaskStream)).Get(); + + cv::Mat matting_mask = mediapipe::formats::MatView( + pose_segmentation_mask.GetImageFrameSharedPtr().get()); + cv::Mat visualized_mask; + matting_mask.convertTo(visualized_mask, CV_8UC1, 255); + ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8, + visualized_mask.cols, visualized_mask.rows, + visualized_mask.step, visualized_mask.data, + [visualized_mask](uint8_t[]) {}); + + MP_EXPECT_OK( + SavePngTestOutput(visualized_image, "holistic_pose_segmentation_mask")); +} + +INSTANTIATE_TEST_SUITE_P( + HolisticLandmarkerGraphTest, SmokeTest, + Values(TestParams{ + /* test_name= */ "UseModelBundle", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "UseSeparateModelFiles", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ false, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoLeftHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ false, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoRightHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ false, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoHand", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ false, + /*is_right_hand_requested= */ false, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ true, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoFace", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ false, + /*is_face_blendshapes_requested= */ false, + }, + }, + TestParams{ + /* test_name= */ "ModelBundleNoFaceBlendshapes", + /* test_image_name= */ std::string(kTestImageFile), + /* use_model_bundle= */ true, + /* holistic_request= */ + { + /*is_left_hand_requested= */ true, + /*is_right_hand_requested= */ true, + /*is_face_requested= */ true, + /*is_face_blendshapes_requested= */ false, + }, + }), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc new file mode 100644 index 000000000..860035ad0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.cc @@ -0,0 +1,307 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/detections_to_rects.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/api2/stream/landmarks_to_detection.h" +#include "mediapipe/framework/api2/stream/loopback.h" +#include "mediapipe/framework/api2/stream/merge.h" +#include "mediapipe/framework/api2/stream/presence.h" +#include "mediapipe/framework/api2/stream/rect_transformation.h" +#include "mediapipe/framework/api2/stream/segmentation_smoothing.h" +#include "mediapipe/framework/api2/stream/smoothing.h" +#include "mediapipe/framework/api2/stream/split.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionsToRect; +using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect; +using ::mediapipe::api2::builder::ConvertLandmarksToDetection; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::GetLoopbackData; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::IsPresent; +using ::mediapipe::api2::builder::Merge; +using ::mediapipe::api2::builder::ScaleAndMakeSquare; +using ::mediapipe::api2::builder::SmoothLandmarks; +using ::mediapipe::api2::builder::SmoothLandmarksVisibility; +using ::mediapipe::api2::builder::SmoothSegmentationMask; +using ::mediapipe::api2::builder::SplitToRanges; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::components::utils::DisallowIf; +using Size = std::pair; + +constexpr int kAuxLandmarksStartKeypointIndex = 0; +constexpr int kAuxLandmarksEndKeypointIndex = 1; +constexpr float kAuxLandmarksTargetAngle = 90; +constexpr float kRoiFromDetectionScaleFactor = 1.25f; +constexpr float kRoiFromLandmarksScaleFactor = 1.25f; + +Stream CalculateRoiFromDetections( + Stream> detections, Stream image_size, + Graph& graph) { + auto roi = ConvertAlignmentPointsDetectionsToRect(detections, image_size, + /*start_keypoint_index=*/0, + /*end_keypoint_index=*/1, + /*target_angle=*/90, graph); + return ScaleAndMakeSquare( + roi, image_size, /*scale_x_factor=*/kRoiFromDetectionScaleFactor, + /*scale_y_factor=*/kRoiFromDetectionScaleFactor, graph); +} + +Stream CalculateScaleRoiFromAuxiliaryLandmarks( + Stream landmarks, Stream image_size, + Graph& graph) { + // TODO: consider calculating ROI directly from landmarks. + auto detection = ConvertLandmarksToDetection(landmarks, graph); + return ConvertAlignmentPointsDetectionToRect( + detection, image_size, kAuxLandmarksStartKeypointIndex, + kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph); +} + +Stream CalculateRoiFromAuxiliaryLandmarks( + Stream landmarks, Stream image_size, + Graph& graph) { + // TODO: consider calculating ROI directly from landmarks. + auto detection = ConvertLandmarksToDetection(landmarks, graph); + auto roi = ConvertAlignmentPointsDetectionToRect( + detection, image_size, kAuxLandmarksStartKeypointIndex, + kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph); + return ScaleAndMakeSquare( + roi, image_size, /*scale_x_factor=*/kRoiFromLandmarksScaleFactor, + /*scale_y_factor=*/kRoiFromLandmarksScaleFactor, graph); +} + +struct PoseLandmarksResult { + std::optional> landmarks; + std::optional> world_landmarks; + std::optional> auxiliary_landmarks; + std::optional> segmentation_mask; +}; + +PoseLandmarksResult RunLandmarksDetection( + Stream image, Stream roi, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + GenericNode& landmarks_graph = graph.AddNode( + "mediapipe.tasks.vision.pose_landmarker." + "SinglePoseLandmarksDetectorGraph"); + landmarks_graph + .GetOptions() = + pose_landmarks_detector_graph_options; + image >> landmarks_graph.In("IMAGE"); + roi >> landmarks_graph.In("NORM_RECT"); + + PoseLandmarksResult result; + if (request.landmarks) { + result.landmarks = + landmarks_graph.Out("LANDMARKS").Cast(); + result.auxiliary_landmarks = landmarks_graph.Out("AUXILIARY_LANDMARKS") + .Cast(); + } + if (request.world_landmarks) { + result.world_landmarks = + landmarks_graph.Out("WORLD_LANDMARKS").Cast(); + } + if (request.segmentation_mask) { + result.segmentation_mask = + landmarks_graph.Out("SEGMENTATION_MASK").Cast(); + } + return result; +} + +} // namespace + +absl::StatusOr +TrackHolisticPoseUsingCustomPoseDetection( + Stream image, PoseDetectionFn pose_detection_fn, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + // Calculate ROI from scratch (pose detection) or reuse one from the + // previous run if available. + auto [previous_roi, set_previous_roi_fn] = + GetLoopbackData(/*tick=*/image, graph); + auto is_previous_roi_available = IsPresent(previous_roi, graph); + auto image_for_detection = + DisallowIf(image, is_previous_roi_available, graph); + MP_ASSIGN_OR_RETURN(auto pose_detections, + pose_detection_fn(image_for_detection, graph)); + auto roi_from_detections = CalculateRoiFromDetections( + pose_detections, GetImageSize(image_for_detection, graph), graph); + // Take first non-empty. + auto roi = Merge(roi_from_detections, previous_roi, graph); + + // Calculate landmarks and other outputs (if requested) in the specified ROI. + auto landmarks_detection_result = RunLandmarksDetection( + image, roi, pose_landmarks_detector_graph_options, + { + // Landmarks are required for tracking, hence force-requesting them. + .landmarks = true, + .world_landmarks = request.world_landmarks, + .segmentation_mask = request.segmentation_mask, + }, + graph); + RET_CHECK(landmarks_detection_result.landmarks.has_value() && + landmarks_detection_result.auxiliary_landmarks.has_value()) + << "Failed to calculate landmarks required for tracking."; + + // Split landmarks to pose landmarks and auxiliary landmarks. + auto pose_landmarks_raw = *landmarks_detection_result.landmarks; + auto auxiliary_landmarks = *landmarks_detection_result.auxiliary_landmarks; + + auto image_size = GetImageSize(image, graph); + + // TODO: b/305750053 - Apply adaptive crop by adding AdaptiveCropCalculator. + + // Calculate ROI from smoothed auxiliary landmarks. + auto scale_roi = CalculateScaleRoiFromAuxiliaryLandmarks(auxiliary_landmarks, + image_size, graph); + auto auxiliary_landmarks_smoothed = SmoothLandmarks( + auxiliary_landmarks, image_size, scale_roi, + {// Min cutoff 0.01 results into ~0.002 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.01, + // Beta 10.0 in combintation with min_cutoff 0.01 results into ~0.68 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 10.0, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0}, + graph); + auto roi_from_auxiliary_landmarks = CalculateRoiFromAuxiliaryLandmarks( + auxiliary_landmarks_smoothed, image_size, graph); + + // Make ROI from auxiliary landmarks to be used as "previous" ROI for a + // subsequent run. + set_previous_roi_fn(roi_from_auxiliary_landmarks); + + // Populate and smooth pose landmarks if corresponding output has been + // requested. + std::optional> pose_landmarks; + if (request.landmarks) { + pose_landmarks = SmoothLandmarksVisibility( + pose_landmarks_raw, /*low_pass_filter_alpha=*/0.1f, graph); + pose_landmarks = SmoothLandmarks( + *pose_landmarks, image_size, scale_roi, + {// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.05f, + // Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 80.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0f}, + graph); + } + + // Populate and smooth world landmarks if available. + std::optional> world_landmarks; + if (landmarks_detection_result.world_landmarks) { + world_landmarks = SplitToRanges(*landmarks_detection_result.world_landmarks, + /*ranges*/ {{0, 33}}, graph)[0]; + world_landmarks = SmoothLandmarksVisibility( + *world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph); + world_landmarks = SmoothLandmarks( + *world_landmarks, + /*scale_roi=*/std::nullopt, + {// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter when + // landmark is static. + .min_cutoff = 0.1f, + // Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8 + // alpha in landmark EMA filter when landmark is moving fast. + .beta = 40.0f, + // Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity + // EMA filter. + .derivate_cutoff = 1.0f}, + graph); + } + + // Populate and smooth segmentation mask if available. + std::optional> segmentation_mask; + if (landmarks_detection_result.segmentation_mask) { + auto mask = *landmarks_detection_result.segmentation_mask; + auto [prev_mask_as_img, set_prev_mask_as_img_fn] = + GetLoopbackData( + /*tick=*/*landmarks_detection_result.segmentation_mask, graph); + auto mask_smoothed = + SmoothSegmentationMask(mask, prev_mask_as_img, + /*combine_with_previous_ratio=*/0.7f, graph); + set_prev_mask_as_img_fn(mask_smoothed); + segmentation_mask = mask_smoothed; + } + + return {{/*landmarks=*/pose_landmarks, + /*world_landmarks=*/world_landmarks, + /*segmentation_mask=*/segmentation_mask, + /*debug_output=*/ + {/*auxiliary_landmarks=*/auxiliary_landmarks_smoothed, + /*roi_from_landmarks=*/roi_from_auxiliary_landmarks, + /*detections*/ pose_detections}}}; +} + +absl::StatusOr TrackHolisticPose( + Stream image, + const pose_detector::proto::PoseDetectorGraphOptions& + pose_detector_graph_options, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, Graph& graph) { + PoseDetectionFn pose_detection_fn = [&pose_detector_graph_options]( + Stream image, Graph& graph) + -> absl::StatusOr>> { + GenericNode& pose_detector = + graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); + pose_detector.GetOptions() = + pose_detector_graph_options; + image >> pose_detector.In("IMAGE"); + return pose_detector.Out("DETECTIONS") + .Cast>(); + }; + return TrackHolisticPoseUsingCustomPoseDetection( + image, pose_detection_fn, pose_landmarks_detector_graph_options, request, + graph); +} + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h new file mode 100644 index 000000000..f51ccc283 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h @@ -0,0 +1,110 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +// Type of pose detection function that can be used to customize pose tracking, +// by supplying the function into a corresponding `TrackPose` function overload. +// +// Function should update provided graph with node/nodes that accept image +// stream and produce stream of detections. +using PoseDetectionFn = std::function< + absl::StatusOr>>( + api2::builder::Stream, api2::builder::Graph&)>; + +struct HolisticPoseTrackingRequest { + bool landmarks = false; + bool world_landmarks = false; + bool segmentation_mask = false; +}; + +struct HolisticPoseTrackingOutput { + std::optional> + landmarks; + std::optional> world_landmarks; + std::optional> segmentation_mask; + + struct DebugOutput { + api2::builder::Stream + auxiliary_landmarks; + api2::builder::Stream roi_from_landmarks; + api2::builder::Stream> detections; + }; + + DebugOutput debug_output; +}; + +// Updates @graph to track pose in @image. +// +// @image - ImageFrame/GpuBuffer to track pose in. +// @pose_detection_fn - pose detection function that takes @image as input and +// produces stream of pose detections. +// @pose_landmarks_detector_graph_options - options of the +// PoseLandmarksDetectorGraph used to detect the pose landmarks. +// @request - object to request specific pose tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed all. +// @graph - graph to update. +absl::StatusOr +TrackHolisticPoseUsingCustomPoseDetection( + api2::builder::Stream image, PoseDetectionFn pose_detection_fn, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph); + +// Updates @graph to track pose in @image. +// +// @image - ImageFrame/GpuBuffer to track pose in. +// @pose_detector_graph_options - options of the PoseDetectorGraph used to +// detect the pose. +// @pose_landmarks_detector_graph_options - options of the +// PoseLandmarksDetectorGraph used to detect the pose landmarks. +// @request - object to request specific pose tracking outputs. +// NOTE: Outputs that were not requested won't be returned and corresponding +// parts of the graph won't be genertaed all. +// @graph - graph to update. +absl::StatusOr TrackHolisticPose( + api2::builder::Stream image, + const pose_detector::proto::PoseDetectorGraphOptions& + pose_detector_graph_options, + const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions& + pose_landmarks_detector_graph_options, + const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph); + +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_ diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc new file mode 100644 index 000000000..0bf7259e8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking_test.cc @@ -0,0 +1,243 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/stream/image_size.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/data_renderer.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +#include "testing/base/public/googletest.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace holistic_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::Image; +using ::mediapipe::api2::builder::GetImageSize; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; +using ::mediapipe::tasks::core::TaskRunner; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr float kAbsMargin = 0.025; +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision/"; +constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg"; +constexpr absl::string_view kImageInStream = "image_in"; +constexpr absl::string_view kPoseLandmarksOutStream = "pose_landmarks_out"; +constexpr absl::string_view kPoseWorldLandmarksOutStream = + "pose_world_landmarks_out"; +constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out"; +constexpr absl::string_view kHolisticResultFile = + "male_full_height_hands_result_cpu.pbtxt"; +constexpr absl::string_view kHolisticPoseTrackingGraph = + "holistic_pose_tracking_graph.pbtxt"; + +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDirectory, filename); +} + +mediapipe::LandmarksToRenderDataCalculatorOptions GetPoseRendererOptions() { + mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options; + for (const auto& connection : pose_landmarker::kPoseLandmarksConnections) { + renderer_options.add_landmark_connections(connection[0]); + renderer_options.add_landmark_connections(connection[1]); + } + renderer_options.mutable_landmark_color()->set_r(255); + renderer_options.mutable_landmark_color()->set_g(255); + renderer_options.mutable_landmark_color()->set_b(255); + renderer_options.mutable_connection_color()->set_r(255); + renderer_options.mutable_connection_color()->set_g(255); + renderer_options.mutable_connection_color()->set_b(255); + renderer_options.set_thickness(0.5); + renderer_options.set_visualize_landmark_depth(false); + return renderer_options; +} + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options; + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_detection.tflite")); + pose_detector_graph_options.set_num_poses(1); + pose_landmarker::proto::PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options; + pose_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_landmark_lite.tflite")); + + HolisticPoseTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSIGN_OR_RETURN( + HolisticPoseTrackingOutput result, + TrackHolisticPose(image, pose_detector_graph_options, + pose_landmarks_detector_graph_options, request, graph)); + + auto image_size = GetImageSize(image, graph); + auto render_data = utils::RenderLandmarks( + *result.landmarks, + utils::GetRenderScale(image_size, result.debug_output.roi_from_landmarks, + 0.0001, graph), + GetPoseRendererOptions(), graph); + std::vector> render_list = {render_data}; + auto rendered_image = + utils::Render( + image, absl::Span>(render_list), graph) + .SetName(kRenderedImageOutStream); + + rendered_image >> graph.Out("RENDERED_IMAGE"); + result.landmarks->SetName(kPoseLandmarksOutStream) >> + graph.Out("POSE_LANDMARKS"); + result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >> + graph.Out("POSE_WORLD_LANDMARKS"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + return TaskRunner::Create( + config, std::make_unique()); +} + +// Remove fields not to be checked in the result, since the model +// generating expected result is different from the testing model. +void RemoveUncheckedResult(proto::HolisticResult& holistic_result) { + for (auto& landmark : + *holistic_result.mutable_pose_landmarks()->mutable_landmark()) { + landmark.clear_z(); + landmark.clear_visibility(); + landmark.clear_presence(); + } +} + +class HolisticPoseTrackingTest : public testing::Test {}; + +TEST_F(HolisticPoseTrackingTest, VerifyGraph) { + Graph graph; + Stream image = graph.In("IMAGE").Cast().SetName(kImageInStream); + pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options; + pose_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_detection.tflite")); + pose_detector_graph_options.set_num_poses(1); + pose_landmarker::proto::PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options; + pose_landmarks_detector_graph_options.mutable_base_options() + ->mutable_model_asset() + ->set_file_name(GetFilePath("pose_landmark_lite.tflite")); + HolisticPoseTrackingRequest request; + request.landmarks = true; + request.world_landmarks = true; + MP_ASSERT_OK_AND_ASSIGN( + HolisticPoseTrackingOutput result, + TrackHolisticPose(image, pose_detector_graph_options, + pose_landmarks_detector_graph_options, request, graph)); + result.landmarks->SetName(kPoseLandmarksOutStream) >> + graph.Out("POSE_LANDMARKS"); + result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >> + graph.Out("POSE_WORLD_LANDMARKS"); + + auto config = graph.GetConfig(); + core::FixGraphBackEdges(config); + + // Read the expected graph config. + std::string expected_graph_contents; + MP_ASSERT_OK(file::GetContents( + file::JoinPath("./", kTestDataDirectory, kHolisticPoseTrackingGraph), + &expected_graph_contents)); + + // Need to replace the expected graph config with the test srcdir, because + // each run has different test dir on TAP. + expected_graph_contents = absl::Substitute( + expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir); + CalculatorGraphConfig expected_graph = + ParseTextProtoOrDie(expected_graph_contents); + + EXPECT_THAT(config, testing::proto::IgnoringRepeatedFieldOrdering( + testing::EqualsProto(expected_graph))); +} + +TEST_F(HolisticPoseTrackingTest, SmokeTest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(GetFilePath(kTestImageFile))); + + proto::HolisticResult holistic_result; + MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result, + Defaults())); + RemoveUncheckedResult(holistic_result); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + MP_ASSERT_OK_AND_ASSIGN(auto output_packets, + task_runner->Process({{std::string(kImageInStream), + MakePacket(image)}})); + auto pose_landmarks = output_packets.at(std::string(kPoseLandmarksOutStream)) + .Get(); + EXPECT_THAT( + pose_landmarks, + Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())), + /*margin=*/kAbsMargin)); + auto rendered_image = + output_packets.at(std::string(kRenderedImageOutStream)).Get(); + MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(), + "pose_landmarks")); +} + +} // namespace +} // namespace holistic_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD new file mode 100644 index 000000000..147f3cc86 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "holistic_result_proto", + srcs = ["holistic_result.proto"], + deps = [ + "//mediapipe/framework/formats:classification_proto", + "//mediapipe/framework/formats:landmark_proto", + ], +) + +mediapipe_proto_library( + name = "holistic_landmarker_graph_options_proto", + srcs = ["holistic_landmarker_graph_options.proto"], + deps = [ + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_proto", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto new file mode 100644 index 000000000..86aba8887 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.proto @@ -0,0 +1,57 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.holistic_landmarker.proto; + +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.proto"; +import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker.proto"; +option java_outer_classname = "HolisticLandmarkerGraphOptionsProto"; + +message HolisticLandmarkerGraphOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // asset bundle file with metadata, accelerator options, etc. + core.proto.BaseOptions base_options = 1; + + // Options for hand landmarks graph. + hand_landmarker.proto.HandLandmarksDetectorGraphOptions + hand_landmarks_detector_graph_options = 2; + + // Options for hand roi refinement graph. + hand_landmarker.proto.HandRoiRefinementGraphOptions + hand_roi_refinement_graph_options = 3; + + // Options for face detector graph. + face_detector.proto.FaceDetectorGraphOptions face_detector_graph_options = 4; + + // Options for face landmarks detector graph. + face_landmarker.proto.FaceLandmarksDetectorGraphOptions + face_landmarks_detector_graph_options = 5; + + // Options for pose detector graph. + pose_detector.proto.PoseDetectorGraphOptions pose_detector_graph_options = 6; + + // Options for pose landmarks detector graph. + pose_landmarker.proto.PoseLandmarksDetectorGraphOptions + pose_landmarks_detector_graph_options = 7; +} diff --git a/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto new file mode 100644 index 000000000..356da45d9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.proto @@ -0,0 +1,34 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.holistic_landmarker.proto; + +import "mediapipe/framework/formats/classification.proto"; +import "mediapipe/framework/formats/landmark.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker"; +option java_outer_classname = "HolisticResultProto"; + +message HolisticResult { + mediapipe.NormalizedLandmarkList pose_landmarks = 1; + mediapipe.LandmarkList pose_world_landmarks = 7; + mediapipe.NormalizedLandmarkList left_hand_landmarks = 2; + mediapipe.NormalizedLandmarkList right_hand_landmarks = 3; + mediapipe.NormalizedLandmarkList face_landmarks = 4; + mediapipe.ClassificationList face_blendshapes = 6; + mediapipe.NormalizedLandmarkList auxiliary_landmarks = 5; +}