Internal change.

PiperOrigin-RevId: 524336899
This commit is contained in:
MediaPipe Team 2023-04-14 11:08:50 -07:00 committed by Copybara-Service
parent 534da98ccb
commit 27038f534a
4 changed files with 98 additions and 66 deletions

View File

@ -97,6 +97,7 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
// limit the number of frames in flight. // limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options, std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) { bool enable_flow_limiting, bool output_segmentation_masks) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName); auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get()); subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
graph.Out(kNormLandmarksTag); graph.Out(kNormLandmarksTag);
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
.SetName(kPoseAuxiliaryLandmarksStreamName) >> .SetName(kPoseAuxiliaryLandmarksStreamName) >>
graph.Out(kPoseAuxiliaryLandmarksTag); graph.Out(kPoseAuxiliaryLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (output_segmentation_masks) {
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
}
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
PoseLandmarkerGraphOptionsProto>( PoseLandmarkerGraphOptionsProto>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM,
options->output_segmentation_masks),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)))); std::move(packets_callback))));

View File

@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists; Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
Source<std::vector<NormalizedRect>> pose_rects_next_frame; Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<Detection>> pose_detections; Source<std::vector<Detection>> pose_detections;
Source<std::vector<Image>> segmentation_masks; std::optional<Source<std::vector<Image>>> segmentation_masks;
Source<Image> image; Source<Image> image;
}; };
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect" // input_stream: "NORM_RECT:norm_rect"
// output_stream: "NORM_LANDMARKS:pose_landmarks" // output_stream: "NORM_LANDMARKS:pose_landmarks"
// output_stream: "LANDMARKS:world_landmarks" // output_stream: "WORLD_LANDMARKS:world_landmarks"
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks" // output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame" // output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
// output_stream: "POSE_RECTS:pose_rects" // output_stream: "POSE_RECTS:pose_rects"
// output_stream: "SEGMENTATION_MASK:segmentation_masks" // output_stream: "SEGMENTATION_MASK:segmentation_masks"
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
bool output_segmentation_masks =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
if (sc->Options<PoseLandmarkerGraphOptions>() if (sc->Options<PoseLandmarkerGraphOptions>()
.base_options() .base_options()
.has_model_asset()) { .has_model_asset()) {
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable())); .IsAvailable()));
} }
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto outs,
auto outs,
BuildPoseLandmarkerGraph( BuildPoseLandmarkerGraph(
*sc->MutableOptions<PoseLandmarkerGraphOptions>(), *sc->MutableOptions<PoseLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph, output_segmentation_masks));
outs.landmark_lists >> outs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
outs.world_landmark_lists >> outs.world_landmark_lists >>
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
kAuxiliaryLandmarksTag)]; kAuxiliaryLandmarksTag)];
outs.pose_rects_next_frame >> outs.pose_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)]; graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
outs.pose_detections >> outs.pose_detections >>
graph[Output<std::vector<Detection>>(kDetectionsTag)]; graph[Output<std::vector<Detection>>(kDetectionsTag)];
outs.image >> graph[Output<Image>(kImageTag)]; outs.image >> graph[Output<Image>(kImageTag)];
if (outs.segmentation_masks) {
*outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}
// TODO remove when support is fixed. // TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring // As mediapipe GraphBuilder currently doesn't support configuring
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
// graph: the mediapipe graph instance to be updated. // graph: the mediapipe graph instance to be updated.
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph( absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in, PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph,
bool output_segmentation_masks) {
const int max_num_poses = const int max_num_poses =
tasks_options.pose_detector_graph_options().num_poses(); tasks_options.pose_detector_graph_options().num_poses();
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
auto pose_rects_for_next_frame = auto pose_rects_for_next_frame =
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag) pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
.Cast<std::vector<NormalizedRect>>(); .Cast<std::vector<NormalizedRect>>();
auto segmentation_masks = std::optional<Source<std::vector<Image>>> segmentation_masks;
if (output_segmentation_masks) {
segmentation_masks =
pose_landmarks_detector_graph.Out(kSegmentationMaskTag) pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
.Cast<std::vector<Image>>(); .Cast<std::vector<Image>>();
}
if (tasks_options.base_options().use_stream_mode()) { if (tasks_options.base_options().use_stream_mode()) {
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator"); auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -48,6 +49,7 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::pose_landmarker::proto:: using ::mediapipe::tasks::vision::pose_landmarker::proto::
PoseLandmarksDetectorGraphOptions; PoseLandmarksDetectorGraphOptions;
@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs {
Source<NormalizedRect> pose_rect_next_frame; Source<NormalizedRect> pose_rect_next_frame;
Source<bool> pose_presence; Source<bool> pose_presence;
Source<float> pose_presence_score; Source<float> pose_presence_score;
Source<Image> segmentation_mask; std::optional<Source<Image>> segmentation_mask;
}; };
struct PoseLandmarkerOutputs { struct PoseLandmarkerOutputs {
@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedRect>> pose_rects_next_frame; Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<bool>> presences; Source<std::vector<bool>> presences;
Source<std::vector<float>> presence_scores; Source<std::vector<float>> presence_scores;
Source<std::vector<Image>> segmentation_masks; std::optional<Source<std::vector<Image>>> segmentation_masks;
}; };
absl::Status SanityCheckOptions( absl::Status SanityCheckOptions(
@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
bool output_segmentation_mask =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
const auto* model_resources, const auto* model_resources,
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc)); CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto pose_landmark_detection_outs,
auto pose_landmark_detection_outs,
BuildSinglePoseLandmarksDetectorGraph( BuildSinglePoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources, sc->Options<PoseLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)], *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph, output_segmentation_mask));
pose_landmark_detection_outs.pose_landmarks >> pose_landmark_detection_outs.pose_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)]; graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
pose_landmark_detection_outs.world_pose_landmarks >> pose_landmark_detection_outs.world_pose_landmarks >>
@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
graph[Output<bool>(kPresenceTag)]; graph[Output<bool>(kPresenceTag)];
pose_landmark_detection_outs.pose_presence_score >> pose_landmark_detection_outs.pose_presence_score >>
graph[Output<float>(kPresenceScoreTag)]; graph[Output<float>(kPresenceScoreTag)];
pose_landmark_detection_outs.segmentation_mask >> if (pose_landmark_detection_outs.segmentation_mask) {
*pose_landmark_detection_outs.segmentation_mask >>
graph[Output<Image>(kSegmentationMaskTag)]; graph[Output<Image>(kSegmentationMaskTag)];
}
return graph.GetConfig(); return graph.GetConfig();
} }
@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
BuildSinglePoseLandmarksDetectorGraph( BuildSinglePoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options, const PoseLandmarksDetectorGraphOptions& subgraph_options,
const ModelResources& model_resources, Source<Image> image_in, const ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> pose_rect, Graph& graph) { Source<NormalizedRect> pose_rect, Graph& graph,
bool output_segmentation_mask) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
auto& preprocessing = graph.AddNode( auto& preprocessing = graph.AddNode(
@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto raw_landmarks = auto raw_landmarks =
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)]; tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
// Decodes the segmentation tensor into a mask image with pixel values in
// [0, 1] (1 for person and 0 for background).
auto& tensors_to_segmentation =
graph.AddNode("TensorsToSegmentationCalculator");
ConfigureTensorsToSegmentationCalculator(
&tensors_to_segmentation
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
auto raw_segmentation_mask =
tensors_to_segmentation[Output<Image>(kMaskTag)];
// Refines landmarks with the heatmap tensor. // Refines landmarks with the heatmap tensor.
auto& refine_landmarks_from_heatmap = auto& refine_landmarks_from_heatmap =
graph.AddNode("RefineLandmarksFromHeatmapCalculator"); graph.AddNode("RefineLandmarksFromHeatmapCalculator");
@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto world_projected_landmarks = auto world_projected_landmarks =
world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>(); world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
std::optional<Stream<Image>> segmentation_mask;
if (output_segmentation_mask) {
// Decodes the segmentation tensor into a mask image with pixel values in
// [0, 1] (1 for person and 0 for background).
auto& tensors_to_segmentation =
graph.AddNode("TensorsToSegmentationCalculator");
ConfigureTensorsToSegmentationCalculator(
&tensors_to_segmentation.GetOptions<
mediapipe::TensorsToSegmentationCalculatorOptions>());
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
auto raw_segmentation_mask =
tensors_to_segmentation[Output<Image>(kMaskTag)];
// Calculates the inverse transformation matrix. // Calculates the inverse transformation matrix.
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
matrix >> inverse_matrix.In(kMatrixTag); matrix >> inverse_matrix.In(kMatrixTag);
auto inverted_matrix = inverse_matrix.Out(kMatrixTag); auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
// Projects the segmentation mask from the letterboxed ROI back to the full // Projects the segmentation mask from the letterboxed ROI back to the
// image. // full image.
auto& warp_affine = graph.AddNode("WarpAffineCalculator"); auto& warp_affine = graph.AddNode("WarpAffineCalculator");
ConfigureWarpAffineCalculator( ConfigureWarpAffineCalculator(
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>()); &warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
image_size >> warp_affine.In(kOutputSizeTag); image_size >> warp_affine.In(kOutputSizeTag);
inverted_matrix >> warp_affine.In(kMatrixTag); inverted_matrix >> warp_affine.In(kMatrixTag);
raw_segmentation_mask >> warp_affine.In(kImageTag); raw_segmentation_mask >> warp_affine.In(kImageTag);
auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>(); segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
}
// Calculate region of interest based on auxiliary landmarks, to be used // Calculate region of interest based on auxiliary landmarks, to be used
// in the next frame. Consists of LandmarksToDetection + // in the next frame. Consists of LandmarksToDetection +
@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
/* pose_rect_next_frame= */ pose_rect_next_frame, /* pose_rect_next_frame= */ pose_rect_next_frame,
/* pose_presence= */ pose_presence, /* pose_presence= */ pose_presence,
/* pose_presence_score= */ pose_presence_score, /* pose_presence_score= */ pose_presence_score,
/* segmentation_mask= */ projected_segmentation_mask, /* segmentation_mask= */ segmentation_mask,
}}; }};
} }
}; };
@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
bool output_segmentation_masks =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto pose_landmark_detection_outputs, auto pose_landmark_detection_outputs,
BuildPoseLandmarksDetectorGraph( BuildPoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(), sc->Options<PoseLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph)); graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph,
output_segmentation_masks));
pose_landmark_detection_outputs.landmark_lists >> pose_landmark_detection_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
pose_landmark_detection_outputs.world_landmark_lists >> pose_landmark_detection_outputs.world_landmark_lists >>
@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
graph[Output<std::vector<bool>>(kPresenceTag)]; graph[Output<std::vector<bool>>(kPresenceTag)];
pose_landmark_detection_outputs.presence_scores >> pose_landmark_detection_outputs.presence_scores >>
graph[Output<std::vector<float>>(kPresenceScoreTag)]; graph[Output<std::vector<float>>(kPresenceScoreTag)];
pose_landmark_detection_outputs.segmentation_masks >> if (pose_landmark_detection_outputs.segmentation_masks) {
*pose_landmark_detection_outputs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)]; graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}
return graph.GetConfig(); return graph.GetConfig();
} }
@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph( absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options, const PoseLandmarksDetectorGraphOptions& subgraph_options,
Source<Image> image_in, Source<Image> image_in,
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) { Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph,
bool output_segmentation_masks) {
auto& begin_loop_multi_pose_rects = auto& begin_loop_multi_pose_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator"); graph.AddNode("BeginLoopNormalizedRectCalculator");
image_in >> begin_loop_multi_pose_rects.In("CLONE"); image_in >> begin_loop_multi_pose_rects.In("CLONE");
@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
pose_landmark_subgraph.Out(kPoseRectNextFrameTag); pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
auto presence = pose_landmark_subgraph.Out(kPresenceTag); auto presence = pose_landmark_subgraph.Out(kPresenceTag);
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag); auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
auto& end_loop_landmarks = auto& end_loop_landmarks =
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator"); graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto presence_scores = auto presence_scores =
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)]; end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator"); std::optional<Stream<std::vector<Image>>> segmentation_masks_vector;
if (output_segmentation_masks) {
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
auto& end_loop_segmentation_mask =
graph.AddNode("EndLoopImageCalculator");
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag); batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag); segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
auto segmentation_masks = segmentation_masks_vector =
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)]; end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
}
return {{ return {{
/* landmark_lists= */ landmark_lists, /* landmark_lists= */ landmark_lists,
@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
/* pose_rects_next_frame= */ pose_rects_next_frame, /* pose_rects_next_frame= */ pose_rects_next_frame,
/* presences= */ presences, /* presences= */ presences,
/* presence_scores= */ presence_scores, /* presence_scores= */ presence_scores,
/* segmentation_masks= */ segmentation_masks, /* segmentation_masks= */ segmentation_masks_vector,
}}; }};
} }
}; };