Internal change.
PiperOrigin-RevId: 524336899
This commit is contained in:
		
							parent
							
								
									534da98ccb
								
							
						
					
					
						commit
						27038f534a
					
				| 
						 | 
					@ -97,6 +97,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
					        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
 | 
					        "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
					        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
				
			||||||
 | 
					        "//mediapipe/util:graph_builder_utils",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
 | 
				
			||||||
// limit the number of frames in flight.
 | 
					// limit the number of frames in flight.
 | 
				
			||||||
CalculatorGraphConfig CreateGraphConfig(
 | 
					CalculatorGraphConfig CreateGraphConfig(
 | 
				
			||||||
    std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
 | 
					    std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
 | 
				
			||||||
    bool enable_flow_limiting) {
 | 
					    bool enable_flow_limiting, bool output_segmentation_masks) {
 | 
				
			||||||
  api2::builder::Graph graph;
 | 
					  api2::builder::Graph graph;
 | 
				
			||||||
  auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
 | 
					  auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
 | 
				
			||||||
  subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
 | 
					  subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
 | 
				
			||||||
  graph.In(kImageTag).SetName(kImageInStreamName);
 | 
					  graph.In(kImageTag).SetName(kImageInStreamName);
 | 
				
			||||||
  graph.In(kNormRectTag).SetName(kNormRectStreamName);
 | 
					  graph.In(kNormRectTag).SetName(kNormRectStreamName);
 | 
				
			||||||
  subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
 | 
					 | 
				
			||||||
      graph.Out(kSegmentationMaskTag);
 | 
					 | 
				
			||||||
  subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
 | 
					  subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
 | 
				
			||||||
      graph.Out(kNormLandmarksTag);
 | 
					      graph.Out(kNormLandmarksTag);
 | 
				
			||||||
  subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
 | 
					  subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
 | 
				
			||||||
| 
						 | 
					@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
				
			||||||
          .SetName(kPoseAuxiliaryLandmarksStreamName) >>
 | 
					          .SetName(kPoseAuxiliaryLandmarksStreamName) >>
 | 
				
			||||||
      graph.Out(kPoseAuxiliaryLandmarksTag);
 | 
					      graph.Out(kPoseAuxiliaryLandmarksTag);
 | 
				
			||||||
  subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
 | 
					  subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
 | 
				
			||||||
 | 
					  if (output_segmentation_masks) {
 | 
				
			||||||
 | 
					    subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
 | 
				
			||||||
 | 
					        graph.Out(kSegmentationMaskTag);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  if (enable_flow_limiting) {
 | 
					  if (enable_flow_limiting) {
 | 
				
			||||||
    return tasks::core::AddFlowLimiterCalculator(
 | 
					    return tasks::core::AddFlowLimiterCalculator(
 | 
				
			||||||
        graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
 | 
					        graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
 | 
				
			||||||
| 
						 | 
					@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
 | 
				
			||||||
                                          PoseLandmarkerGraphOptionsProto>(
 | 
					                                          PoseLandmarkerGraphOptionsProto>(
 | 
				
			||||||
          CreateGraphConfig(
 | 
					          CreateGraphConfig(
 | 
				
			||||||
              std::move(options_proto),
 | 
					              std::move(options_proto),
 | 
				
			||||||
              options->running_mode == core::RunningMode::LIVE_STREAM),
 | 
					              options->running_mode == core::RunningMode::LIVE_STREAM,
 | 
				
			||||||
 | 
					              options->output_segmentation_masks),
 | 
				
			||||||
          std::move(options->base_options.op_resolver), options->running_mode,
 | 
					          std::move(options->base_options.op_resolver), options->running_mode,
 | 
				
			||||||
          std::move(packets_callback))));
 | 
					          std::move(packets_callback))));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
 | 
				
			||||||
  Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
 | 
					  Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
 | 
				
			||||||
  Source<std::vector<NormalizedRect>> pose_rects_next_frame;
 | 
					  Source<std::vector<NormalizedRect>> pose_rects_next_frame;
 | 
				
			||||||
  Source<std::vector<Detection>> pose_detections;
 | 
					  Source<std::vector<Detection>> pose_detections;
 | 
				
			||||||
  Source<std::vector<Image>> segmentation_masks;
 | 
					  std::optional<Source<std::vector<Image>>> segmentation_masks;
 | 
				
			||||||
  Source<Image> image;
 | 
					  Source<Image> image;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
 | 
				
			||||||
//   input_stream: "IMAGE:image_in"
 | 
					//   input_stream: "IMAGE:image_in"
 | 
				
			||||||
//   input_stream: "NORM_RECT:norm_rect"
 | 
					//   input_stream: "NORM_RECT:norm_rect"
 | 
				
			||||||
//   output_stream: "NORM_LANDMARKS:pose_landmarks"
 | 
					//   output_stream: "NORM_LANDMARKS:pose_landmarks"
 | 
				
			||||||
//   output_stream: "LANDMARKS:world_landmarks"
 | 
					//   output_stream: "WORLD_LANDMARKS:world_landmarks"
 | 
				
			||||||
//   output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
 | 
					//   output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
 | 
				
			||||||
//   output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
 | 
					//   output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
 | 
				
			||||||
//   output_stream: "POSE_RECTS:pose_rects"
 | 
					//   output_stream: "POSE_RECTS:pose_rects"
 | 
				
			||||||
//   output_stream: "SEGMENTATION_MASK:segmentation_masks"
 | 
					//   output_stream: "SEGMENTATION_MASK:segmentation_masks"
 | 
				
			||||||
| 
						 | 
					@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
					  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
				
			||||||
      SubgraphContext* sc) override {
 | 
					      SubgraphContext* sc) override {
 | 
				
			||||||
    Graph graph;
 | 
					    Graph graph;
 | 
				
			||||||
 | 
					    bool output_segmentation_masks =
 | 
				
			||||||
 | 
					        HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
 | 
				
			||||||
    if (sc->Options<PoseLandmarkerGraphOptions>()
 | 
					    if (sc->Options<PoseLandmarkerGraphOptions>()
 | 
				
			||||||
            .base_options()
 | 
					            .base_options()
 | 
				
			||||||
            .has_model_asset()) {
 | 
					            .has_model_asset()) {
 | 
				
			||||||
| 
						 | 
					@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
          !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
 | 
					          !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
 | 
				
			||||||
               .IsAvailable()));
 | 
					               .IsAvailable()));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    ASSIGN_OR_RETURN(
 | 
					    ASSIGN_OR_RETURN(auto outs,
 | 
				
			||||||
        auto outs,
 | 
					                     BuildPoseLandmarkerGraph(
 | 
				
			||||||
        BuildPoseLandmarkerGraph(
 | 
					                         *sc->MutableOptions<PoseLandmarkerGraphOptions>(),
 | 
				
			||||||
            *sc->MutableOptions<PoseLandmarkerGraphOptions>(),
 | 
					                         graph[Input<Image>(kImageTag)],
 | 
				
			||||||
            graph[Input<Image>(kImageTag)],
 | 
					                         graph[Input<NormalizedRect>::Optional(kNormRectTag)],
 | 
				
			||||||
            graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
 | 
					                         graph, output_segmentation_masks));
 | 
				
			||||||
    outs.landmark_lists >>
 | 
					    outs.landmark_lists >>
 | 
				
			||||||
        graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
 | 
					        graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
 | 
				
			||||||
    outs.world_landmark_lists >>
 | 
					    outs.world_landmark_lists >>
 | 
				
			||||||
| 
						 | 
					@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
            kAuxiliaryLandmarksTag)];
 | 
					            kAuxiliaryLandmarksTag)];
 | 
				
			||||||
    outs.pose_rects_next_frame >>
 | 
					    outs.pose_rects_next_frame >>
 | 
				
			||||||
        graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
 | 
					        graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
 | 
				
			||||||
    outs.segmentation_masks >>
 | 
					 | 
				
			||||||
        graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
					 | 
				
			||||||
    outs.pose_detections >>
 | 
					    outs.pose_detections >>
 | 
				
			||||||
        graph[Output<std::vector<Detection>>(kDetectionsTag)];
 | 
					        graph[Output<std::vector<Detection>>(kDetectionsTag)];
 | 
				
			||||||
    outs.image >> graph[Output<Image>(kImageTag)];
 | 
					    outs.image >> graph[Output<Image>(kImageTag)];
 | 
				
			||||||
 | 
					    if (outs.segmentation_masks) {
 | 
				
			||||||
 | 
					      *outs.segmentation_masks >>
 | 
				
			||||||
 | 
					          graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO remove when support is fixed.
 | 
					    // TODO remove when support is fixed.
 | 
				
			||||||
    // As mediapipe GraphBuilder currently doesn't support configuring
 | 
					    // As mediapipe GraphBuilder currently doesn't support configuring
 | 
				
			||||||
| 
						 | 
					@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
  // graph: the mediapipe graph instance to be updated.
 | 
					  // graph: the mediapipe graph instance to be updated.
 | 
				
			||||||
  absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
 | 
					  absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
 | 
				
			||||||
      PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
 | 
					      PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
 | 
				
			||||||
      Source<NormalizedRect> norm_rect_in, Graph& graph) {
 | 
					      Source<NormalizedRect> norm_rect_in, Graph& graph,
 | 
				
			||||||
 | 
					      bool output_segmentation_masks) {
 | 
				
			||||||
    const int max_num_poses =
 | 
					    const int max_num_poses =
 | 
				
			||||||
        tasks_options.pose_detector_graph_options().num_poses();
 | 
					        tasks_options.pose_detector_graph_options().num_poses();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
    auto pose_rects_for_next_frame =
 | 
					    auto pose_rects_for_next_frame =
 | 
				
			||||||
        pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
 | 
					        pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
 | 
				
			||||||
            .Cast<std::vector<NormalizedRect>>();
 | 
					            .Cast<std::vector<NormalizedRect>>();
 | 
				
			||||||
    auto segmentation_masks =
 | 
					    std::optional<Source<std::vector<Image>>> segmentation_masks;
 | 
				
			||||||
        pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
 | 
					    if (output_segmentation_masks) {
 | 
				
			||||||
            .Cast<std::vector<Image>>();
 | 
					      segmentation_masks =
 | 
				
			||||||
 | 
					          pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
 | 
				
			||||||
 | 
					              .Cast<std::vector<Image>>();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (tasks_options.base_options().use_stream_mode()) {
 | 
					    if (tasks_options.base_options().use_stream_mode()) {
 | 
				
			||||||
      auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
 | 
					      auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,7 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
					#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
 | 
					#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
					#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
 | 
				
			||||||
 | 
					#include "mediapipe/util/graph_builder_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace tasks {
 | 
					namespace tasks {
 | 
				
			||||||
| 
						 | 
					@ -48,6 +49,7 @@ using ::mediapipe::api2::Input;
 | 
				
			||||||
using ::mediapipe::api2::Output;
 | 
					using ::mediapipe::api2::Output;
 | 
				
			||||||
using ::mediapipe::api2::builder::Graph;
 | 
					using ::mediapipe::api2::builder::Graph;
 | 
				
			||||||
using ::mediapipe::api2::builder::Source;
 | 
					using ::mediapipe::api2::builder::Source;
 | 
				
			||||||
 | 
					using ::mediapipe::api2::builder::Stream;
 | 
				
			||||||
using ::mediapipe::tasks::core::ModelResources;
 | 
					using ::mediapipe::tasks::core::ModelResources;
 | 
				
			||||||
using ::mediapipe::tasks::vision::pose_landmarker::proto::
 | 
					using ::mediapipe::tasks::vision::pose_landmarker::proto::
 | 
				
			||||||
    PoseLandmarksDetectorGraphOptions;
 | 
					    PoseLandmarksDetectorGraphOptions;
 | 
				
			||||||
| 
						 | 
					@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs {
 | 
				
			||||||
  Source<NormalizedRect> pose_rect_next_frame;
 | 
					  Source<NormalizedRect> pose_rect_next_frame;
 | 
				
			||||||
  Source<bool> pose_presence;
 | 
					  Source<bool> pose_presence;
 | 
				
			||||||
  Source<float> pose_presence_score;
 | 
					  Source<float> pose_presence_score;
 | 
				
			||||||
  Source<Image> segmentation_mask;
 | 
					  std::optional<Source<Image>> segmentation_mask;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct PoseLandmarkerOutputs {
 | 
					struct PoseLandmarkerOutputs {
 | 
				
			||||||
| 
						 | 
					@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs {
 | 
				
			||||||
  Source<std::vector<NormalizedRect>> pose_rects_next_frame;
 | 
					  Source<std::vector<NormalizedRect>> pose_rects_next_frame;
 | 
				
			||||||
  Source<std::vector<bool>> presences;
 | 
					  Source<std::vector<bool>> presences;
 | 
				
			||||||
  Source<std::vector<float>> presence_scores;
 | 
					  Source<std::vector<float>> presence_scores;
 | 
				
			||||||
  Source<std::vector<Image>> segmentation_masks;
 | 
					  std::optional<Source<std::vector<Image>>> segmentation_masks;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status SanityCheckOptions(
 | 
					absl::Status SanityCheckOptions(
 | 
				
			||||||
| 
						 | 
					@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
					  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
				
			||||||
      SubgraphContext* sc) override {
 | 
					      SubgraphContext* sc) override {
 | 
				
			||||||
 | 
					    bool output_segmentation_mask =
 | 
				
			||||||
 | 
					        HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
 | 
				
			||||||
    ASSIGN_OR_RETURN(
 | 
					    ASSIGN_OR_RETURN(
 | 
				
			||||||
        const auto* model_resources,
 | 
					        const auto* model_resources,
 | 
				
			||||||
        CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
 | 
					        CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
 | 
				
			||||||
    Graph graph;
 | 
					    Graph graph;
 | 
				
			||||||
    ASSIGN_OR_RETURN(
 | 
					    ASSIGN_OR_RETURN(auto pose_landmark_detection_outs,
 | 
				
			||||||
        auto pose_landmark_detection_outs,
 | 
					                     BuildSinglePoseLandmarksDetectorGraph(
 | 
				
			||||||
        BuildSinglePoseLandmarksDetectorGraph(
 | 
					                         sc->Options<PoseLandmarksDetectorGraphOptions>(),
 | 
				
			||||||
            sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
 | 
					                         *model_resources, graph[Input<Image>(kImageTag)],
 | 
				
			||||||
            graph[Input<Image>(kImageTag)],
 | 
					                         graph[Input<NormalizedRect>::Optional(kNormRectTag)],
 | 
				
			||||||
            graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
 | 
					                         graph, output_segmentation_mask));
 | 
				
			||||||
    pose_landmark_detection_outs.pose_landmarks >>
 | 
					    pose_landmark_detection_outs.pose_landmarks >>
 | 
				
			||||||
        graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
 | 
					        graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
 | 
				
			||||||
    pose_landmark_detection_outs.world_pose_landmarks >>
 | 
					    pose_landmark_detection_outs.world_pose_landmarks >>
 | 
				
			||||||
| 
						 | 
					@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        graph[Output<bool>(kPresenceTag)];
 | 
					        graph[Output<bool>(kPresenceTag)];
 | 
				
			||||||
    pose_landmark_detection_outs.pose_presence_score >>
 | 
					    pose_landmark_detection_outs.pose_presence_score >>
 | 
				
			||||||
        graph[Output<float>(kPresenceScoreTag)];
 | 
					        graph[Output<float>(kPresenceScoreTag)];
 | 
				
			||||||
    pose_landmark_detection_outs.segmentation_mask >>
 | 
					    if (pose_landmark_detection_outs.segmentation_mask) {
 | 
				
			||||||
        graph[Output<Image>(kSegmentationMaskTag)];
 | 
					      *pose_landmark_detection_outs.segmentation_mask >>
 | 
				
			||||||
 | 
					          graph[Output<Image>(kSegmentationMaskTag)];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return graph.GetConfig();
 | 
					    return graph.GetConfig();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
  BuildSinglePoseLandmarksDetectorGraph(
 | 
					  BuildSinglePoseLandmarksDetectorGraph(
 | 
				
			||||||
      const PoseLandmarksDetectorGraphOptions& subgraph_options,
 | 
					      const PoseLandmarksDetectorGraphOptions& subgraph_options,
 | 
				
			||||||
      const ModelResources& model_resources, Source<Image> image_in,
 | 
					      const ModelResources& model_resources, Source<Image> image_in,
 | 
				
			||||||
      Source<NormalizedRect> pose_rect, Graph& graph) {
 | 
					      Source<NormalizedRect> pose_rect, Graph& graph,
 | 
				
			||||||
 | 
					      bool output_segmentation_mask) {
 | 
				
			||||||
    MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
 | 
					    MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto& preprocessing = graph.AddNode(
 | 
					    auto& preprocessing = graph.AddNode(
 | 
				
			||||||
| 
						 | 
					@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
    auto raw_landmarks =
 | 
					    auto raw_landmarks =
 | 
				
			||||||
        tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
 | 
					        tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //  Decodes the segmentation tensor into a mask image with pixel values in
 | 
					 | 
				
			||||||
    //  [0, 1] (1 for person and 0 for background).
 | 
					 | 
				
			||||||
    auto& tensors_to_segmentation =
 | 
					 | 
				
			||||||
        graph.AddNode("TensorsToSegmentationCalculator");
 | 
					 | 
				
			||||||
    ConfigureTensorsToSegmentationCalculator(
 | 
					 | 
				
			||||||
        &tensors_to_segmentation
 | 
					 | 
				
			||||||
             .GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
 | 
					 | 
				
			||||||
    ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
 | 
					 | 
				
			||||||
    auto raw_segmentation_mask =
 | 
					 | 
				
			||||||
        tensors_to_segmentation[Output<Image>(kMaskTag)];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Refines landmarks with the heatmap tensor.
 | 
					    // Refines landmarks with the heatmap tensor.
 | 
				
			||||||
    auto& refine_landmarks_from_heatmap =
 | 
					    auto& refine_landmarks_from_heatmap =
 | 
				
			||||||
        graph.AddNode("RefineLandmarksFromHeatmapCalculator");
 | 
					        graph.AddNode("RefineLandmarksFromHeatmapCalculator");
 | 
				
			||||||
| 
						 | 
					@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
    auto world_projected_landmarks =
 | 
					    auto world_projected_landmarks =
 | 
				
			||||||
        world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
 | 
					        world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Calculates the inverse transformation matrix.
 | 
					    std::optional<Stream<Image>> segmentation_mask;
 | 
				
			||||||
    auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
 | 
					    if (output_segmentation_mask) {
 | 
				
			||||||
    matrix >> inverse_matrix.In(kMatrixTag);
 | 
					      //  Decodes the segmentation tensor into a mask image with pixel values in
 | 
				
			||||||
    auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
 | 
					      //  [0, 1] (1 for person and 0 for background).
 | 
				
			||||||
 | 
					      auto& tensors_to_segmentation =
 | 
				
			||||||
 | 
					          graph.AddNode("TensorsToSegmentationCalculator");
 | 
				
			||||||
 | 
					      ConfigureTensorsToSegmentationCalculator(
 | 
				
			||||||
 | 
					          &tensors_to_segmentation.GetOptions<
 | 
				
			||||||
 | 
					              mediapipe::TensorsToSegmentationCalculatorOptions>());
 | 
				
			||||||
 | 
					      ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
 | 
				
			||||||
 | 
					      auto raw_segmentation_mask =
 | 
				
			||||||
 | 
					          tensors_to_segmentation[Output<Image>(kMaskTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Projects the segmentation mask from the letterboxed ROI back to the full
 | 
					      // Calculates the inverse transformation matrix.
 | 
				
			||||||
    // image.
 | 
					      auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
 | 
				
			||||||
    auto& warp_affine = graph.AddNode("WarpAffineCalculator");
 | 
					      matrix >> inverse_matrix.In(kMatrixTag);
 | 
				
			||||||
    ConfigureWarpAffineCalculator(
 | 
					      auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
 | 
				
			||||||
        &warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
 | 
					
 | 
				
			||||||
    image_size >> warp_affine.In(kOutputSizeTag);
 | 
					      // Projects the segmentation mask from the letterboxed ROI back to the
 | 
				
			||||||
    inverted_matrix >> warp_affine.In(kMatrixTag);
 | 
					      // full image.
 | 
				
			||||||
    raw_segmentation_mask >> warp_affine.In(kImageTag);
 | 
					      auto& warp_affine = graph.AddNode("WarpAffineCalculator");
 | 
				
			||||||
    auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
 | 
					      ConfigureWarpAffineCalculator(
 | 
				
			||||||
 | 
					          &warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
 | 
				
			||||||
 | 
					      image_size >> warp_affine.In(kOutputSizeTag);
 | 
				
			||||||
 | 
					      inverted_matrix >> warp_affine.In(kMatrixTag);
 | 
				
			||||||
 | 
					      raw_segmentation_mask >> warp_affine.In(kImageTag);
 | 
				
			||||||
 | 
					      segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Calculate region of interest based on auxiliary landmarks, to be used
 | 
					    // Calculate region of interest based on auxiliary landmarks, to be used
 | 
				
			||||||
    // in the next frame. Consists of LandmarksToDetection +
 | 
					    // in the next frame. Consists of LandmarksToDetection +
 | 
				
			||||||
| 
						 | 
					@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        /* pose_rect_next_frame= */ pose_rect_next_frame,
 | 
					        /* pose_rect_next_frame= */ pose_rect_next_frame,
 | 
				
			||||||
        /* pose_presence= */ pose_presence,
 | 
					        /* pose_presence= */ pose_presence,
 | 
				
			||||||
        /* pose_presence_score= */ pose_presence_score,
 | 
					        /* pose_presence_score= */ pose_presence_score,
 | 
				
			||||||
        /* segmentation_mask= */ projected_segmentation_mask,
 | 
					        /* segmentation_mask= */ segmentation_mask,
 | 
				
			||||||
    }};
 | 
					    }};
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
| 
						 | 
					@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
					  absl::StatusOr<CalculatorGraphConfig> GetConfig(
 | 
				
			||||||
      SubgraphContext* sc) override {
 | 
					      SubgraphContext* sc) override {
 | 
				
			||||||
    Graph graph;
 | 
					    Graph graph;
 | 
				
			||||||
 | 
					    bool output_segmentation_masks =
 | 
				
			||||||
 | 
					        HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
 | 
				
			||||||
    ASSIGN_OR_RETURN(
 | 
					    ASSIGN_OR_RETURN(
 | 
				
			||||||
        auto pose_landmark_detection_outputs,
 | 
					        auto pose_landmark_detection_outputs,
 | 
				
			||||||
        BuildPoseLandmarksDetectorGraph(
 | 
					        BuildPoseLandmarksDetectorGraph(
 | 
				
			||||||
            sc->Options<PoseLandmarksDetectorGraphOptions>(),
 | 
					            sc->Options<PoseLandmarksDetectorGraphOptions>(),
 | 
				
			||||||
            graph[Input<Image>(kImageTag)],
 | 
					            graph[Input<Image>(kImageTag)],
 | 
				
			||||||
            graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
 | 
					            graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph,
 | 
				
			||||||
 | 
					            output_segmentation_masks));
 | 
				
			||||||
    pose_landmark_detection_outputs.landmark_lists >>
 | 
					    pose_landmark_detection_outputs.landmark_lists >>
 | 
				
			||||||
        graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
 | 
					        graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
 | 
				
			||||||
    pose_landmark_detection_outputs.world_landmark_lists >>
 | 
					    pose_landmark_detection_outputs.world_landmark_lists >>
 | 
				
			||||||
| 
						 | 
					@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        graph[Output<std::vector<bool>>(kPresenceTag)];
 | 
					        graph[Output<std::vector<bool>>(kPresenceTag)];
 | 
				
			||||||
    pose_landmark_detection_outputs.presence_scores >>
 | 
					    pose_landmark_detection_outputs.presence_scores >>
 | 
				
			||||||
        graph[Output<std::vector<float>>(kPresenceScoreTag)];
 | 
					        graph[Output<std::vector<float>>(kPresenceScoreTag)];
 | 
				
			||||||
    pose_landmark_detection_outputs.segmentation_masks >>
 | 
					    if (pose_landmark_detection_outputs.segmentation_masks) {
 | 
				
			||||||
        graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
					      *pose_landmark_detection_outputs.segmentation_masks >>
 | 
				
			||||||
 | 
					          graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return graph.GetConfig();
 | 
					    return graph.GetConfig();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
  absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
 | 
					  absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
 | 
				
			||||||
      const PoseLandmarksDetectorGraphOptions& subgraph_options,
 | 
					      const PoseLandmarksDetectorGraphOptions& subgraph_options,
 | 
				
			||||||
      Source<Image> image_in,
 | 
					      Source<Image> image_in,
 | 
				
			||||||
      Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
 | 
					      Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph,
 | 
				
			||||||
 | 
					      bool output_segmentation_masks) {
 | 
				
			||||||
    auto& begin_loop_multi_pose_rects =
 | 
					    auto& begin_loop_multi_pose_rects =
 | 
				
			||||||
        graph.AddNode("BeginLoopNormalizedRectCalculator");
 | 
					        graph.AddNode("BeginLoopNormalizedRectCalculator");
 | 
				
			||||||
    image_in >> begin_loop_multi_pose_rects.In("CLONE");
 | 
					    image_in >> begin_loop_multi_pose_rects.In("CLONE");
 | 
				
			||||||
| 
						 | 
					@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
 | 
					        pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
 | 
				
			||||||
    auto presence = pose_landmark_subgraph.Out(kPresenceTag);
 | 
					    auto presence = pose_landmark_subgraph.Out(kPresenceTag);
 | 
				
			||||||
    auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
 | 
					    auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
 | 
				
			||||||
    auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto& end_loop_landmarks =
 | 
					    auto& end_loop_landmarks =
 | 
				
			||||||
        graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
 | 
					        graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
 | 
				
			||||||
| 
						 | 
					@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
    auto presence_scores =
 | 
					    auto presence_scores =
 | 
				
			||||||
        end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
 | 
					        end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
 | 
					    std::optional<Stream<std::vector<Image>>> segmentation_masks_vector;
 | 
				
			||||||
    batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
 | 
					    if (output_segmentation_masks) {
 | 
				
			||||||
    segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
 | 
					      auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
 | 
				
			||||||
    auto segmentation_masks =
 | 
					      auto& end_loop_segmentation_mask =
 | 
				
			||||||
        end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
 | 
					          graph.AddNode("EndLoopImageCalculator");
 | 
				
			||||||
 | 
					      batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
 | 
				
			||||||
 | 
					      segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
 | 
				
			||||||
 | 
					      segmentation_masks_vector =
 | 
				
			||||||
 | 
					          end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return {{
 | 
					    return {{
 | 
				
			||||||
        /* landmark_lists= */ landmark_lists,
 | 
					        /* landmark_lists= */ landmark_lists,
 | 
				
			||||||
| 
						 | 
					@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        /* pose_rects_next_frame= */ pose_rects_next_frame,
 | 
					        /* pose_rects_next_frame= */ pose_rects_next_frame,
 | 
				
			||||||
        /* presences= */ presences,
 | 
					        /* presences= */ presences,
 | 
				
			||||||
        /* presence_scores= */ presence_scores,
 | 
					        /* presence_scores= */ presence_scores,
 | 
				
			||||||
        /* segmentation_masks= */ segmentation_masks,
 | 
					        /* segmentation_masks= */ segmentation_masks_vector,
 | 
				
			||||||
    }};
 | 
					    }};
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user