Internal change.

PiperOrigin-RevId: 524336899
This commit is contained in:
MediaPipe Team 2023-04-14 11:08:50 -07:00 committed by Copybara-Service
parent 534da98ccb
commit 27038f534a
4 changed files with 98 additions and 66 deletions

View File

@ -97,6 +97,7 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor",
],
)

View File

@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) {
bool enable_flow_limiting, bool output_segmentation_masks) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
graph.Out(kNormLandmarksTag);
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
graph.Out(kPoseAuxiliaryLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (output_segmentation_masks) {
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
}
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
PoseLandmarkerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
options->running_mode == core::RunningMode::LIVE_STREAM,
options->output_segmentation_masks),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback))));

View File

@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<Detection>> pose_detections;
Source<std::vector<Image>> segmentation_masks;
std::optional<Source<std::vector<Image>>> segmentation_masks;
Source<Image> image;
};
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "NORM_LANDMARKS:pose_landmarks"
// output_stream: "LANDMARKS:world_landmarks"
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
// output_stream: "WORLD_LANDMARKS:world_landmarks"
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
// output_stream: "POSE_RECTS:pose_rects"
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
bool output_segmentation_masks =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
if (sc->Options<PoseLandmarkerGraphOptions>()
.base_options()
.has_model_asset()) {
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(
auto outs,
BuildPoseLandmarkerGraph(
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
ASSIGN_OR_RETURN(auto outs,
BuildPoseLandmarkerGraph(
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph, output_segmentation_masks));
outs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
outs.world_landmark_lists >>
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
kAuxiliaryLandmarksTag)];
outs.pose_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
outs.pose_detections >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
outs.image >> graph[Output<Image>(kImageTag)];
if (outs.segmentation_masks) {
*outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}
// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
Source<NormalizedRect> norm_rect_in, Graph& graph,
bool output_segmentation_masks) {
const int max_num_poses =
tasks_options.pose_detector_graph_options().num_poses();
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
auto pose_rects_for_next_frame =
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
.Cast<std::vector<NormalizedRect>>();
auto segmentation_masks =
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
.Cast<std::vector<Image>>();
std::optional<Source<std::vector<Image>>> segmentation_masks;
if (output_segmentation_masks) {
segmentation_masks =
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
.Cast<std::vector<Image>>();
}
if (tasks_options.base_options().use_stream_mode()) {
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
@ -48,6 +49,7 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::pose_landmarker::proto::
PoseLandmarksDetectorGraphOptions;
@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs {
Source<NormalizedRect> pose_rect_next_frame;
Source<bool> pose_presence;
Source<float> pose_presence_score;
Source<Image> segmentation_mask;
std::optional<Source<Image>> segmentation_mask;
};
struct PoseLandmarkerOutputs {
@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<bool>> presences;
Source<std::vector<float>> presence_scores;
Source<std::vector<Image>> segmentation_masks;
std::optional<Source<std::vector<Image>>> segmentation_masks;
};
absl::Status SanityCheckOptions(
@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
bool output_segmentation_mask =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto pose_landmark_detection_outs,
BuildSinglePoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
ASSIGN_OR_RETURN(auto pose_landmark_detection_outs,
BuildSinglePoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph, output_segmentation_mask));
pose_landmark_detection_outs.pose_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
pose_landmark_detection_outs.world_pose_landmarks >>
@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
graph[Output<bool>(kPresenceTag)];
pose_landmark_detection_outs.pose_presence_score >>
graph[Output<float>(kPresenceScoreTag)];
pose_landmark_detection_outs.segmentation_mask >>
graph[Output<Image>(kSegmentationMaskTag)];
if (pose_landmark_detection_outs.segmentation_mask) {
*pose_landmark_detection_outs.segmentation_mask >>
graph[Output<Image>(kSegmentationMaskTag)];
}
return graph.GetConfig();
}
@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
BuildSinglePoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options,
const ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> pose_rect, Graph& graph) {
Source<NormalizedRect> pose_rect, Graph& graph,
bool output_segmentation_mask) {
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
auto& preprocessing = graph.AddNode(
@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto raw_landmarks =
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
// Decodes the segmentation tensor into a mask image with pixel values in
// [0, 1] (1 for person and 0 for background).
auto& tensors_to_segmentation =
graph.AddNode("TensorsToSegmentationCalculator");
ConfigureTensorsToSegmentationCalculator(
&tensors_to_segmentation
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
auto raw_segmentation_mask =
tensors_to_segmentation[Output<Image>(kMaskTag)];
// Refines landmarks with the heatmap tensor.
auto& refine_landmarks_from_heatmap =
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto world_projected_landmarks =
world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
// Calculates the inverse transformation matrix.
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
matrix >> inverse_matrix.In(kMatrixTag);
auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
std::optional<Stream<Image>> segmentation_mask;
if (output_segmentation_mask) {
// Decodes the segmentation tensor into a mask image with pixel values in
// [0, 1] (1 for person and 0 for background).
auto& tensors_to_segmentation =
graph.AddNode("TensorsToSegmentationCalculator");
ConfigureTensorsToSegmentationCalculator(
&tensors_to_segmentation.GetOptions<
mediapipe::TensorsToSegmentationCalculatorOptions>());
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
auto raw_segmentation_mask =
tensors_to_segmentation[Output<Image>(kMaskTag)];
// Projects the segmentation mask from the letterboxed ROI back to the full
// image.
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
ConfigureWarpAffineCalculator(
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
image_size >> warp_affine.In(kOutputSizeTag);
inverted_matrix >> warp_affine.In(kMatrixTag);
raw_segmentation_mask >> warp_affine.In(kImageTag);
auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
// Calculates the inverse transformation matrix.
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
matrix >> inverse_matrix.In(kMatrixTag);
auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
// Projects the segmentation mask from the letterboxed ROI back to the
// full image.
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
ConfigureWarpAffineCalculator(
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
image_size >> warp_affine.In(kOutputSizeTag);
inverted_matrix >> warp_affine.In(kMatrixTag);
raw_segmentation_mask >> warp_affine.In(kImageTag);
segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
}
// Calculate region of interest based on auxiliary landmarks, to be used
// in the next frame. Consists of LandmarksToDetection +
@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
/* pose_rect_next_frame= */ pose_rect_next_frame,
/* pose_presence= */ pose_presence,
/* pose_presence_score= */ pose_presence_score,
/* segmentation_mask= */ projected_segmentation_mask,
/* segmentation_mask= */ segmentation_mask,
}};
}
};
@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
bool output_segmentation_masks =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
ASSIGN_OR_RETURN(
auto pose_landmark_detection_outputs,
BuildPoseLandmarksDetectorGraph(
sc->Options<PoseLandmarksDetectorGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph,
output_segmentation_masks));
pose_landmark_detection_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
pose_landmark_detection_outputs.world_landmark_lists >>
@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
graph[Output<std::vector<bool>>(kPresenceTag)];
pose_landmark_detection_outputs.presence_scores >>
graph[Output<std::vector<float>>(kPresenceScoreTag)];
pose_landmark_detection_outputs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
if (pose_landmark_detection_outputs.segmentation_masks) {
*pose_landmark_detection_outputs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}
return graph.GetConfig();
}
@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
const PoseLandmarksDetectorGraphOptions& subgraph_options,
Source<Image> image_in,
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph,
bool output_segmentation_masks) {
auto& begin_loop_multi_pose_rects =
graph.AddNode("BeginLoopNormalizedRectCalculator");
image_in >> begin_loop_multi_pose_rects.In("CLONE");
@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
auto& end_loop_landmarks =
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto presence_scores =
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
auto segmentation_masks =
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
std::optional<Stream<std::vector<Image>>> segmentation_masks_vector;
if (output_segmentation_masks) {
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
auto& end_loop_segmentation_mask =
graph.AddNode("EndLoopImageCalculator");
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
segmentation_masks_vector =
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
}
return {{
/* landmark_lists= */ landmark_lists,
@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
/* pose_rects_next_frame= */ pose_rects_next_frame,
/* presences= */ presences,
/* presence_scores= */ presence_scores,
/* segmentation_masks= */ segmentation_masks,
/* segmentation_masks= */ segmentation_masks_vector,
}};
}
};