Holistic Landmarker C++ Graph

PiperOrigin-RevId: 586105983
This commit is contained in:
Sebastian Schmidt 2023-11-28 14:33:15 -08:00 committed by Copybara-Service
parent 95601ff98b
commit a898215c52
17 changed files with 3493 additions and 0 deletions

View File

@ -155,6 +155,37 @@ cc_library(
# TODO: open source hand joints graph
cc_library(
name = "hand_roi_refinement_graph",
srcs = ["hand_roi_refinement_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator",
"//mediapipe/calculators/tensor:tensors_to_landmarks_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:landmarks_projection",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
name = "hand_landmarker_result",
srcs = ["hand_landmarker_result.cc"],

View File

@ -0,0 +1,154 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_landmarks_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/landmarks_projection.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace hand_landmarker {
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::ProjectLandmarks;
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
using ::mediapipe::api2::builder::Stream;
// Refine the input hand RoI with hand_roi_refinement model.
//
// Inputs:
// IMAGE - Image
// The image to preprocess.
// NORM_RECT - NormalizedRect
// Coarse RoI of hand.
// Outputs:
// NORM_RECT - NormalizedRect
// Refined RoI of hand.
class HandRoiRefinementGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* context) override {
Graph graph;
Stream<Image> image_in = graph.In("IMAGE").Cast<Image>();
Stream<NormalizedRect> roi_in =
graph.In("NORM_RECT").Cast<NormalizedRect>();
auto& graph_options =
*context->MutableOptions<proto::HandRoiRefinementGraphOptions>();
MP_ASSIGN_OR_RETURN(
const auto* model_resources,
GetOrCreateModelResources<proto::HandRoiRefinementGraphOptions>(
context));
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
graph_options.base_options().acceleration());
auto& image_to_tensor_options =
*preprocessing
.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()
.mutable_image_to_tensor_options();
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_REPLICATE);
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
*model_resources, use_gpu, graph_options.base_options().gpu_origin(),
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In("IMAGE");
roi_in >> preprocessing.In("NORM_RECT");
auto tensors_in = preprocessing.Out("TENSORS");
auto matrix = preprocessing.Out("MATRIX").Cast<std::array<float, 16>>();
auto image_size =
preprocessing.Out("IMAGE_SIZE").Cast<std::pair<int, int>>();
auto& inference = AddInference(
*model_resources, graph_options.base_options().acceleration(), graph);
tensors_in >> inference.In("TENSORS");
auto tensors_out = inference.Out("TENSORS").Cast<std::vector<Tensor>>();
MP_ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildInputImageTensorSpecs(*model_resources));
// Convert tensors to landmarks. Recrop model outputs two points,
// center point and guide point.
auto& to_landmarks = graph.AddNode("TensorsToLandmarksCalculator");
auto& to_landmarks_opts =
to_landmarks
.GetOptions<mediapipe::TensorsToLandmarksCalculatorOptions>();
to_landmarks_opts.set_num_landmarks(/*num_landmarks=*/2);
to_landmarks_opts.set_input_image_width(image_tensor_specs.image_width);
to_landmarks_opts.set_input_image_height(image_tensor_specs.image_height);
to_landmarks_opts.set_normalize_z(/*z_norm_factor=*/1.0f);
tensors_out.ConnectTo(to_landmarks.In("TENSORS"));
auto recrop_landmarks = to_landmarks.Out("NORM_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>();
// Project landmarks.
auto projected_recrop_landmarks =
ProjectLandmarks(recrop_landmarks, matrix, graph);
// Convert re-crop landmarks to detection.
auto recrop_detection =
ConvertLandmarksToDetection(projected_recrop_landmarks, graph);
// Convert re-crop detection to rect.
auto recrop_rect = ConvertAlignmentPointsDetectionToRect(
recrop_detection, image_size, /*start_keypoint_index=*/0,
/*end_keypoint_index=*/1, /*target_angle=*/-90, graph);
auto refined_roi =
ScaleAndShiftAndMakeSquareLong(recrop_rect, image_size,
/*scale_x_factor=*/1.0,
/*scale_y_factor=*/1.0, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
refined_roi >> graph.Out("NORM_RECT").Cast<NormalizedRect>();
return graph.GetConfig();
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::hand_landmarker::HandRoiRefinementGraph);
} // namespace hand_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,152 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/tasks:internal"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "holistic_face_tracking",
srcs = ["holistic_face_tracking.cc"],
hdrs = ["holistic_face_tracking.h"],
deps = [
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker:face_blendshapes_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)
cc_library(
name = "holistic_hand_tracking",
srcs = ["holistic_hand_tracking.cc"],
hdrs = ["holistic_hand_tracking.h"],
deps = [
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator",
"//mediapipe/calculators/util:align_hand_to_pose_in_world_calculator_cc_proto",
"//mediapipe/calculators/util:landmark_visibility_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/api2/stream:threshold",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/modules/holistic_landmark/calculators:hand_detections_from_pose_to_rects_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator",
"//mediapipe/modules/holistic_landmark/calculators:roi_tracking_calculator_cc_proto",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_roi_refinement_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "holistic_pose_tracking",
srcs = ["holistic_pose_tracking.cc"],
hdrs = ["holistic_pose_tracking.h"],
deps = [
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:detections_to_rects",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:landmarks_to_detection",
"//mediapipe/framework/api2/stream:loopback",
"//mediapipe/framework/api2/stream:merge",
"//mediapipe/framework/api2/stream:presence",
"//mediapipe/framework/api2/stream:rect_transformation",
"//mediapipe/framework/api2/stream:segmentation_smoothing",
"//mediapipe/framework/api2/stream:smoothing",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "holistic_landmarker_graph",
srcs = ["holistic_landmarker_graph.cc"],
deps = [
":holistic_face_tracking",
":holistic_hand_tracking",
":holistic_pose_tracking",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2/stream:split",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_topology",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -0,0 +1,260 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include <functional>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::ConvertDetectionsToRectUsingKeypoints;
using ::mediapipe::api2::builder::ConvertDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Scale;
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
using ::mediapipe::api2::builder::Stream;
struct FaceLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<ClassificationList>> classifications;
};
absl::Status ValidateGraphOptions(
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request) {
if (face_detector_graph_options.num_faces() != 1) {
return absl::InvalidArgumentError(absl::StrFormat(
"Only support num_faces to be 1, but got num_faces = %d.",
face_detector_graph_options.num_faces()));
}
if (request.classifications && !face_landmarks_detector_graph_options
.has_face_blendshapes_graph_options()) {
return absl::InvalidArgumentError(
"Blendshapes detection is requested, but "
"face_blendshapes_graph_options is not configured.");
}
return absl::OkStatus();
}
Stream<NormalizedRect> GetFaceRoiFromPoseFaceLandmarks(
Stream<NormalizedLandmarkList> pose_face_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
Stream<mediapipe::Detection> detection =
ConvertLandmarksToDetection(pose_face_landmarks, graph);
// Refer the pose face landmarks indices here:
// https://developers.google.com/mediapipe/solutions/vision/pose_landmarker#pose_landmarker_model
Stream<NormalizedRect> rect = ConvertDetectionToRect(
detection, image_size, /*start_keypoint_index=*/5,
/*end_keypoint_index=*/2, /*target_angle=*/0, graph);
// Scale the face RoI from a tight rect enclosing the pose face landmarks, to
// a larger square so that the whole face is within the RoI.
return ScaleAndMakeSquare(rect, image_size,
/*scale_x_factor=*/3.0,
/*scale_y_factor=*/3.0, graph);
}
Stream<NormalizedRect> GetFaceRoiFromFaceLandmarks(
Stream<NormalizedLandmarkList> face_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
Stream<mediapipe::Detection> detection =
ConvertLandmarksToDetection(face_landmarks, graph);
Stream<NormalizedRect> rect = ConvertDetectionToRect(
detection, image_size, /*start_keypoint_index=*/33,
/*end_keypoint_index=*/263, /*target_angle=*/0, graph);
return Scale(rect, image_size,
/*scale_x_factor=*/1.5,
/*scale_y_factor=*/1.5, graph);
}
Stream<std::vector<Detection>> GetFaceDetections(
Stream<Image> image, Stream<NormalizedRect> roi,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
Graph& graph) {
auto& face_detector_graph =
graph.AddNode("mediapipe.tasks.vision.face_detector.FaceDetectorGraph");
face_detector_graph
.GetOptions<face_detector::proto::FaceDetectorGraphOptions>() =
face_detector_graph_options;
image >> face_detector_graph.In("IMAGE");
roi >> face_detector_graph.In("NORM_RECT");
return face_detector_graph.Out("DETECTIONS").Cast<std::vector<Detection>>();
}
Stream<NormalizedRect> GetFaceRoiFromFaceDetections(
Stream<std::vector<Detection>> face_detections,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert detection to rect.
Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
face_detections, image_size, /*start_keypoint_index=*/0,
/*end_keypoint_index=*/1, /*target_angle=*/0, graph);
return ScaleAndMakeSquare(rect, image_size,
/*scale_x_factor=*/2.0,
/*scale_y_factor=*/2.0, graph);
}
Stream<NormalizedRect> TrackFaceRoi(
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Gets face ROI from previous frame face landmarks.
Stream<NormalizedRect> prev_roi =
GetFaceRoiFromFaceLandmarks(prev_landmarks, image_size, graph);
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
auto& tracking_node_opts =
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
rect_requirements->set_rotation_degrees(15.0);
rect_requirements->set_translation(0.1);
rect_requirements->set_scale(0.3);
auto* landmarks_requirements =
tracking_node_opts.mutable_landmarks_requirements();
landmarks_requirements->set_recrop_rect_margin(-0.2);
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
}
FaceLandmarksResult GetFaceLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request, Graph& graph) {
FaceLandmarksResult result;
auto& face_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker."
"SingleFaceLandmarksDetectorGraph");
face_landmarks_detector_graph
.GetOptions<face_landmarker::proto::FaceLandmarksDetectorGraphOptions>() =
face_landmarks_detector_graph_options;
image >> face_landmarks_detector_graph.In("IMAGE");
roi >> face_landmarks_detector_graph.In("NORM_RECT");
auto landmarks = face_landmarks_detector_graph.Out("NORM_LANDMARKS")
.Cast<NormalizedLandmarkList>();
result.landmarks = landmarks;
if (request.classifications) {
auto& blendshapes_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceBlendshapesGraph");
blendshapes_graph
.GetOptions<face_landmarker::proto::FaceBlendshapesGraphOptions>() =
face_landmarks_detector_graph_options.face_blendshapes_graph_options();
landmarks >> blendshapes_graph.In("LANDMARKS");
image_size >> blendshapes_graph.In("IMAGE_SIZE");
result.classifications =
blendshapes_graph.Out("BLENDSHAPES").Cast<ClassificationList>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
Stream<Image> image, Stream<NormalizedLandmarkList> pose_face_landmarks,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request, Graph& graph) {
MP_RETURN_IF_ERROR(ValidateGraphOptions(face_detector_graph_options,
face_landmarks_detector_graph_options,
request));
// Extracts image size from the input images.
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
// Gets face ROI from pose face landmarks.
Stream<NormalizedRect> roi_from_pose =
GetFaceRoiFromPoseFaceLandmarks(pose_face_landmarks, image_size, graph);
// Detects faces within ROI of pose face.
Stream<std::vector<Detection>> face_detections = GetFaceDetections(
image, roi_from_pose, face_detector_graph_options, graph);
// Gets face ROI from face detector.
Stream<NormalizedRect> roi_from_detection =
GetFaceRoiFromFaceDetections(face_detections, image_size, graph);
// Loop for previous frame landmarks.
auto [prev_landmarks, set_prev_landmarks_fn] =
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
// Tracks face ROI.
auto tracking_roi =
TrackFaceRoi(prev_landmarks, roi_from_detection, image_size, graph);
// Predicts face landmarks.
auto landmarks_detection_result = GetFaceLandmarksDetection(
image, tracking_roi, image_size, face_landmarks_detector_graph_options,
request, graph);
// Sets previous landmarks for ROI tracking.
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
return {{.landmarks = landmarks_detection_result.landmarks,
.classifications = landmarks_detection_result.classifications,
.debug_output = {
.roi_from_pose = roi_from_pose,
.roi_from_detection = roi_from_detection,
.tracking_roi = tracking_roi,
}}};
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,89 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_
#include <optional>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
struct HolisticFaceTrackingRequest {
bool classifications = false;
};
struct HolisticFaceTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::ClassificationList>>
classifications;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_detection;
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
};
DebugOutput debug_output;
};
// Updates @graph to track a single face in @image based on pose landmarks.
//
// To track single face this subgraph uses pose face landmarks to obtain
// approximate face location, refines it with face detector model and then runs
// face landmarks model. It can also reuse face ROI from the previous frame if
// face hasn't moved too much.
//
// @image - Image to track a single face in.
// @pose_face_landmarks - Pose face landmarks to derive initial face location
// from.
// @face_detector_graph_options - face detector graph options used to detect the
// face within the RoI constructed from the pose face landmarks.
// @face_landmarks_detector_graph_options - face landmarks detector graph
// options used to detect face landmarks within the RoI given be the face
// detector graph.
// @request - object to request specific face tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed.
// @graph - graph to update.
absl::StatusOr<HolisticFaceTrackingOutput> TrackHolisticFace(
api2::builder::Stream<Image> image,
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
pose_face_landmarks,
const face_detector::proto::FaceDetectorGraphOptions&
face_detector_graph_options,
const face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_detector_graph_options,
const HolisticFaceTrackingRequest& request,
mediapipe::api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_FACE_TRACKING_H_

View File

@ -0,0 +1,227 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/calculators/util/rect_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SplitToRanges;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::core::proto::ExternalFile;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.015;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr char kImageInStream[] = "image_in";
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
constexpr char kFaceLandmarksOutStream[] = "face_landmarks_out";
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
constexpr char kFaceLandmarksDetectorTFLiteName[] =
"face_landmarks_detector.tflite";
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetFaceRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions render_options;
for (const auto& connection :
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors) {
render_options.add_landmark_connections(connection[0]);
render_options.add_landmark_connections(connection[1]);
}
render_options.mutable_landmark_color()->set_r(255);
render_options.mutable_landmark_color()->set_g(255);
render_options.mutable_landmark_color()->set_b(255);
render_options.mutable_connection_color()->set_r(255);
render_options.mutable_connection_color()->set_g(255);
render_options.mutable_connection_color()->set_b(255);
render_options.set_thickness(0.5);
render_options.set_visualize_landmark_depth(false);
return render_options;
}
absl::StatusOr<std::unique_ptr<ModelAssetBundleResources>>
CreateModelAssetBundleResources(const std::string& model_asset_filename) {
auto external_model_bundle = std::make_unique<ExternalFile>();
external_model_bundle->set_file_name(model_asset_filename);
return ModelAssetBundleResources::Create("",
std::move(external_model_bundle));
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<NormalizedLandmarkList> face_landmarks_from_pose =
SplitToRanges(pose_landmarks, {{0, 11}}, graph)[0];
// Create face landmarker model bundle.
MP_ASSIGN_OR_RETURN(
auto model_bundle,
CreateModelAssetBundleResources(GetFilePath("face_landmarker_v2.task")));
face_detector::proto::FaceDetectorGraphOptions detector_options;
face_landmarker::proto::FaceLandmarksDetectorGraphOptions
landmarks_detector_options;
// Set face detection model.
MP_ASSIGN_OR_RETURN(auto face_detector_model_file,
model_bundle->GetFile(kFaceDetectorTFLiteName));
core::proto::FilePointerMeta face_detection_file_pointer;
face_detection_file_pointer.set_pointer(
reinterpret_cast<uint64_t>(face_detector_model_file.data()));
face_detection_file_pointer.set_length(face_detector_model_file.size());
detector_options.mutable_base_options()
->mutable_model_asset()
->mutable_file_pointer_meta()
->Swap(&face_detection_file_pointer);
detector_options.set_num_faces(1);
// Set face landmarks model.
MP_ASSIGN_OR_RETURN(auto face_landmarks_model_file,
model_bundle->GetFile(kFaceLandmarksDetectorTFLiteName));
core::proto::FilePointerMeta face_landmarks_detector_file_pointer;
face_landmarks_detector_file_pointer.set_pointer(
reinterpret_cast<uint64_t>(face_landmarks_model_file.data()));
face_landmarks_detector_file_pointer.set_length(
face_landmarks_model_file.size());
landmarks_detector_options.mutable_base_options()
->mutable_model_asset()
->mutable_file_pointer_meta()
->Swap(&face_landmarks_detector_file_pointer);
// Track holistic face.
HolisticFaceTrackingRequest request;
MP_ASSIGN_OR_RETURN(
HolisticFaceTrackingOutput result,
TrackHolisticFace(image, face_landmarks_from_pose, detector_options,
landmarks_detector_options, request, graph));
auto face_landmarks =
result.landmarks.value().SetName(kFaceLandmarksOutStream);
auto image_size = GetImageSize(image, graph);
auto render_scale = utils::GetRenderScale(
image_size, result.debug_output.roi_from_pose, 0.0001, graph);
auto face_landmarks_render_data = utils::RenderLandmarks(
face_landmarks, render_scale, GetFaceRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {
face_landmarks_render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
face_landmarks >> graph.Out("FACE_LANDMARKS");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HolisticFaceTrackingTest : public ::testing::Test {};
TEST_F(HolisticFaceTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
::file::Defaults()));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
task_runner->Process(
{{kImageInStream, MakePacket<Image>(image)},
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
holistic_result.pose_landmarks())}}));
ASSERT_TRUE(output_packets.find(kFaceLandmarksOutStream) !=
output_packets.end());
auto face_landmarks = output_packets.find(kFaceLandmarksOutStream)
->second.Get<NormalizedLandmarkList>();
EXPECT_THAT(
face_landmarks,
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"holistic_face_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,272 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include <functional>
#include <optional>
#include <utility>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.h"
#include "mediapipe/calculators/util/align_hand_to_pose_in_world_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/api2/stream/threshold.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.pb.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::AlignHandToPoseInWorldCalculator;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::IsOverThreshold;
using ::mediapipe::api2::builder::ScaleAndShiftAndMakeSquareLong;
using ::mediapipe::api2::builder::SplitAndCombine;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::AllowIf;
struct HandLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<LandmarkList>> world_landmarks;
};
Stream<LandmarkList> AlignHandToPoseInWorldCalculator(
Stream<LandmarkList> hand_world_landmarks,
Stream<LandmarkList> pose_world_landmarks, int pose_wrist_idx,
Graph& graph) {
auto& node = graph.AddNode("AlignHandToPoseInWorldCalculator");
auto& opts = node.GetOptions<AlignHandToPoseInWorldCalculatorOptions>();
opts.set_hand_wrist_idx(0);
opts.set_pose_wrist_idx(pose_wrist_idx);
hand_world_landmarks.ConnectTo(
node[AlignHandToPoseInWorldCalculator::kInHandLandmarks]);
pose_world_landmarks.ConnectTo(
node[AlignHandToPoseInWorldCalculator::kInPoseLandmarks]);
return node[AlignHandToPoseInWorldCalculator::kOutHandLandmarks];
}
Stream<bool> GetPosePalmVisibility(
Stream<NormalizedLandmarkList> pose_palm_landmarks, Graph& graph) {
// Get wrist landmark.
auto pose_wrist = SplitAndCombine(pose_palm_landmarks, {0}, graph);
// Get visibility score.
auto& score_node = graph.AddNode("LandmarkVisibilityCalculator");
pose_wrist.ConnectTo(score_node.In("NORM_LANDMARKS"));
Stream<float> score = score_node.Out("VISIBILITY").Cast<float>();
// Convert score into flag.
return IsOverThreshold(score, /*threshold=*/0.1, graph);
}
Stream<NormalizedRect> GetHandRoiFromPosePalmLandmarks(
Stream<NormalizedLandmarkList> pose_palm_landmarks,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert pose palm landmarks to detection.
auto detection = ConvertLandmarksToDetection(pose_palm_landmarks, graph);
// Convert detection to rect.
auto& rect_node = graph.AddNode("HandDetectionsFromPoseToRectsCalculator");
detection.ConnectTo(rect_node.In("DETECTION"));
image_size.ConnectTo(rect_node.In("IMAGE_SIZE"));
Stream<NormalizedRect> rect =
rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
return ScaleAndShiftAndMakeSquareLong(rect, image_size,
/*scale_x_factor=*/2.7,
/*scale_y_factor=*/2.7, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
}
absl::StatusOr<Stream<NormalizedRect>> RefineHandRoi(
Stream<Image> image, Stream<NormalizedRect> roi,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinenement_graph_options,
Graph& graph) {
auto& hand_roi_refinement = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker.HandRoiRefinementGraph");
hand_roi_refinement
.GetOptions<hand_landmarker::proto::HandRoiRefinementGraphOptions>() =
hand_roi_refinenement_graph_options;
image >> hand_roi_refinement.In("IMAGE");
roi >> hand_roi_refinement.In("NORM_RECT");
return hand_roi_refinement.Out("NORM_RECT").Cast<NormalizedRect>();
}
Stream<NormalizedRect> TrackHandRoi(
Stream<NormalizedLandmarkList> prev_landmarks, Stream<NormalizedRect> roi,
Stream<std::pair<int, int>> image_size, Graph& graph) {
// Convert hand landmarks to tight rect.
auto& prev_rect_node = graph.AddNode("HandLandmarksToRectCalculator");
prev_landmarks.ConnectTo(prev_rect_node.In("NORM_LANDMARKS"));
image_size.ConnectTo(prev_rect_node.In("IMAGE_SIZE"));
Stream<NormalizedRect> prev_rect =
prev_rect_node.Out("NORM_RECT").Cast<NormalizedRect>();
// Convert tight hand rect to hand roi.
Stream<NormalizedRect> prev_roi =
ScaleAndShiftAndMakeSquareLong(prev_rect, image_size,
/*scale_x_factor=*/2.0,
/*scale_y_factor=*/2.0, /*shift_x=*/0,
/*shift_y=*/-0.1, graph);
auto& tracking_node = graph.AddNode("RoiTrackingCalculator");
auto& tracking_node_opts =
tracking_node.GetOptions<RoiTrackingCalculatorOptions>();
auto* rect_requirements = tracking_node_opts.mutable_rect_requirements();
rect_requirements->set_rotation_degrees(40.0);
rect_requirements->set_translation(0.2);
rect_requirements->set_scale(0.4);
auto* landmarks_requirements =
tracking_node_opts.mutable_landmarks_requirements();
landmarks_requirements->set_recrop_rect_margin(-0.1);
prev_landmarks.ConnectTo(tracking_node.In("PREV_LANDMARKS"));
prev_roi.ConnectTo(tracking_node.In("PREV_LANDMARKS_RECT"));
roi.ConnectTo(tracking_node.In("RECROP_RECT"));
image_size.ConnectTo(tracking_node.In("IMAGE_SIZE"));
return tracking_node.Out("TRACKING_RECT").Cast<NormalizedRect>();
}
HandLandmarksResult GetHandLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const HolisticHandTrackingRequest& request, Graph& graph) {
HandLandmarksResult result;
auto& hand_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker."
"SingleHandLandmarksDetectorGraph");
hand_landmarks_detector_graph
.GetOptions<hand_landmarker::proto::HandLandmarksDetectorGraphOptions>() =
hand_landmarks_detector_graph_options;
image >> hand_landmarks_detector_graph.In("IMAGE");
roi >> hand_landmarks_detector_graph.In("HAND_RECT");
if (request.landmarks) {
result.landmarks = hand_landmarks_detector_graph.Out("LANDMARKS")
.Cast<NormalizedLandmarkList>();
}
if (request.world_landmarks) {
result.world_landmarks =
hand_landmarks_detector_graph.Out("WORLD_LANDMARKS")
.Cast<LandmarkList>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
Stream<Image> image, Stream<NormalizedLandmarkList> pose_landmarks,
Stream<LandmarkList> pose_world_landmarks,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_graph_options,
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
Graph& graph) {
// Extracts pose palm landmarks.
Stream<NormalizedLandmarkList> pose_palm_landmarks = SplitAndCombine(
pose_landmarks,
{pose_indices.wrist_idx, pose_indices.pinky_idx, pose_indices.index_idx},
graph);
// Get pose palm visibility.
Stream<bool> is_pose_palm_visible =
GetPosePalmVisibility(pose_palm_landmarks, graph);
// Drop pose palm landmarks if pose palm is invisible.
pose_palm_landmarks =
AllowIf(pose_palm_landmarks, is_pose_palm_visible, graph);
// Extracts image size from the input images.
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
// Get hand ROI from pose palm landmarks.
Stream<NormalizedRect> roi_from_pose =
GetHandRoiFromPosePalmLandmarks(pose_palm_landmarks, image_size, graph);
// Refine hand ROI with re-crop model.
MP_ASSIGN_OR_RETURN(Stream<NormalizedRect> roi_from_recrop,
RefineHandRoi(image, roi_from_pose,
hand_roi_refinement_graph_options, graph));
// Loop for previous frame landmarks.
auto [prev_landmarks, set_prev_landmarks_fn] =
GetLoopbackData<NormalizedLandmarkList>(/*tick=*/image_size, graph);
// Track hand ROI.
auto tracking_roi =
TrackHandRoi(prev_landmarks, roi_from_recrop, image_size, graph);
// Predict hand landmarks.
auto landmarks_detection_result = GetHandLandmarksDetection(
image, tracking_roi, hand_landmarks_detector_graph_options, request,
graph);
// Set previous landmarks for ROI tracking.
set_prev_landmarks_fn(landmarks_detection_result.landmarks.value());
// Output landmarks.
std::optional<Stream<NormalizedLandmarkList>> hand_landmarks;
if (request.landmarks) {
hand_landmarks = landmarks_detection_result.landmarks;
}
// Output world landmarks.
std::optional<Stream<LandmarkList>> hand_world_landmarks;
if (request.world_landmarks) {
hand_world_landmarks = landmarks_detection_result.world_landmarks;
// Align hand world landmarks with pose world landmarks.
hand_world_landmarks = AlignHandToPoseInWorldCalculator(
hand_world_landmarks.value(), pose_world_landmarks,
pose_indices.wrist_idx, graph);
}
return {{.landmarks = hand_landmarks,
.world_landmarks = hand_world_landmarks,
.debug_output = {
.roi_from_pose = roi_from_pose,
.roi_from_recrop = roi_from_recrop,
.tracking_roi = tracking_roi,
}}};
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,94 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_
#include <optional>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
struct PoseIndices {
int wrist_idx;
int pinky_idx;
int index_idx;
};
struct HolisticHandTrackingRequest {
bool landmarks = false;
bool world_landmarks = false;
};
struct HolisticHandTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_pose;
api2::builder::Stream<mediapipe::NormalizedRect> roi_from_recrop;
api2::builder::Stream<mediapipe::NormalizedRect> tracking_roi;
};
DebugOutput debug_output;
};
// Updates @graph to track a single hand in @image based on pose landmarks.
//
// To track single hand this subgraph uses pose palm landmarks to obtain
// approximate hand location, refines it with re-crop model and then runs hand
// landmarks model. It can also reuse hand ROI from the previous frame if hand
// hasn't moved too much.
//
// @image - ImageFrame/GpuBuffer to track a single hand in.
// @pose_landmarks - Pose landmarks to derive initial hand location from.
// @pose_world_landmarks - Pose world landmarks to align hand world landmarks
// wrist with.
// @ hand_landmarks_detector_graph_options - Options of the
// HandLandmarksDetectorGraph used to detect the hand landmarks.
// @ hand_roi_refinement_graph_options - Options of HandRoiRefinementGraph used
// to refine the hand RoIs got from Pose landmarks.
// @request - object to request specific hand tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed.
// @graph - graph to update.
absl::StatusOr<HolisticHandTrackingOutput> TrackHolisticHand(
api2::builder::Stream<Image> image,
api2::builder::Stream<mediapipe::NormalizedLandmarkList> pose_landmarks,
api2::builder::Stream<mediapipe::LandmarkList> pose_world_landmarks,
const hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
const hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_graph_options,
const PoseIndices& pose_indices, const HolisticHandTrackingRequest& request,
mediapipe::api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_HAND_TRACKING_H_

View File

@ -0,0 +1,303 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.018;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHolisticHandTrackingLeft[] =
"holistic_hand_tracking_left_hand_graph.pbtxt";
constexpr char kTestImageFile[] = "male_full_height_hands.jpg";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr char kImageInStream[] = "image_in";
constexpr char kPoseLandmarksInStream[] = "pose_landmarks_in";
constexpr char kPoseWorldLandmarksInStream[] = "pose_world_landmarks_in";
constexpr char kLeftHandLandmarksOutStream[] = "left_hand_landmarks_out";
constexpr char kLeftHandWorldLandmarksOutStream[] =
"left_hand_world_landmarks_out";
constexpr char kRightHandLandmarksOutStream[] = "right_hand_landmarks_out";
constexpr char kRenderedImageOutStream[] = "rendered_image_out";
constexpr char kHandLandmarksModelFile[] = "hand_landmark_full.tflite";
constexpr char kHandRoiRefinementModelFile[] =
"handrecrop_2020_07_21_v0.f16.tflite";
std::string GetFilePath(const std::string& filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetHandRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : hand_landmarker::kHandConnections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
renderer_options.mutable_landmark_color()->set_r(255);
renderer_options.mutable_landmark_color()->set_g(255);
renderer_options.mutable_landmark_color()->set_b(255);
renderer_options.mutable_connection_color()->set_r(255);
renderer_options.mutable_connection_color()->set_g(255);
renderer_options.mutable_connection_color()->set_b(255);
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
void ConfigHandTrackingModelsOptions(
hand_landmarker::proto::HandLandmarksDetectorGraphOptions&
hand_landmarks_detector_graph_options,
hand_landmarker::proto::HandRoiRefinementGraphOptions&
hand_roi_refinement_options) {
hand_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandLandmarksModelFile));
hand_roi_refinement_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<mediapipe::LandmarkList> pose_world_landmarks =
graph.In("POSE_WORLD_LANDMARKS")
.Cast<mediapipe::LandmarkList>()
.SetName(kPoseWorldLandmarksInStream);
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
hand_landmarks_detector_options;
hand_landmarker::proto::HandRoiRefinementGraphOptions
hand_roi_refinement_options;
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
hand_roi_refinement_options);
HolisticHandTrackingRequest request;
request.landmarks = true;
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput left_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
/*index_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
request, graph));
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput right_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kRightPinky1),
/*index_idx=*/
static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightIndex1)},
request, graph));
auto image_size = GetImageSize(image, graph);
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
*left_hand_result.landmarks,
utils::GetRenderScale(image_size,
left_hand_result.debug_output.roi_from_pose, 0.0001,
graph),
GetHandRendererOptions(), graph);
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
*right_hand_result.landmarks,
utils::GetRenderScale(image_size,
right_hand_result.debug_output.roi_from_pose,
0.0001, graph),
GetHandRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {
left_hand_landmarks_render_data, right_hand_landmarks_render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
graph.Out("LEFT_HAND_LANDMARKS");
right_hand_result.landmarks->SetName(kRightHandLandmarksOutStream) >>
graph.Out("RIGHT_HAND_LANDMARKS");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HolisticHandTrackingTest : public ::testing::Test {};
TEST_F(HolisticHandTrackingTest, VerifyGraph) {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
Stream<mediapipe::NormalizedLandmarkList> pose_landmarks =
graph.In("POSE_LANDMARKS")
.Cast<mediapipe::NormalizedLandmarkList>()
.SetName(kPoseLandmarksInStream);
Stream<mediapipe::LandmarkList> pose_world_landmarks =
graph.In("POSE_WORLD_LANDMARKS")
.Cast<mediapipe::LandmarkList>()
.SetName(kPoseWorldLandmarksInStream);
hand_landmarker::proto::HandLandmarksDetectorGraphOptions
hand_landmarks_detector_options;
hand_landmarker::proto::HandRoiRefinementGraphOptions
hand_roi_refinement_options;
ConfigHandTrackingModelsOptions(hand_landmarks_detector_options,
hand_roi_refinement_options);
HolisticHandTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSERT_OK_AND_ASSIGN(
HolisticHandTrackingOutput left_hand_result,
TrackHolisticHand(
image, pose_landmarks, pose_world_landmarks,
hand_landmarks_detector_options, hand_roi_refinement_options,
PoseIndices{
/*wrist_idx=*/static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftWrist),
/*pinky_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftPinky1),
/*index_idx=*/
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftIndex1)},
request, graph));
left_hand_result.landmarks->SetName(kLeftHandLandmarksOutStream) >>
graph.Out("LEFT_HAND_LANDMARKS");
left_hand_result.world_landmarks->SetName(kLeftHandWorldLandmarksOutStream) >>
graph.Out("LEFT_HAND_WORLD_LANDMARKS");
// Read the expected graph config.
std::string expected_graph_contents;
MP_ASSERT_OK(file::GetContents(
file::JoinPath("./", kTestDataDirectory, kHolisticHandTrackingLeft),
&expected_graph_contents));
// Need to replace the expected graph config with the test srcdir, because
// each run has different test dir on TAP.
expected_graph_contents = absl::Substitute(
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
CalculatorGraphConfig expected_graph =
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
EXPECT_THAT(graph.GetConfig(), testing::proto::IgnoringRepeatedFieldOrdering(
testing::EqualsProto(expected_graph)));
}
TEST_F(HolisticHandTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
Defaults()));
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
task_runner->Process(
{{kImageInStream, MakePacket<Image>(image)},
{kPoseLandmarksInStream, MakePacket<NormalizedLandmarkList>(
holistic_result.pose_landmarks())},
{kPoseWorldLandmarksInStream,
MakePacket<LandmarkList>(
holistic_result.pose_world_landmarks())}}));
auto left_hand_landmarks = output_packets.at(kLeftHandLandmarksOutStream)
.Get<NormalizedLandmarkList>();
auto right_hand_landmarks = output_packets.at(kRightHandLandmarksOutStream)
.Get<NormalizedLandmarkList>();
EXPECT_THAT(left_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
/*margin=*/kAbsMargin));
EXPECT_THAT(
right_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image = output_packets.at(kRenderedImageOutStream).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"holistic_hand_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,521 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_face_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_hand_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_topology.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::metadata::SetExternalFile;
constexpr absl::string_view kHandLandmarksDetectorModelName =
"hand_landmarks_detector.tflite";
constexpr absl::string_view kHandRoiRefinementModelName =
"hand_roi_refinement.tflite";
constexpr absl::string_view kFaceDetectorModelName = "face_detector.tflite";
constexpr absl::string_view kFaceLandmarksDetectorModelName =
"face_landmarks_detector.tflite";
constexpr absl::string_view kFaceBlendshapesModelName =
"face_blendshapes.tflite";
constexpr absl::string_view kPoseDetectorModelName = "pose_detector.tflite";
constexpr absl::string_view kPoseLandmarksDetectorModelName =
"pose_landmarks_detector.tflite";
absl::Status SetGraphPoseOutputs(
const HolisticPoseTrackingRequest& pose_request,
const CalculatorGraphConfig::Node& node,
HolisticPoseTrackingOutput& pose_output, Graph& graph) {
// Main outputs.
if (pose_request.landmarks) {
RET_CHECK(pose_output.landmarks.has_value())
<< "POSE_LANDMARKS output is not supported.";
pose_output.landmarks->ConnectTo(graph.Out("POSE_LANDMARKS"));
}
if (pose_request.world_landmarks) {
RET_CHECK(pose_output.world_landmarks.has_value())
<< "POSE_WORLD_LANDMARKS output is not supported.";
pose_output.world_landmarks->ConnectTo(graph.Out("POSE_WORLD_LANDMARKS"));
}
if (pose_request.segmentation_mask) {
RET_CHECK(pose_output.segmentation_mask.has_value())
<< "POSE_SEGMENTATION_MASK output is not supported.";
pose_output.segmentation_mask->ConnectTo(
graph.Out("POSE_SEGMENTATION_MASK"));
}
// Debug outputs.
if (HasOutput(node, "POSE_AUXILIARY_LANDMARKS")) {
pose_output.debug_output.auxiliary_landmarks.ConnectTo(
graph.Out("POSE_AUXILIARY_LANDMARKS"));
}
if (HasOutput(node, "POSE_LANDMARKS_ROI")) {
pose_output.debug_output.roi_from_landmarks.ConnectTo(
graph.Out("POSE_LANDMARKS_ROI"));
}
return absl::OkStatus();
}
// Sets the base options in the sub tasks.
template <typename T>
absl::Status SetSubTaskBaseOptions(
const core::ModelAssetBundleResources* resources,
proto::HolisticLandmarkerGraphOptions* options, T* sub_task_options,
absl::string_view model_name, bool is_copy) {
if (!sub_task_options->base_options().has_model_asset()) {
MP_ASSIGN_OR_RETURN(const auto model_file_content,
resources->GetFile(std::string(model_name)));
SetExternalFile(
model_file_content,
sub_task_options->mutable_base_options()->mutable_model_asset(),
is_copy);
}
sub_task_options->mutable_base_options()->mutable_acceleration()->CopyFrom(
options->base_options().acceleration());
sub_task_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
sub_task_options->mutable_base_options()->set_gpu_origin(
options->base_options().gpu_origin());
return absl::OkStatus();
}
void SetGraphHandOutputs(bool is_left, const CalculatorGraphConfig::Node& node,
HolisticHandTrackingOutput& hand_output,
Graph& graph) {
const std::string hand_side = is_left ? "LEFT" : "RIGHT";
if (hand_output.landmarks) {
hand_output.landmarks->ConnectTo(graph.Out(hand_side + "_HAND_LANDMARKS"));
}
if (hand_output.world_landmarks) {
hand_output.world_landmarks->ConnectTo(
graph.Out(hand_side + "_HAND_WORLD_LANDMARKS"));
}
// Debug outputs.
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_POSE")) {
hand_output.debug_output.roi_from_pose.ConnectTo(
graph.Out(hand_side + "_HAND_ROI_FROM_POSE"));
}
if (HasOutput(node, hand_side + "_HAND_ROI_FROM_RECROP")) {
hand_output.debug_output.roi_from_recrop.ConnectTo(
graph.Out(hand_side + "_HAND_ROI_FROM_RECROP"));
}
if (HasOutput(node, hand_side + "_HAND_TRACKING_ROI")) {
hand_output.debug_output.tracking_roi.ConnectTo(
graph.Out(hand_side + "_HAND_TRACKING_ROI"));
}
}
void SetGraphFaceOutputs(const CalculatorGraphConfig::Node& node,
HolisticFaceTrackingOutput& face_output,
Graph& graph) {
if (face_output.landmarks) {
face_output.landmarks->ConnectTo(graph.Out("FACE_LANDMARKS"));
}
if (face_output.classifications) {
face_output.classifications->ConnectTo(graph.Out("FACE_BLENDSHAPES"));
}
// Face detection debug outputs
if (HasOutput(node, "FACE_ROI_FROM_POSE")) {
face_output.debug_output.roi_from_pose.ConnectTo(
graph.Out("FACE_ROI_FROM_POSE"));
}
if (HasOutput(node, "FACE_ROI_FROM_DETECTION")) {
face_output.debug_output.roi_from_detection.ConnectTo(
graph.Out("FACE_ROI_FROM_DETECTION"));
}
if (HasOutput(node, "FACE_TRACKING_ROI")) {
face_output.debug_output.tracking_roi.ConnectTo(
graph.Out("FACE_TRACKING_ROI"));
}
}
} // namespace
// Tracks pose and detects hands and face.
//
// NOTE: for GPU works only with image having GpuOrigin::TOP_LEFT
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
//
// Outputs:
// POSE_LANDMARKS - NormalizedLandmarkList
// 33 landmarks (see pose_landmarker/pose_topology.h)
// 0 - nose
// 1 - left eye (inner)
// 2 - left eye
// 3 - left eye (outer)
// 4 - right eye (inner)
// 5 - right eye
// 6 - right eye (outer)
// 7 - left ear
// 8 - right ear
// 9 - mouth (left)
// 10 - mouth (right)
// 11 - left shoulder
// 12 - right shoulder
// 13 - left elbow
// 14 - right elbow
// 15 - left wrist
// 16 - right wrist
// 17 - left pinky
// 18 - right pinky
// 19 - left index
// 20 - right index
// 21 - left thumb
// 22 - right thumb
// 23 - left hip
// 24 - right hip
// 25 - left knee
// 26 - right knee
// 27 - left ankle
// 28 - right ankle
// 29 - left heel
// 30 - right heel
// 31 - left foot index
// 32 - right foot index
// POSE_WORLD_LANDMARKS - LandmarkList
// World landmarks are real world 3D coordinates with origin in hips center
// and coordinates in meters. To understand the difference: POSE_LANDMARKS
// stream provides coordinates (in pixels) of 3D object projected on a 2D
// surface of the image (check on how perspective projection works), while
// POSE_WORLD_LANDMARKS stream provides coordinates (in meters) of the 3D
// object itself. POSE_WORLD_LANDMARKS has the same landmarks topology,
// visibility and presence as POSE_LANDMARKS.
// POSE_SEGMENTATION_MASK - Image
// Separates person from background. Mask is stored as gray float32 image
// with [0.0, 1.0] range for pixels (1 for person and 0 for background) on
// CPU and, on GPU - RGBA texture with R channel indicating person vs.
// background probability.
// LEFT_HAND_LANDMARKS - NormalizedLandmarkList
// 21 left hand landmarks.
// RIGHT_HAND_LANDMARKS - NormalizedLandmarkList
// 21 right hand landmarks.
// FACE_LANDMARKS - NormalizedLandmarkList
// 468 face landmarks.
// FACE_BLENDSHAPES - ClassificationList
// Supplementary blendshape coefficients that are predicted directly from
// the input image.
// LEFT_HAND_WORLD_LANDMARKS - LandmarkList
// 21 left hand world 3D landmarks.
// Hand landmarks are aligned with pose landmarks: translated so that wrist
// from # hand matches wrist from pose in pose coordinates system.
// RIGHT_HAND_WORLD_LANDMARKS - LandmarkList
// 21 right hand world 3D landmarks.
// Hand landmarks are aligned with pose landmarks: translated so that wrist
// from # hand matches wrist from pose in pose coordinates system.
// IMAGE - Image
// The input image that the hiolistic landmarker runs on and has the pixel
// data stored on the target storage (CPU vs GPU).
//
// Debug outputs:
// POSE_AUXILIARY_LANDMARKS - NormalizedLandmarkList
// TODO: Return ROI rather than auxiliary landmarks
// Auxiliary landmarks for deriving the ROI in the subsequent image.
// 0 - hidden center point
// 1 - hidden scale point
// POSE_LANDMARKS_ROI - NormalizedRect
// Region of interest calculated based on landmarks.
// LEFT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
// LEFT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
// LEFT_HAND_TRACKING_ROI - NormalizedLandmarkList
// RIGHT_HAND_ROI_FROM_POSE - NormalizedLandmarkList
// RIGHT_HAND_ROI_FROM_RECROP - NormalizedLandmarkList
// RIGHT_HAND_TRACKING_ROI - NormalizedLandmarkList
// FACE_ROI_FROM_POSE - NormalizedLandmarkList
// FACE_ROI_FROM_DETECTION - NormalizedLandmarkList
// FACE_TRACKING_ROI - NormalizedLandmarkList
//
// NOTE: failure is reported if some output has been requested, but specified
// model doesn't support it.
//
// NOTE: there will not be an output packet in an output stream for a
// particular timestamp if nothing is detected. However, the MediaPipe
// framework will internally inform the downstream calculators of the
// absence of this packet so that they don't wait for it unnecessarily.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"
// input_stream: "IMAGE:input_frames_image"
// output_stream: "POSE_LANDMARKS:pose_landmarks"
// output_stream: "POSE_WORLD_LANDMARKS:pose_world_landmarks"
// output_stream: "FACE_LANDMARKS:face_landmarks"
// output_stream: "FACE_BLENDSHAPES:extra_blendshapes"
// output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks"
// output_stream: "LEFT_HAND_WORLD_LANDMARKS:left_hand_world_landmarks"
// output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks"
// output_stream: "RIGHT_HAND_WORLD_LANDMARKS:right_hand_world_landmarks"
// node_options {
// [type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions]
// {
// base_options {
// model_asset {
// file_name:
// "mediapipe/tasks/testdata/vision/holistic_landmarker.task"
// }
// }
// face_detector_graph_options: {
// num_faces: 1
// }
// pose_detector_graph_options: {
// num_poses: 1
// }
// }
// }
// }
class HolisticLandmarkerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
const auto& holistic_node = sc->OriginalNode();
proto::HolisticLandmarkerGraphOptions* holistic_options =
sc->MutableOptions<proto::HolisticLandmarkerGraphOptions>();
const core::ModelAssetBundleResources* model_asset_bundle_resources;
if (holistic_options->base_options().has_model_asset()) {
MP_ASSIGN_OR_RETURN(model_asset_bundle_resources,
CreateModelAssetBundleResources<
proto::HolisticLandmarkerGraphOptions>(sc));
}
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
bool create_copy =
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable();
Stream<Image> image = graph.In("IMAGE").Cast<Image>();
// Check whether Hand requested
const bool is_left_hand_requested =
HasOutput(holistic_node, "LEFT_HAND_LANDMARKS");
const bool is_right_hand_requested =
HasOutput(holistic_node, "RIGHT_HAND_LANDMARKS");
const bool is_left_hand_world_requested =
HasOutput(holistic_node, "LEFT_HAND_WORLD_LANDMARKS");
const bool is_right_hand_world_requested =
HasOutput(holistic_node, "RIGHT_HAND_WORLD_LANDMARKS");
const bool hands_requested =
is_left_hand_requested || is_right_hand_requested ||
is_left_hand_world_requested || is_right_hand_world_requested;
if (hands_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_hand_landmarks_detector_graph_options(),
kHandLandmarksDetectorModelName, create_copy));
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_hand_roi_refinement_graph_options(),
kHandRoiRefinementModelName, create_copy));
}
// Check whether Face requested
const bool is_face_requested = HasOutput(holistic_node, "FACE_LANDMARKS");
const bool is_face_blendshapes_requested =
HasOutput(holistic_node, "FACE_BLENDSHAPES");
const bool face_requested =
is_face_requested || is_face_blendshapes_requested;
if (face_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_detector_graph_options(),
kFaceDetectorModelName, create_copy));
// Forcely set num_faces to 1, because holistic landmarker only supports a
// single subject for now.
holistic_options->mutable_face_detector_graph_options()->set_num_faces(1);
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_landmarks_detector_graph_options(),
kFaceLandmarksDetectorModelName, create_copy));
if (is_face_blendshapes_requested) {
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_face_landmarks_detector_graph_options()
->mutable_face_blendshapes_graph_options(),
kFaceBlendshapesModelName, create_copy));
}
}
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_pose_detector_graph_options(),
kPoseDetectorModelName, create_copy));
// Forcely set num_poses to 1, because holistic landmarker sonly supports a
// single subject for now.
holistic_options->mutable_pose_detector_graph_options()->set_num_poses(1);
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
model_asset_bundle_resources, holistic_options,
holistic_options->mutable_pose_landmarks_detector_graph_options(),
kPoseLandmarksDetectorModelName, create_copy));
HolisticPoseTrackingRequest pose_request = {
.landmarks = HasOutput(holistic_node, "POSE_LANDMARKS") ||
hands_requested || face_requested,
.world_landmarks =
HasOutput(holistic_node, "POSE_WORLD_LANDMARKS") || hands_requested,
.segmentation_mask =
HasOutput(holistic_node, "POSE_SEGMENTATION_MASK")};
// Detect and track pose.
MP_ASSIGN_OR_RETURN(
HolisticPoseTrackingOutput pose_output,
TrackHolisticPose(
image, holistic_options->pose_detector_graph_options(),
holistic_options->pose_landmarks_detector_graph_options(),
pose_request, graph));
MP_RETURN_IF_ERROR(
SetGraphPoseOutputs(pose_request, holistic_node, pose_output, graph));
// Detect and track hand.
if (hands_requested) {
if (is_left_hand_requested || is_left_hand_world_requested) {
RET_CHECK(pose_output.landmarks.has_value());
RET_CHECK(pose_output.world_landmarks.has_value());
PoseIndices pose_indices = {
.wrist_idx =
static_cast<int>(pose_landmarker::PoseLandmarkName::kLeftWrist),
.pinky_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftPinky1),
.index_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kLeftIndex1),
};
HolisticHandTrackingRequest hand_request = {
.landmarks = is_left_hand_requested,
.world_landmarks = is_left_hand_world_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput hand_output,
TrackHolisticHand(
image, *pose_output.landmarks, *pose_output.world_landmarks,
holistic_options->hand_landmarks_detector_graph_options(),
holistic_options->hand_roi_refinement_graph_options(),
pose_indices, hand_request, graph
));
SetGraphHandOutputs(/*is_left=*/true, holistic_node, hand_output,
graph);
}
if (is_right_hand_requested || is_right_hand_world_requested) {
RET_CHECK(pose_output.landmarks.has_value());
RET_CHECK(pose_output.world_landmarks.has_value());
PoseIndices pose_indices = {
.wrist_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightWrist),
.pinky_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightPinky1),
.index_idx = static_cast<int>(
pose_landmarker::PoseLandmarkName::kRightIndex1),
};
HolisticHandTrackingRequest hand_request = {
.landmarks = is_right_hand_requested,
.world_landmarks = is_right_hand_world_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticHandTrackingOutput hand_output,
TrackHolisticHand(
image, *pose_output.landmarks, *pose_output.world_landmarks,
holistic_options->hand_landmarks_detector_graph_options(),
holistic_options->hand_roi_refinement_graph_options(),
pose_indices, hand_request, graph
));
SetGraphHandOutputs(/*is_left=*/false, holistic_node, hand_output,
graph);
}
}
// Detect and track face.
if (face_requested) {
RET_CHECK(pose_output.landmarks.has_value());
Stream<mediapipe::NormalizedLandmarkList> face_landmarks_from_pose =
api2::builder::SplitToRanges(*pose_output.landmarks, {{0, 11}},
graph)[0];
HolisticFaceTrackingRequest face_request = {
.classifications = is_face_blendshapes_requested,
};
MP_ASSIGN_OR_RETURN(
HolisticFaceTrackingOutput face_output,
TrackHolisticFace(
image, face_landmarks_from_pose,
holistic_options->face_detector_graph_options(),
holistic_options->face_landmarks_detector_graph_options(),
face_request, graph));
SetGraphFaceOutputs(holistic_node, face_output, graph);
}
auto& pass_through = graph.AddNode("PassThroughCalculator");
image >> pass_through.In("");
pass_through.Out("") >> graph.Out("IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::holistic_landmarker::HolisticLandmarkerGraph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,595 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/googletest.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.025;
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/vision/";
constexpr char kHolisticResultFile[] =
"male_full_height_hands_result_cpu.pbtxt";
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
constexpr absl::string_view kImageInStream = "image_in";
constexpr absl::string_view kLeftHandLandmarksStream = "left_hand_landmarks";
constexpr absl::string_view kRightHandLandmarksStream = "right_hand_landmarks";
constexpr absl::string_view kFaceLandmarksStream = "face_landmarks";
constexpr absl::string_view kFaceBlendshapesStream = "face_blendshapes";
constexpr absl::string_view kPoseLandmarksStream = "pose_landmarks";
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
constexpr absl::string_view kPoseSegmentationMaskStream =
"pose_segmentation_mask";
constexpr absl::string_view kHolisticLandmarkerModelBundleFile =
"holistic_landmarker.task";
constexpr absl::string_view kHandLandmarksModelFile =
"hand_landmark_full.tflite";
constexpr absl::string_view kHandRoiRefinementModelFile =
"handrecrop_2020_07_21_v0.f16.tflite";
constexpr absl::string_view kPoseDetectionModelFile = "pose_detection.tflite";
constexpr absl::string_view kPoseLandmarksModelFile =
"pose_landmark_lite.tflite";
constexpr absl::string_view kFaceDetectionModelFile =
"face_detection_short_range.tflite";
constexpr absl::string_view kFaceLandmarksModelFile =
"facemesh2_lite_iris_faceflag_2023_02_14.tflite";
constexpr absl::string_view kFaceBlendshapesModelFile =
"face_blendshapes.tflite";
enum RenderPart {
HAND = 0,
POSE = 1,
FACE = 2,
};
mediapipe::Color GetColor(RenderPart render_part) {
mediapipe::Color color;
switch (render_part) {
case HAND:
color.set_b(255);
color.set_g(255);
color.set_r(255);
break;
case POSE:
color.set_b(0);
color.set_g(255);
color.set_r(0);
break;
case FACE:
color.set_b(0);
color.set_g(0);
color.set_r(255);
break;
}
return color;
}
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
template <std::size_t N>
mediapipe::LandmarksToRenderDataCalculatorOptions GetRendererOptions(
const std::array<std::array<int, 2>, N>& connections,
mediapipe::Color color) {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : connections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
*renderer_options.mutable_landmark_color() = color;
*renderer_options.mutable_connection_color() = color;
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
void ConfigureHandProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
options.mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandLandmarksModelFile));
options.mutable_hand_roi_refinement_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kHandRoiRefinementModelFile));
}
void ConfigureFaceProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
// Set face detection model.
face_detector::proto::FaceDetectorGraphOptions& face_detector_graph_options =
*options.mutable_face_detector_graph_options();
face_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceDetectionModelFile));
face_detector_graph_options.set_num_faces(1);
// Set face landmarks model.
face_landmarker::proto::FaceLandmarksDetectorGraphOptions&
face_landmarks_graph_options =
*options.mutable_face_landmarks_detector_graph_options();
face_landmarks_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceLandmarksModelFile));
face_landmarks_graph_options.mutable_face_blendshapes_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kFaceBlendshapesModelFile));
}
void ConfigurePoseProtoOptions(proto::HolisticLandmarkerGraphOptions& options) {
pose_detector::proto::PoseDetectorGraphOptions& pose_detector_graph_options =
*options.mutable_pose_detector_graph_options();
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kPoseDetectionModelFile));
pose_detector_graph_options.set_num_poses(1);
options.mutable_pose_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath(kPoseLandmarksModelFile));
}
struct HolisticRequest {
bool is_left_hand_requested = false;
bool is_right_hand_requested = false;
bool is_face_requested = false;
bool is_face_blendshapes_requested = false;
};
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner(
bool use_model_bundle, HolisticRequest holistic_request) {
Graph graph;
Stream<Image> image = graph.In("IMAEG").Cast<Image>().SetName(kImageInStream);
auto& holistic_graph = graph.AddNode(
"mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph");
proto::HolisticLandmarkerGraphOptions& options =
holistic_graph.GetOptions<proto::HolisticLandmarkerGraphOptions>();
if (use_model_bundle) {
options.mutable_base_options()->mutable_model_asset()->set_file_name(
GetFilePath(kHolisticLandmarkerModelBundleFile));
} else {
ConfigureHandProtoOptions(options);
ConfigurePoseProtoOptions(options);
ConfigureFaceProtoOptions(options);
}
std::vector<Stream<mediapipe::RenderData>> render_list;
image >> holistic_graph.In("IMAGE");
Stream<std::pair<int, int>> image_size = GetImageSize(image, graph);
if (holistic_request.is_left_hand_requested) {
Stream<NormalizedLandmarkList> left_hand_landmarks =
holistic_graph.Out("LEFT_HAND_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kLeftHandLandmarksStream);
Stream<NormalizedRect> left_hand_tracking_roi =
holistic_graph.Out("LEFT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
auto left_hand_landmarks_render_data = utils::RenderLandmarks(
left_hand_landmarks,
utils::GetRenderScale(image_size, left_hand_tracking_roi, 0.0001,
graph),
GetRendererOptions(hand_landmarker::kHandConnections,
GetColor(RenderPart::HAND)),
graph);
render_list.push_back(left_hand_landmarks_render_data);
left_hand_landmarks >> graph.Out("LEFT_HAND_LANDMARKS");
}
if (holistic_request.is_right_hand_requested) {
Stream<NormalizedLandmarkList> right_hand_landmarks =
holistic_graph.Out("RIGHT_HAND_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kRightHandLandmarksStream);
Stream<NormalizedRect> right_hand_tracking_roi =
holistic_graph.Out("RIGHT_HAND_TRACKING_ROI").Cast<NormalizedRect>();
auto right_hand_landmarks_render_data = utils::RenderLandmarks(
right_hand_landmarks,
utils::GetRenderScale(image_size, right_hand_tracking_roi, 0.0001,
graph),
GetRendererOptions(hand_landmarker::kHandConnections,
GetColor(RenderPart::HAND)),
graph);
render_list.push_back(right_hand_landmarks_render_data);
right_hand_landmarks >> graph.Out("RIGHT_HAND_LANDMARKS");
}
if (holistic_request.is_face_requested) {
Stream<NormalizedLandmarkList> face_landmarks =
holistic_graph.Out("FACE_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kFaceLandmarksStream);
Stream<NormalizedRect> face_tracking_roi =
holistic_graph.Out("FACE_TRACKING_ROI").Cast<NormalizedRect>();
auto face_landmarks_render_data = utils::RenderLandmarks(
face_landmarks,
utils::GetRenderScale(image_size, face_tracking_roi, 0.0001, graph),
GetRendererOptions(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksConnectors,
GetColor(RenderPart::FACE)),
graph);
render_list.push_back(face_landmarks_render_data);
face_landmarks >> graph.Out("FACE_LANDMARKS");
}
if (holistic_request.is_face_blendshapes_requested) {
Stream<ClassificationList> face_blendshapes =
holistic_graph.Out("FACE_BLENDSHAPES")
.Cast<ClassificationList>()
.SetName(kFaceBlendshapesStream);
face_blendshapes >> graph.Out("FACE_BLENDSHAPES");
}
Stream<NormalizedLandmarkList> pose_landmarks =
holistic_graph.Out("POSE_LANDMARKS")
.Cast<NormalizedLandmarkList>()
.SetName(kPoseLandmarksStream);
Stream<NormalizedRect> pose_tracking_roi =
holistic_graph.Out("POSE_LANDMARKS_ROI").Cast<NormalizedRect>();
Stream<Image> pose_segmentation_mask =
holistic_graph.Out("POSE_SEGMENTATION_MASK")
.Cast<Image>()
.SetName(kPoseSegmentationMaskStream);
auto pose_landmarks_render_data = utils::RenderLandmarks(
pose_landmarks,
utils::GetRenderScale(image_size, pose_tracking_roi, 0.0001, graph),
GetRendererOptions(pose_landmarker::kPoseLandmarksConnections,
GetColor(RenderPart::POSE)),
graph);
render_list.push_back(pose_landmarks_render_data);
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
pose_landmarks >> graph.Out("POSE_LANDMARKS");
pose_segmentation_mask >> graph.Out("POSE_SEGMENTATION_MASK");
rendered_image >> graph.Out("RENDERED_IMAGE");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
template <typename T>
absl::StatusOr<T> FetchResult(const core::PacketMap& output_packets,
absl::string_view stream_name) {
auto it = output_packets.find(std::string(stream_name));
RET_CHECK(it != output_packets.end());
return it->second.Get<T>();
}
// Remove fields not to be checked in the result, since the model
// generating expected result is different from the testing model.
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
for (auto& landmark :
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_face_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_left_hand_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
for (auto& landmark :
*holistic_result.mutable_right_hand_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
}
std::string RequestToString(HolisticRequest request) {
return absl::StrFormat(
"%s_%s_%s_%s",
request.is_left_hand_requested ? "left_hand" : "no_left_hand",
request.is_right_hand_requested ? "right_hand" : "no_right_hand",
request.is_face_requested ? "face" : "no_face",
request.is_face_blendshapes_requested ? "face_blendshapes"
: "no_face_blendshapes");
}
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of test image.
std::string test_image_name;
// Whether to use holistic model bundle to test.
bool use_model_bundle;
// Requests of holistic parts.
HolisticRequest holistic_request;
};
class SmokeTest : public testing::TestWithParam<TestParams> {};
TEST_P(SmokeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(GetFilePath(GetParam().test_image_name)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
::file::Defaults()));
RemoveUncheckedResult(holistic_result);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner,
CreateTaskRunner(GetParam().use_model_bundle,
GetParam().holistic_request));
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
task_runner->Process({{std::string(kImageInStream),
MakePacket<Image>(image)}}));
// Check face landmarks
if (GetParam().holistic_request.is_face_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto face_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kFaceLandmarksStream));
EXPECT_THAT(
face_landmarks,
Approximately(Partially(EqualsProto(holistic_result.face_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(output_packets.contains(std::string(kFaceLandmarksStream)));
}
if (GetParam().holistic_request.is_face_blendshapes_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto face_blendshapes,
FetchResult<ClassificationList>(
output_packets, kFaceBlendshapesStream));
EXPECT_THAT(face_blendshapes,
Approximately(
Partially(EqualsProto(holistic_result.face_blendshapes())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(output_packets.contains(std::string(kFaceBlendshapesStream)));
}
// Check Pose landmarks
MP_ASSERT_OK_AND_ASSIGN(auto pose_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kPoseLandmarksStream));
EXPECT_THAT(
pose_landmarks,
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
/*margin=*/kAbsMargin));
// Check Hand landmarks
if (GetParam().holistic_request.is_left_hand_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto left_hand_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kLeftHandLandmarksStream));
EXPECT_THAT(
left_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.left_hand_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(
output_packets.contains(std::string(kLeftHandLandmarksStream)));
}
if (GetParam().holistic_request.is_right_hand_requested) {
MP_ASSERT_OK_AND_ASSIGN(auto right_hand_landmarks,
FetchResult<NormalizedLandmarkList>(
output_packets, kRightHandLandmarksStream));
EXPECT_THAT(
right_hand_landmarks,
Approximately(
Partially(EqualsProto(holistic_result.right_hand_landmarks())),
/*margin=*/kAbsMargin));
} else {
ASSERT_FALSE(
output_packets.contains(std::string(kRightHandLandmarksStream)));
}
auto rendered_image =
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(
*rendered_image.GetImageFrameSharedPtr(),
absl::StrCat("holistic_landmark_",
RequestToString(GetParam().holistic_request))));
auto pose_segmentation_mask =
output_packets.at(std::string(kPoseSegmentationMaskStream)).Get<Image>();
cv::Mat matting_mask = mediapipe::formats::MatView(
pose_segmentation_mask.GetImageFrameSharedPtr().get());
cv::Mat visualized_mask;
matting_mask.convertTo(visualized_mask, CV_8UC1, 255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(
SavePngTestOutput(visualized_image, "holistic_pose_segmentation_mask"));
}
INSTANTIATE_TEST_SUITE_P(
HolisticLandmarkerGraphTest, SmokeTest,
Values(TestParams{
/* test_name= */ "UseModelBundle",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "UseSeparateModelFiles",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ false,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoLeftHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ false,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoRightHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ false,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoHand",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ false,
/*is_right_hand_requested= */ false,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ true,
},
},
TestParams{
/* test_name= */ "ModelBundleNoFace",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ false,
/*is_face_blendshapes_requested= */ false,
},
},
TestParams{
/* test_name= */ "ModelBundleNoFaceBlendshapes",
/* test_image_name= */ std::string(kTestImageFile),
/* use_model_bundle= */ true,
/* holistic_request= */
{
/*is_left_hand_requested= */ true,
/*is_right_hand_requested= */ true,
/*is_face_requested= */ true,
/*is_face_blendshapes_requested= */ false,
},
}),
[](const TestParamInfo<SmokeTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,307 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/landmarks_to_detection.h"
#include "mediapipe/framework/api2/stream/loopback.h"
#include "mediapipe/framework/api2/stream/merge.h"
#include "mediapipe/framework/api2/stream/presence.h"
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include "mediapipe/framework/api2/stream/segmentation_smoothing.h"
#include "mediapipe/framework/api2/stream/smoothing.h"
#include "mediapipe/framework/api2/stream/split.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionsToRect;
using ::mediapipe::api2::builder::ConvertAlignmentPointsDetectionToRect;
using ::mediapipe::api2::builder::ConvertLandmarksToDetection;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::GetLoopbackData;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::IsPresent;
using ::mediapipe::api2::builder::Merge;
using ::mediapipe::api2::builder::ScaleAndMakeSquare;
using ::mediapipe::api2::builder::SmoothLandmarks;
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
using ::mediapipe::api2::builder::SmoothSegmentationMask;
using ::mediapipe::api2::builder::SplitToRanges;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::components::utils::DisallowIf;
using Size = std::pair<int, int>;
constexpr int kAuxLandmarksStartKeypointIndex = 0;
constexpr int kAuxLandmarksEndKeypointIndex = 1;
constexpr float kAuxLandmarksTargetAngle = 90;
constexpr float kRoiFromDetectionScaleFactor = 1.25f;
constexpr float kRoiFromLandmarksScaleFactor = 1.25f;
Stream<NormalizedRect> CalculateRoiFromDetections(
Stream<std::vector<Detection>> detections, Stream<Size> image_size,
Graph& graph) {
auto roi = ConvertAlignmentPointsDetectionsToRect(detections, image_size,
/*start_keypoint_index=*/0,
/*end_keypoint_index=*/1,
/*target_angle=*/90, graph);
return ScaleAndMakeSquare(
roi, image_size, /*scale_x_factor=*/kRoiFromDetectionScaleFactor,
/*scale_y_factor=*/kRoiFromDetectionScaleFactor, graph);
}
Stream<NormalizedRect> CalculateScaleRoiFromAuxiliaryLandmarks(
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
Graph& graph) {
// TODO: consider calculating ROI directly from landmarks.
auto detection = ConvertLandmarksToDetection(landmarks, graph);
return ConvertAlignmentPointsDetectionToRect(
detection, image_size, kAuxLandmarksStartKeypointIndex,
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
}
Stream<NormalizedRect> CalculateRoiFromAuxiliaryLandmarks(
Stream<NormalizedLandmarkList> landmarks, Stream<Size> image_size,
Graph& graph) {
// TODO: consider calculating ROI directly from landmarks.
auto detection = ConvertLandmarksToDetection(landmarks, graph);
auto roi = ConvertAlignmentPointsDetectionToRect(
detection, image_size, kAuxLandmarksStartKeypointIndex,
kAuxLandmarksEndKeypointIndex, kAuxLandmarksTargetAngle, graph);
return ScaleAndMakeSquare(
roi, image_size, /*scale_x_factor=*/kRoiFromLandmarksScaleFactor,
/*scale_y_factor=*/kRoiFromLandmarksScaleFactor, graph);
}
struct PoseLandmarksResult {
std::optional<Stream<NormalizedLandmarkList>> landmarks;
std::optional<Stream<LandmarkList>> world_landmarks;
std::optional<Stream<NormalizedLandmarkList>> auxiliary_landmarks;
std::optional<Stream<Image>> segmentation_mask;
};
PoseLandmarksResult RunLandmarksDetection(
Stream<Image> image, Stream<NormalizedRect> roi,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
GenericNode& landmarks_graph = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"SinglePoseLandmarksDetectorGraph");
landmarks_graph
.GetOptions<pose_landmarker::proto::PoseLandmarksDetectorGraphOptions>() =
pose_landmarks_detector_graph_options;
image >> landmarks_graph.In("IMAGE");
roi >> landmarks_graph.In("NORM_RECT");
PoseLandmarksResult result;
if (request.landmarks) {
result.landmarks =
landmarks_graph.Out("LANDMARKS").Cast<NormalizedLandmarkList>();
result.auxiliary_landmarks = landmarks_graph.Out("AUXILIARY_LANDMARKS")
.Cast<NormalizedLandmarkList>();
}
if (request.world_landmarks) {
result.world_landmarks =
landmarks_graph.Out("WORLD_LANDMARKS").Cast<LandmarkList>();
}
if (request.segmentation_mask) {
result.segmentation_mask =
landmarks_graph.Out("SEGMENTATION_MASK").Cast<Image>();
}
return result;
}
} // namespace
absl::StatusOr<HolisticPoseTrackingOutput>
TrackHolisticPoseUsingCustomPoseDetection(
Stream<Image> image, PoseDetectionFn pose_detection_fn,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
// Calculate ROI from scratch (pose detection) or reuse one from the
// previous run if available.
auto [previous_roi, set_previous_roi_fn] =
GetLoopbackData<NormalizedRect>(/*tick=*/image, graph);
auto is_previous_roi_available = IsPresent(previous_roi, graph);
auto image_for_detection =
DisallowIf(image, is_previous_roi_available, graph);
MP_ASSIGN_OR_RETURN(auto pose_detections,
pose_detection_fn(image_for_detection, graph));
auto roi_from_detections = CalculateRoiFromDetections(
pose_detections, GetImageSize(image_for_detection, graph), graph);
// Take first non-empty.
auto roi = Merge(roi_from_detections, previous_roi, graph);
// Calculate landmarks and other outputs (if requested) in the specified ROI.
auto landmarks_detection_result = RunLandmarksDetection(
image, roi, pose_landmarks_detector_graph_options,
{
// Landmarks are required for tracking, hence force-requesting them.
.landmarks = true,
.world_landmarks = request.world_landmarks,
.segmentation_mask = request.segmentation_mask,
},
graph);
RET_CHECK(landmarks_detection_result.landmarks.has_value() &&
landmarks_detection_result.auxiliary_landmarks.has_value())
<< "Failed to calculate landmarks required for tracking.";
// Split landmarks to pose landmarks and auxiliary landmarks.
auto pose_landmarks_raw = *landmarks_detection_result.landmarks;
auto auxiliary_landmarks = *landmarks_detection_result.auxiliary_landmarks;
auto image_size = GetImageSize(image, graph);
// TODO: b/305750053 - Apply adaptive crop by adding AdaptiveCropCalculator.
// Calculate ROI from smoothed auxiliary landmarks.
auto scale_roi = CalculateScaleRoiFromAuxiliaryLandmarks(auxiliary_landmarks,
image_size, graph);
auto auxiliary_landmarks_smoothed = SmoothLandmarks(
auxiliary_landmarks, image_size, scale_roi,
{// Min cutoff 0.01 results into ~0.002 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.01,
// Beta 10.0 in combintation with min_cutoff 0.01 results into ~0.68
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 10.0,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0},
graph);
auto roi_from_auxiliary_landmarks = CalculateRoiFromAuxiliaryLandmarks(
auxiliary_landmarks_smoothed, image_size, graph);
// Make ROI from auxiliary landmarks to be used as "previous" ROI for a
// subsequent run.
set_previous_roi_fn(roi_from_auxiliary_landmarks);
// Populate and smooth pose landmarks if corresponding output has been
// requested.
std::optional<Stream<NormalizedLandmarkList>> pose_landmarks;
if (request.landmarks) {
pose_landmarks = SmoothLandmarksVisibility(
pose_landmarks_raw, /*low_pass_filter_alpha=*/0.1f, graph);
pose_landmarks = SmoothLandmarks(
*pose_landmarks, image_size, scale_roi,
{// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.05f,
// Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 80.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0f},
graph);
}
// Populate and smooth world landmarks if available.
std::optional<Stream<LandmarkList>> world_landmarks;
if (landmarks_detection_result.world_landmarks) {
world_landmarks = SplitToRanges(*landmarks_detection_result.world_landmarks,
/*ranges*/ {{0, 33}}, graph)[0];
world_landmarks = SmoothLandmarksVisibility(
*world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
world_landmarks = SmoothLandmarks(
*world_landmarks,
/*scale_roi=*/std::nullopt,
{// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter when
// landmark is static.
.min_cutoff = 0.1f,
// Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
// alpha in landmark EMA filter when landmark is moving fast.
.beta = 40.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark velocity
// EMA filter.
.derivate_cutoff = 1.0f},
graph);
}
// Populate and smooth segmentation mask if available.
std::optional<Stream<Image>> segmentation_mask;
if (landmarks_detection_result.segmentation_mask) {
auto mask = *landmarks_detection_result.segmentation_mask;
auto [prev_mask_as_img, set_prev_mask_as_img_fn] =
GetLoopbackData<mediapipe::Image>(
/*tick=*/*landmarks_detection_result.segmentation_mask, graph);
auto mask_smoothed =
SmoothSegmentationMask(mask, prev_mask_as_img,
/*combine_with_previous_ratio=*/0.7f, graph);
set_prev_mask_as_img_fn(mask_smoothed);
segmentation_mask = mask_smoothed;
}
return {{/*landmarks=*/pose_landmarks,
/*world_landmarks=*/world_landmarks,
/*segmentation_mask=*/segmentation_mask,
/*debug_output=*/
{/*auxiliary_landmarks=*/auxiliary_landmarks_smoothed,
/*roi_from_landmarks=*/roi_from_auxiliary_landmarks,
/*detections*/ pose_detections}}};
}
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
Stream<Image> image,
const pose_detector::proto::PoseDetectorGraphOptions&
pose_detector_graph_options,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, Graph& graph) {
PoseDetectionFn pose_detection_fn = [&pose_detector_graph_options](
Stream<Image> image, Graph& graph)
-> absl::StatusOr<Stream<std::vector<mediapipe::Detection>>> {
GenericNode& pose_detector =
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
pose_detector.GetOptions<pose_detector::proto::PoseDetectorGraphOptions>() =
pose_detector_graph_options;
image >> pose_detector.In("IMAGE");
return pose_detector.Out("DETECTIONS")
.Cast<std::vector<mediapipe::Detection>>();
};
return TrackHolisticPoseUsingCustomPoseDetection(
image, pose_detection_fn, pose_landmarks_detector_graph_options, request,
graph);
}
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,110 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
#define MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_
#include <functional>
#include <optional>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
// Type of pose detection function that can be used to customize pose tracking,
// by supplying the function into a corresponding `TrackPose` function overload.
//
// Function should update provided graph with node/nodes that accept image
// stream and produce stream of detections.
using PoseDetectionFn = std::function<
absl::StatusOr<api2::builder::Stream<std::vector<mediapipe::Detection>>>(
api2::builder::Stream<Image>, api2::builder::Graph&)>;
struct HolisticPoseTrackingRequest {
bool landmarks = false;
bool world_landmarks = false;
bool segmentation_mask = false;
};
struct HolisticPoseTrackingOutput {
std::optional<api2::builder::Stream<mediapipe::NormalizedLandmarkList>>
landmarks;
std::optional<api2::builder::Stream<mediapipe::LandmarkList>> world_landmarks;
std::optional<api2::builder::Stream<Image>> segmentation_mask;
struct DebugOutput {
api2::builder::Stream<mediapipe::NormalizedLandmarkList>
auxiliary_landmarks;
api2::builder::Stream<NormalizedRect> roi_from_landmarks;
api2::builder::Stream<std::vector<mediapipe::Detection>> detections;
};
DebugOutput debug_output;
};
// Updates @graph to track pose in @image.
//
// @image - ImageFrame/GpuBuffer to track pose in.
// @pose_detection_fn - pose detection function that takes @image as input and
// produces stream of pose detections.
// @pose_landmarks_detector_graph_options - options of the
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
// @request - object to request specific pose tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed all.
// @graph - graph to update.
absl::StatusOr<HolisticPoseTrackingOutput>
TrackHolisticPoseUsingCustomPoseDetection(
api2::builder::Stream<Image> image, PoseDetectionFn pose_detection_fn,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
// Updates @graph to track pose in @image.
//
// @image - ImageFrame/GpuBuffer to track pose in.
// @pose_detector_graph_options - options of the PoseDetectorGraph used to
// detect the pose.
// @pose_landmarks_detector_graph_options - options of the
// PoseLandmarksDetectorGraph used to detect the pose landmarks.
// @request - object to request specific pose tracking outputs.
// NOTE: Outputs that were not requested won't be returned and corresponding
// parts of the graph won't be genertaed all.
// @graph - graph to update.
absl::StatusOr<HolisticPoseTrackingOutput> TrackHolisticPose(
api2::builder::Stream<Image> image,
const pose_detector::proto::PoseDetectorGraphOptions&
pose_detector_graph_options,
const pose_landmarker::proto::PoseLandmarksDetectorGraphOptions&
pose_landmarks_detector_graph_options,
const HolisticPoseTrackingRequest& request, api2::builder::Graph& graph);
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HOLISTIC_LANDMARKER_HOLISTIC_POSE_TRACKING_H_

View File

@ -0,0 +1,243 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/holistic_landmarker/holistic_pose_tracking.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result.pb.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/data_renderer.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
#include "testing/base/public/googletest.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace holistic_landmarker {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::Image;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::TaskRunner;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr float kAbsMargin = 0.025;
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/vision/";
constexpr absl::string_view kTestImageFile = "male_full_height_hands.jpg";
constexpr absl::string_view kImageInStream = "image_in";
constexpr absl::string_view kPoseLandmarksOutStream = "pose_landmarks_out";
constexpr absl::string_view kPoseWorldLandmarksOutStream =
"pose_world_landmarks_out";
constexpr absl::string_view kRenderedImageOutStream = "rendered_image_out";
constexpr absl::string_view kHolisticResultFile =
"male_full_height_hands_result_cpu.pbtxt";
constexpr absl::string_view kHolisticPoseTrackingGraph =
"holistic_pose_tracking_graph.pbtxt";
std::string GetFilePath(absl::string_view filename) {
return file::JoinPath("./", kTestDataDirectory, filename);
}
mediapipe::LandmarksToRenderDataCalculatorOptions GetPoseRendererOptions() {
mediapipe::LandmarksToRenderDataCalculatorOptions renderer_options;
for (const auto& connection : pose_landmarker::kPoseLandmarksConnections) {
renderer_options.add_landmark_connections(connection[0]);
renderer_options.add_landmark_connections(connection[1]);
}
renderer_options.mutable_landmark_color()->set_r(255);
renderer_options.mutable_landmark_color()->set_g(255);
renderer_options.mutable_landmark_color()->set_b(255);
renderer_options.mutable_connection_color()->set_r(255);
renderer_options.mutable_connection_color()->set_g(255);
renderer_options.mutable_connection_color()->set_b(255);
renderer_options.set_thickness(0.5);
renderer_options.set_visualize_landmark_depth(false);
return renderer_options;
}
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<tasks::core::TaskRunner>> CreateTaskRunner() {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_detection.tflite"));
pose_detector_graph_options.set_num_poses(1);
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options;
pose_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
HolisticPoseTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSIGN_OR_RETURN(
HolisticPoseTrackingOutput result,
TrackHolisticPose(image, pose_detector_graph_options,
pose_landmarks_detector_graph_options, request, graph));
auto image_size = GetImageSize(image, graph);
auto render_data = utils::RenderLandmarks(
*result.landmarks,
utils::GetRenderScale(image_size, result.debug_output.roi_from_landmarks,
0.0001, graph),
GetPoseRendererOptions(), graph);
std::vector<Stream<mediapipe::RenderData>> render_list = {render_data};
auto rendered_image =
utils::Render(
image, absl::Span<Stream<mediapipe::RenderData>>(render_list), graph)
.SetName(kRenderedImageOutStream);
rendered_image >> graph.Out("RENDERED_IMAGE");
result.landmarks->SetName(kPoseLandmarksOutStream) >>
graph.Out("POSE_LANDMARKS");
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
graph.Out("POSE_WORLD_LANDMARKS");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
return TaskRunner::Create(
config, std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
// Remove fields not to be checked in the result, since the model
// generating expected result is different from the testing model.
void RemoveUncheckedResult(proto::HolisticResult& holistic_result) {
for (auto& landmark :
*holistic_result.mutable_pose_landmarks()->mutable_landmark()) {
landmark.clear_z();
landmark.clear_visibility();
landmark.clear_presence();
}
}
class HolisticPoseTrackingTest : public testing::Test {};
TEST_F(HolisticPoseTrackingTest, VerifyGraph) {
Graph graph;
Stream<Image> image = graph.In("IMAGE").Cast<Image>().SetName(kImageInStream);
pose_detector::proto::PoseDetectorGraphOptions pose_detector_graph_options;
pose_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_detection.tflite"));
pose_detector_graph_options.set_num_poses(1);
pose_landmarker::proto::PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options;
pose_landmarks_detector_graph_options.mutable_base_options()
->mutable_model_asset()
->set_file_name(GetFilePath("pose_landmark_lite.tflite"));
HolisticPoseTrackingRequest request;
request.landmarks = true;
request.world_landmarks = true;
MP_ASSERT_OK_AND_ASSIGN(
HolisticPoseTrackingOutput result,
TrackHolisticPose(image, pose_detector_graph_options,
pose_landmarks_detector_graph_options, request, graph));
result.landmarks->SetName(kPoseLandmarksOutStream) >>
graph.Out("POSE_LANDMARKS");
result.world_landmarks->SetName(kPoseWorldLandmarksOutStream) >>
graph.Out("POSE_WORLD_LANDMARKS");
auto config = graph.GetConfig();
core::FixGraphBackEdges(config);
// Read the expected graph config.
std::string expected_graph_contents;
MP_ASSERT_OK(file::GetContents(
file::JoinPath("./", kTestDataDirectory, kHolisticPoseTrackingGraph),
&expected_graph_contents));
// Need to replace the expected graph config with the test srcdir, because
// each run has different test dir on TAP.
expected_graph_contents = absl::Substitute(
expected_graph_contents, FLAGS_test_srcdir, FLAGS_test_srcdir);
CalculatorGraphConfig expected_graph =
ParseTextProtoOrDie<CalculatorGraphConfig>(expected_graph_contents);
EXPECT_THAT(config, testing::proto::IgnoringRepeatedFieldOrdering(
testing::EqualsProto(expected_graph)));
}
TEST_F(HolisticPoseTrackingTest, SmokeTest) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(GetFilePath(kTestImageFile)));
proto::HolisticResult holistic_result;
MP_ASSERT_OK(GetTextProto(GetFilePath(kHolisticResultFile), &holistic_result,
Defaults()));
RemoveUncheckedResult(holistic_result);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
MP_ASSERT_OK_AND_ASSIGN(auto output_packets,
task_runner->Process({{std::string(kImageInStream),
MakePacket<Image>(image)}}));
auto pose_landmarks = output_packets.at(std::string(kPoseLandmarksOutStream))
.Get<NormalizedLandmarkList>();
EXPECT_THAT(
pose_landmarks,
Approximately(Partially(EqualsProto(holistic_result.pose_landmarks())),
/*margin=*/kAbsMargin));
auto rendered_image =
output_packets.at(std::string(kRenderedImageOutStream)).Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*rendered_image.GetImageFrameSharedPtr(),
"pose_landmarks"));
}
} // namespace
} // namespace holistic_landmarker
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,44 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "holistic_result_proto",
srcs = ["holistic_result.proto"],
deps = [
"//mediapipe/framework/formats:classification_proto",
"//mediapipe/framework/formats:landmark_proto",
],
)
mediapipe_proto_library(
name = "holistic_landmarker_graph_options_proto",
srcs = ["holistic_landmarker_graph_options.proto"],
deps = [
"//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_proto",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.holistic_landmarker.proto;
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options.proto";
import "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker.proto";
option java_outer_classname = "HolisticLandmarkerGraphOptionsProto";
message HolisticLandmarkerGraphOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// asset bundle file with metadata, accelerator options, etc.
core.proto.BaseOptions base_options = 1;
// Options for hand landmarks graph.
hand_landmarker.proto.HandLandmarksDetectorGraphOptions
hand_landmarks_detector_graph_options = 2;
// Options for hand roi refinement graph.
hand_landmarker.proto.HandRoiRefinementGraphOptions
hand_roi_refinement_graph_options = 3;
// Options for face detector graph.
face_detector.proto.FaceDetectorGraphOptions face_detector_graph_options = 4;
// Options for face landmarks detector graph.
face_landmarker.proto.FaceLandmarksDetectorGraphOptions
face_landmarks_detector_graph_options = 5;
// Options for pose detector graph.
pose_detector.proto.PoseDetectorGraphOptions pose_detector_graph_options = 6;
// Options for pose landmarks detector graph.
pose_landmarker.proto.PoseLandmarksDetectorGraphOptions
pose_landmarks_detector_graph_options = 7;
}

View File

@ -0,0 +1,34 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.holistic_landmarker.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/framework/formats/landmark.proto";
option java_package = "com.google.mediapipe.tasks.vision.holisticlandmarker";
option java_outer_classname = "HolisticResultProto";
message HolisticResult {
mediapipe.NormalizedLandmarkList pose_landmarks = 1;
mediapipe.LandmarkList pose_world_landmarks = 7;
mediapipe.NormalizedLandmarkList left_hand_landmarks = 2;
mediapipe.NormalizedLandmarkList right_hand_landmarks = 3;
mediapipe.NormalizedLandmarkList face_landmarks = 4;
mediapipe.ClassificationList face_blendshapes = 6;
mediapipe.NormalizedLandmarkList auxiliary_landmarks = 5;
}