diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index 1b1b818ce..07cf793e9 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -97,6 +97,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/util:graph_builder_utils", "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc index 4c734c423..01c86c122 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc @@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; // limit the number of frames in flight. CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool enable_flow_limiting) { + bool enable_flow_limiting, bool output_segmentation_masks) { api2::builder::Graph graph; auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName); subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> - graph.Out(kSegmentationMaskTag); subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> graph.Out(kNormLandmarksTag); subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> @@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig( .SetName(kPoseAuxiliaryLandmarksStreamName) >> graph.Out(kPoseAuxiliaryLandmarksTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (output_segmentation_masks) { + subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> + graph.Out(kSegmentationMaskTag); + } if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); @@ -187,7 +189,8 @@ absl::StatusOr> PoseLandmarker::Create( PoseLandmarkerGraphOptionsProto>( CreateGraphConfig( std::move(options_proto), - options->running_mode == core::RunningMode::LIVE_STREAM), + options->running_mode == core::RunningMode::LIVE_STREAM, + options->output_segmentation_masks), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)))); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index ae3a7482e..456a6efd1 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs { Source> auxiliary_landmark_lists; Source> pose_rects_next_frame; Source> pose_detections; - Source> segmentation_masks; + std::optional>> segmentation_masks; Source image; }; @@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // input_stream: "IMAGE:image_in" // input_stream: "NORM_RECT:norm_rect" // output_stream: "NORM_LANDMARKS:pose_landmarks" -// output_stream: "LANDMARKS:world_landmarks" -// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks" +// output_stream: "WORLD_LANDMARKS:world_landmarks" +// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks" // output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame" // output_stream: "POSE_RECTS:pose_rects" // output_stream: "SEGMENTATION_MASK:segmentation_masks" @@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + bool output_segmentation_masks = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); if (sc->Options() .base_options() .has_model_asset()) { @@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN( - auto outs, - BuildPoseLandmarkerGraph( - *sc->MutableOptions(), - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto outs, + BuildPoseLandmarkerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + graph, output_segmentation_masks)); outs.landmark_lists >> graph[Output>(kNormLandmarksTag)]; outs.world_landmark_lists >> @@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { kAuxiliaryLandmarksTag)]; outs.pose_rects_next_frame >> graph[Output>(kPoseRectsNextFrameTag)]; - outs.segmentation_masks >> - graph[Output>(kSegmentationMaskTag)]; outs.pose_detections >> graph[Output>(kDetectionsTag)]; outs.image >> graph[Output(kImageTag)]; + if (outs.segmentation_masks) { + *outs.segmentation_masks >> + graph[Output>(kSegmentationMaskTag)]; + } // TODO remove when support is fixed. // As mediapipe GraphBuilder currently doesn't support configuring @@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { // graph: the mediapipe graph instance to be updated. absl::StatusOr BuildPoseLandmarkerGraph( PoseLandmarkerGraphOptions& tasks_options, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, Graph& graph, + bool output_segmentation_masks) { const int max_num_poses = tasks_options.pose_detector_graph_options().num_poses(); @@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { auto pose_rects_for_next_frame = pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag) .Cast>(); - auto segmentation_masks = - pose_landmarks_detector_graph.Out(kSegmentationMaskTag) - .Cast>(); + std::optional>> segmentation_masks; + if (output_segmentation_masks) { + segmentation_masks = + pose_landmarks_detector_graph.Out(kSegmentationMaskTag) + .Cast>(); + } if (tasks_options.base_options().use_stream_mode()) { auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator"); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc index c71fc2d58..f8488db02 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/util/graph_builder_utils.h" namespace mediapipe { namespace tasks { @@ -48,6 +49,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::api2::builder::Stream; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::pose_landmarker::proto:: PoseLandmarksDetectorGraphOptions; @@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs { Source pose_rect_next_frame; Source pose_presence; Source pose_presence_score; - Source segmentation_mask; + std::optional> segmentation_mask; }; struct PoseLandmarkerOutputs { @@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs { Source> pose_rects_next_frame; Source> presences; Source> presence_scores; - Source> segmentation_masks; + std::optional>> segmentation_masks; }; absl::Status SanityCheckOptions( @@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { + bool output_segmentation_mask = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); ASSIGN_OR_RETURN( const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN( - auto pose_landmark_detection_outs, - BuildSinglePoseLandmarksDetectorGraph( - sc->Options(), *model_resources, - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto pose_landmark_detection_outs, + BuildSinglePoseLandmarksDetectorGraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + graph, output_segmentation_mask)); pose_landmark_detection_outs.pose_landmarks >> graph[Output(kLandmarksTag)]; pose_landmark_detection_outs.world_pose_landmarks >> @@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { graph[Output(kPresenceTag)]; pose_landmark_detection_outs.pose_presence_score >> graph[Output(kPresenceScoreTag)]; - pose_landmark_detection_outs.segmentation_mask >> - graph[Output(kSegmentationMaskTag)]; + if (pose_landmark_detection_outs.segmentation_mask) { + *pose_landmark_detection_outs.segmentation_mask >> + graph[Output(kSegmentationMaskTag)]; + } return graph.GetConfig(); } @@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { BuildSinglePoseLandmarksDetectorGraph( const PoseLandmarksDetectorGraphOptions& subgraph_options, const ModelResources& model_resources, Source image_in, - Source pose_rect, Graph& graph) { + Source pose_rect, Graph& graph, + bool output_segmentation_mask) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); auto& preprocessing = graph.AddNode( @@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto raw_landmarks = tensors_to_landmarks[Output(kNormLandmarksTag)]; - // Decodes the segmentation tensor into a mask image with pixel values in - // [0, 1] (1 for person and 0 for background). - auto& tensors_to_segmentation = - graph.AddNode("TensorsToSegmentationCalculator"); - ConfigureTensorsToSegmentationCalculator( - &tensors_to_segmentation - .GetOptions()); - ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag); - auto raw_segmentation_mask = - tensors_to_segmentation[Output(kMaskTag)]; - // Refines landmarks with the heatmap tensor. auto& refine_landmarks_from_heatmap = graph.AddNode("RefineLandmarksFromHeatmapCalculator"); @@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto world_projected_landmarks = world_landmarks_projection.Out(kLandmarksTag).Cast(); - // Calculates the inverse transformation matrix. - auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); - matrix >> inverse_matrix.In(kMatrixTag); - auto inverted_matrix = inverse_matrix.Out(kMatrixTag); + std::optional> segmentation_mask; + if (output_segmentation_mask) { + // Decodes the segmentation tensor into a mask image with pixel values in + // [0, 1] (1 for person and 0 for background). + auto& tensors_to_segmentation = + graph.AddNode("TensorsToSegmentationCalculator"); + ConfigureTensorsToSegmentationCalculator( + &tensors_to_segmentation.GetOptions< + mediapipe::TensorsToSegmentationCalculatorOptions>()); + ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag); + auto raw_segmentation_mask = + tensors_to_segmentation[Output(kMaskTag)]; - // Projects the segmentation mask from the letterboxed ROI back to the full - // image. - auto& warp_affine = graph.AddNode("WarpAffineCalculator"); - ConfigureWarpAffineCalculator( - &warp_affine.GetOptions()); - image_size >> warp_affine.In(kOutputSizeTag); - inverted_matrix >> warp_affine.In(kMatrixTag); - raw_segmentation_mask >> warp_affine.In(kImageTag); - auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast(); + // Calculates the inverse transformation matrix. + auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); + matrix >> inverse_matrix.In(kMatrixTag); + auto inverted_matrix = inverse_matrix.Out(kMatrixTag); + + // Projects the segmentation mask from the letterboxed ROI back to the + // full image. + auto& warp_affine = graph.AddNode("WarpAffineCalculator"); + ConfigureWarpAffineCalculator( + &warp_affine.GetOptions()); + image_size >> warp_affine.In(kOutputSizeTag); + inverted_matrix >> warp_affine.In(kMatrixTag); + raw_segmentation_mask >> warp_affine.In(kImageTag); + segmentation_mask = warp_affine.Out(kImageTag).Cast(); + } // Calculate region of interest based on auxiliary landmarks, to be used // in the next frame. Consists of LandmarksToDetection + @@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { /* pose_rect_next_frame= */ pose_rect_next_frame, /* pose_presence= */ pose_presence, /* pose_presence_score= */ pose_presence_score, - /* segmentation_mask= */ projected_segmentation_mask, + /* segmentation_mask= */ segmentation_mask, }}; } }; @@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + bool output_segmentation_masks = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); ASSIGN_OR_RETURN( auto pose_landmark_detection_outputs, BuildPoseLandmarksDetectorGraph( sc->Options(), graph[Input(kImageTag)], - graph[Input>(kNormRectTag)], graph)); + graph[Input>(kNormRectTag)], graph, + output_segmentation_masks)); pose_landmark_detection_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; pose_landmark_detection_outputs.world_landmark_lists >> @@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { graph[Output>(kPresenceTag)]; pose_landmark_detection_outputs.presence_scores >> graph[Output>(kPresenceScoreTag)]; - pose_landmark_detection_outputs.segmentation_masks >> - graph[Output>(kSegmentationMaskTag)]; + if (pose_landmark_detection_outputs.segmentation_masks) { + *pose_landmark_detection_outputs.segmentation_masks >> + graph[Output>(kSegmentationMaskTag)]; + } return graph.GetConfig(); } @@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { absl::StatusOr BuildPoseLandmarksDetectorGraph( const PoseLandmarksDetectorGraphOptions& subgraph_options, Source image_in, - Source> multi_pose_rects, Graph& graph) { + Source> multi_pose_rects, Graph& graph, + bool output_segmentation_masks) { auto& begin_loop_multi_pose_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); image_in >> begin_loop_multi_pose_rects.In("CLONE"); @@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { pose_landmark_subgraph.Out(kPoseRectNextFrameTag); auto presence = pose_landmark_subgraph.Out(kPresenceTag); auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag); - auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag); auto& end_loop_landmarks = graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator"); @@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto presence_scores = end_loop_presence_score[Output>(kIterableTag)]; - auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator"); - batch_end >> end_loop_segmentation_mask.In(kBatchEndTag); - segmentation_mask >> end_loop_segmentation_mask.In(kItemTag); - auto segmentation_masks = - end_loop_segmentation_mask[Output>(kIterableTag)]; + std::optional>> segmentation_masks_vector; + if (output_segmentation_masks) { + auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag); + auto& end_loop_segmentation_mask = + graph.AddNode("EndLoopImageCalculator"); + batch_end >> end_loop_segmentation_mask.In(kBatchEndTag); + segmentation_mask >> end_loop_segmentation_mask.In(kItemTag); + segmentation_masks_vector = + end_loop_segmentation_mask[Output>(kIterableTag)]; + } return {{ /* landmark_lists= */ landmark_lists, @@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { /* pose_rects_next_frame= */ pose_rects_next_frame, /* presences= */ presences, /* presence_scores= */ presence_scores, - /* segmentation_masks= */ segmentation_masks, + /* segmentation_masks= */ segmentation_masks_vector, }}; } };