Internal change.
PiperOrigin-RevId: 524336899
This commit is contained in:
parent
534da98ccb
commit
27038f534a
|
@ -97,6 +97,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
|||
// limit the number of frames in flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
||||
bool enable_flow_limiting) {
|
||||
bool enable_flow_limiting, bool output_segmentation_masks) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
||||
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||
graph.Out(kSegmentationMaskTag);
|
||||
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
||||
graph.Out(kNormLandmarksTag);
|
||||
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
||||
|
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
||||
graph.Out(kPoseAuxiliaryLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (output_segmentation_masks) {
|
||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||
graph.Out(kSegmentationMaskTag);
|
||||
}
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
||||
|
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
|||
PoseLandmarkerGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM,
|
||||
options->output_segmentation_masks),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback))));
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
|
|||
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||
Source<std::vector<Detection>> pose_detections;
|
||||
Source<std::vector<Image>> segmentation_masks;
|
||||
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
// input_stream: "IMAGE:image_in"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||
// output_stream: "LANDMARKS:world_landmarks"
|
||||
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
||||
// output_stream: "POSE_RECTS:pose_rects"
|
||||
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
||||
|
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
bool output_segmentation_masks =
|
||||
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||
if (sc->Options<PoseLandmarkerGraphOptions>()
|
||||
.base_options()
|
||||
.has_model_asset()) {
|
||||
|
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto outs,
|
||||
ASSIGN_OR_RETURN(auto outs,
|
||||
BuildPoseLandmarkerGraph(
|
||||
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||
graph, output_segmentation_masks));
|
||||
outs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
||||
outs.world_landmark_lists >>
|
||||
|
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
kAuxiliaryLandmarksTag)];
|
||||
outs.pose_rects_next_frame >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||
outs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
outs.pose_detections >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
outs.image >> graph[Output<Image>(kImageTag)];
|
||||
if (outs.segmentation_masks) {
|
||||
*outs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
}
|
||||
|
||||
// TODO remove when support is fixed.
|
||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||
|
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
// graph: the mediapipe graph instance to be updated.
|
||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
||||
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph,
|
||||
bool output_segmentation_masks) {
|
||||
const int max_num_poses =
|
||||
tasks_options.pose_detector_graph_options().num_poses();
|
||||
|
||||
|
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
auto pose_rects_for_next_frame =
|
||||
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
||||
.Cast<std::vector<NormalizedRect>>();
|
||||
auto segmentation_masks =
|
||||
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||
if (output_segmentation_masks) {
|
||||
segmentation_masks =
|
||||
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
||||
.Cast<std::vector<Image>>();
|
||||
}
|
||||
|
||||
if (tasks_options.base_options().use_stream_mode()) {
|
||||
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
||||
|
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/util/graph_builder_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -48,6 +49,7 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::vision::pose_landmarker::proto::
|
||||
PoseLandmarksDetectorGraphOptions;
|
||||
|
@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs {
|
|||
Source<NormalizedRect> pose_rect_next_frame;
|
||||
Source<bool> pose_presence;
|
||||
Source<float> pose_presence_score;
|
||||
Source<Image> segmentation_mask;
|
||||
std::optional<Source<Image>> segmentation_mask;
|
||||
};
|
||||
|
||||
struct PoseLandmarkerOutputs {
|
||||
|
@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs {
|
|||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||
Source<std::vector<bool>> presences;
|
||||
Source<std::vector<float>> presence_scores;
|
||||
Source<std::vector<Image>> segmentation_masks;
|
||||
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(
|
||||
|
@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
bool output_segmentation_mask =
|
||||
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pose_landmark_detection_outs,
|
||||
ASSIGN_OR_RETURN(auto pose_landmark_detection_outs,
|
||||
BuildSinglePoseLandmarksDetectorGraph(
|
||||
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||
graph, output_segmentation_mask));
|
||||
pose_landmark_detection_outs.pose_landmarks >>
|
||||
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
||||
pose_landmark_detection_outs.world_pose_landmarks >>
|
||||
|
@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
graph[Output<bool>(kPresenceTag)];
|
||||
pose_landmark_detection_outs.pose_presence_score >>
|
||||
graph[Output<float>(kPresenceScoreTag)];
|
||||
pose_landmark_detection_outs.segmentation_mask >>
|
||||
if (pose_landmark_detection_outs.segmentation_mask) {
|
||||
*pose_landmark_detection_outs.segmentation_mask >>
|
||||
graph[Output<Image>(kSegmentationMaskTag)];
|
||||
}
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
BuildSinglePoseLandmarksDetectorGraph(
|
||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||
const ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> pose_rect, Graph& graph) {
|
||||
Source<NormalizedRect> pose_rect, Graph& graph,
|
||||
bool output_segmentation_mask) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
|
||||
|
||||
auto& preprocessing = graph.AddNode(
|
||||
|
@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
auto raw_landmarks =
|
||||
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
|
||||
|
||||
// Decodes the segmentation tensor into a mask image with pixel values in
|
||||
// [0, 1] (1 for person and 0 for background).
|
||||
auto& tensors_to_segmentation =
|
||||
graph.AddNode("TensorsToSegmentationCalculator");
|
||||
ConfigureTensorsToSegmentationCalculator(
|
||||
&tensors_to_segmentation
|
||||
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
|
||||
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
|
||||
auto raw_segmentation_mask =
|
||||
tensors_to_segmentation[Output<Image>(kMaskTag)];
|
||||
|
||||
// Refines landmarks with the heatmap tensor.
|
||||
auto& refine_landmarks_from_heatmap =
|
||||
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
|
||||
|
@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
auto world_projected_landmarks =
|
||||
world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
|
||||
|
||||
std::optional<Stream<Image>> segmentation_mask;
|
||||
if (output_segmentation_mask) {
|
||||
// Decodes the segmentation tensor into a mask image with pixel values in
|
||||
// [0, 1] (1 for person and 0 for background).
|
||||
auto& tensors_to_segmentation =
|
||||
graph.AddNode("TensorsToSegmentationCalculator");
|
||||
ConfigureTensorsToSegmentationCalculator(
|
||||
&tensors_to_segmentation.GetOptions<
|
||||
mediapipe::TensorsToSegmentationCalculatorOptions>());
|
||||
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
|
||||
auto raw_segmentation_mask =
|
||||
tensors_to_segmentation[Output<Image>(kMaskTag)];
|
||||
|
||||
// Calculates the inverse transformation matrix.
|
||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
||||
matrix >> inverse_matrix.In(kMatrixTag);
|
||||
auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
|
||||
|
||||
// Projects the segmentation mask from the letterboxed ROI back to the full
|
||||
// image.
|
||||
// Projects the segmentation mask from the letterboxed ROI back to the
|
||||
// full image.
|
||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
||||
ConfigureWarpAffineCalculator(
|
||||
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
|
||||
image_size >> warp_affine.In(kOutputSizeTag);
|
||||
inverted_matrix >> warp_affine.In(kMatrixTag);
|
||||
raw_segmentation_mask >> warp_affine.In(kImageTag);
|
||||
auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
|
||||
segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
|
||||
}
|
||||
|
||||
// Calculate region of interest based on auxiliary landmarks, to be used
|
||||
// in the next frame. Consists of LandmarksToDetection +
|
||||
|
@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
/* pose_rect_next_frame= */ pose_rect_next_frame,
|
||||
/* pose_presence= */ pose_presence,
|
||||
/* pose_presence_score= */ pose_presence_score,
|
||||
/* segmentation_mask= */ projected_segmentation_mask,
|
||||
/* segmentation_mask= */ segmentation_mask,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
bool output_segmentation_masks =
|
||||
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pose_landmark_detection_outputs,
|
||||
BuildPoseLandmarksDetectorGraph(
|
||||
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
|
||||
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph,
|
||||
output_segmentation_masks));
|
||||
pose_landmark_detection_outputs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
pose_landmark_detection_outputs.world_landmark_lists >>
|
||||
|
@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
graph[Output<std::vector<bool>>(kPresenceTag)];
|
||||
pose_landmark_detection_outputs.presence_scores >>
|
||||
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
||||
pose_landmark_detection_outputs.segmentation_masks >>
|
||||
if (pose_landmark_detection_outputs.segmentation_masks) {
|
||||
*pose_landmark_detection_outputs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
}
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
|
||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||
Source<Image> image_in,
|
||||
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
|
||||
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph,
|
||||
bool output_segmentation_masks) {
|
||||
auto& begin_loop_multi_pose_rects =
|
||||
graph.AddNode("BeginLoopNormalizedRectCalculator");
|
||||
image_in >> begin_loop_multi_pose_rects.In("CLONE");
|
||||
|
@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
|
||||
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
|
||||
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
|
||||
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
|
||||
|
||||
auto& end_loop_landmarks =
|
||||
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
|
||||
|
@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
auto presence_scores =
|
||||
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
|
||||
|
||||
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
|
||||
std::optional<Stream<std::vector<Image>>> segmentation_masks_vector;
|
||||
if (output_segmentation_masks) {
|
||||
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
|
||||
auto& end_loop_segmentation_mask =
|
||||
graph.AddNode("EndLoopImageCalculator");
|
||||
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
|
||||
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
|
||||
auto segmentation_masks =
|
||||
segmentation_masks_vector =
|
||||
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
|
||||
}
|
||||
|
||||
return {{
|
||||
/* landmark_lists= */ landmark_lists,
|
||||
|
@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
/* pose_rects_next_frame= */ pose_rects_next_frame,
|
||||
/* presences= */ presences,
|
||||
/* presence_scores= */ presence_scores,
|
||||
/* segmentation_masks= */ segmentation_masks,
|
||||
/* segmentation_masks= */ segmentation_masks_vector,
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue
Block a user