Internal change.
PiperOrigin-RevId: 524336899
This commit is contained in:
parent
534da98ccb
commit
27038f534a
|
@ -97,6 +97,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
// limit the number of frames in flight.
|
// limit the number of frames in flight.
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
||||||
bool enable_flow_limiting) {
|
bool enable_flow_limiting, bool output_segmentation_masks) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
||||||
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
|
||||||
graph.Out(kSegmentationMaskTag);
|
|
||||||
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
||||||
graph.Out(kNormLandmarksTag);
|
graph.Out(kNormLandmarksTag);
|
||||||
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
||||||
|
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
||||||
graph.Out(kPoseAuxiliaryLandmarksTag);
|
graph.Out(kPoseAuxiliaryLandmarksTag);
|
||||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||||
|
if (output_segmentation_masks) {
|
||||||
|
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||||
|
graph.Out(kSegmentationMaskTag);
|
||||||
|
}
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
||||||
|
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
||||||
PoseLandmarkerGraphOptionsProto>(
|
PoseLandmarkerGraphOptionsProto>(
|
||||||
CreateGraphConfig(
|
CreateGraphConfig(
|
||||||
std::move(options_proto),
|
std::move(options_proto),
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
options->running_mode == core::RunningMode::LIVE_STREAM,
|
||||||
|
options->output_segmentation_masks),
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
std::move(packets_callback))));
|
std::move(packets_callback))));
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
|
||||||
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
||||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||||
Source<std::vector<Detection>> pose_detections;
|
Source<std::vector<Detection>> pose_detections;
|
||||||
Source<std::vector<Image>> segmentation_masks;
|
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// input_stream: "NORM_RECT:norm_rect"
|
// input_stream: "NORM_RECT:norm_rect"
|
||||||
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||||
// output_stream: "LANDMARKS:world_landmarks"
|
// output_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||||
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
|
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||||
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
||||||
// output_stream: "POSE_RECTS:pose_rects"
|
// output_stream: "POSE_RECTS:pose_rects"
|
||||||
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
||||||
|
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
bool output_segmentation_masks =
|
||||||
|
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||||
if (sc->Options<PoseLandmarkerGraphOptions>()
|
if (sc->Options<PoseLandmarkerGraphOptions>()
|
||||||
.base_options()
|
.base_options()
|
||||||
.has_model_asset()) {
|
.has_model_asset()) {
|
||||||
|
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||||
.IsAvailable()));
|
.IsAvailable()));
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(auto outs,
|
||||||
auto outs,
|
|
||||||
BuildPoseLandmarkerGraph(
|
BuildPoseLandmarkerGraph(
|
||||||
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||||
|
graph, output_segmentation_masks));
|
||||||
outs.landmark_lists >>
|
outs.landmark_lists >>
|
||||||
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
||||||
outs.world_landmark_lists >>
|
outs.world_landmark_lists >>
|
||||||
|
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
kAuxiliaryLandmarksTag)];
|
kAuxiliaryLandmarksTag)];
|
||||||
outs.pose_rects_next_frame >>
|
outs.pose_rects_next_frame >>
|
||||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||||
outs.segmentation_masks >>
|
|
||||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
|
||||||
outs.pose_detections >>
|
outs.pose_detections >>
|
||||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||||
outs.image >> graph[Output<Image>(kImageTag)];
|
outs.image >> graph[Output<Image>(kImageTag)];
|
||||||
|
if (outs.segmentation_masks) {
|
||||||
|
*outs.segmentation_masks >>
|
||||||
|
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||||
|
}
|
||||||
|
|
||||||
// TODO remove when support is fixed.
|
// TODO remove when support is fixed.
|
||||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||||
|
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
// graph: the mediapipe graph instance to be updated.
|
// graph: the mediapipe graph instance to be updated.
|
||||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
||||||
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph,
|
||||||
|
bool output_segmentation_masks) {
|
||||||
const int max_num_poses =
|
const int max_num_poses =
|
||||||
tasks_options.pose_detector_graph_options().num_poses();
|
tasks_options.pose_detector_graph_options().num_poses();
|
||||||
|
|
||||||
|
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
auto pose_rects_for_next_frame =
|
auto pose_rects_for_next_frame =
|
||||||
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
||||||
.Cast<std::vector<NormalizedRect>>();
|
.Cast<std::vector<NormalizedRect>>();
|
||||||
auto segmentation_masks =
|
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||||
|
if (output_segmentation_masks) {
|
||||||
|
segmentation_masks =
|
||||||
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
||||||
.Cast<std::vector<Image>>();
|
.Cast<std::vector<Image>>();
|
||||||
|
}
|
||||||
|
|
||||||
if (tasks_options.base_options().use_stream_mode()) {
|
if (tasks_options.base_options().use_stream_mode()) {
|
||||||
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
||||||
|
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
|
#include "mediapipe/util/graph_builder_utils.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -48,6 +49,7 @@ using ::mediapipe::api2::Input;
|
||||||
using ::mediapipe::api2::Output;
|
using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::api2::builder::Stream;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::mediapipe::tasks::vision::pose_landmarker::proto::
|
using ::mediapipe::tasks::vision::pose_landmarker::proto::
|
||||||
PoseLandmarksDetectorGraphOptions;
|
PoseLandmarksDetectorGraphOptions;
|
||||||
|
@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs {
|
||||||
Source<NormalizedRect> pose_rect_next_frame;
|
Source<NormalizedRect> pose_rect_next_frame;
|
||||||
Source<bool> pose_presence;
|
Source<bool> pose_presence;
|
||||||
Source<float> pose_presence_score;
|
Source<float> pose_presence_score;
|
||||||
Source<Image> segmentation_mask;
|
std::optional<Source<Image>> segmentation_mask;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PoseLandmarkerOutputs {
|
struct PoseLandmarkerOutputs {
|
||||||
|
@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs {
|
||||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||||
Source<std::vector<bool>> presences;
|
Source<std::vector<bool>> presences;
|
||||||
Source<std::vector<float>> presence_scores;
|
Source<std::vector<float>> presence_scores;
|
||||||
Source<std::vector<Image>> segmentation_masks;
|
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(
|
absl::Status SanityCheckOptions(
|
||||||
|
@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
|
bool output_segmentation_mask =
|
||||||
|
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
const auto* model_resources,
|
const auto* model_resources,
|
||||||
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
|
CreateModelResources<PoseLandmarksDetectorGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(auto pose_landmark_detection_outs,
|
||||||
auto pose_landmark_detection_outs,
|
|
||||||
BuildSinglePoseLandmarksDetectorGraph(
|
BuildSinglePoseLandmarksDetectorGraph(
|
||||||
sc->Options<PoseLandmarksDetectorGraphOptions>(), *model_resources,
|
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
||||||
graph[Input<Image>(kImageTag)],
|
*model_resources, graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||||
|
graph, output_segmentation_mask));
|
||||||
pose_landmark_detection_outs.pose_landmarks >>
|
pose_landmark_detection_outs.pose_landmarks >>
|
||||||
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
||||||
pose_landmark_detection_outs.world_pose_landmarks >>
|
pose_landmark_detection_outs.world_pose_landmarks >>
|
||||||
|
@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<bool>(kPresenceTag)];
|
graph[Output<bool>(kPresenceTag)];
|
||||||
pose_landmark_detection_outs.pose_presence_score >>
|
pose_landmark_detection_outs.pose_presence_score >>
|
||||||
graph[Output<float>(kPresenceScoreTag)];
|
graph[Output<float>(kPresenceScoreTag)];
|
||||||
pose_landmark_detection_outs.segmentation_mask >>
|
if (pose_landmark_detection_outs.segmentation_mask) {
|
||||||
|
*pose_landmark_detection_outs.segmentation_mask >>
|
||||||
graph[Output<Image>(kSegmentationMaskTag)];
|
graph[Output<Image>(kSegmentationMaskTag)];
|
||||||
|
}
|
||||||
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
BuildSinglePoseLandmarksDetectorGraph(
|
BuildSinglePoseLandmarksDetectorGraph(
|
||||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||||
const ModelResources& model_resources, Source<Image> image_in,
|
const ModelResources& model_resources, Source<Image> image_in,
|
||||||
Source<NormalizedRect> pose_rect, Graph& graph) {
|
Source<NormalizedRect> pose_rect, Graph& graph,
|
||||||
|
bool output_segmentation_mask) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options));
|
||||||
|
|
||||||
auto& preprocessing = graph.AddNode(
|
auto& preprocessing = graph.AddNode(
|
||||||
|
@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
auto raw_landmarks =
|
auto raw_landmarks =
|
||||||
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
|
tensors_to_landmarks[Output<NormalizedLandmarkList>(kNormLandmarksTag)];
|
||||||
|
|
||||||
// Decodes the segmentation tensor into a mask image with pixel values in
|
|
||||||
// [0, 1] (1 for person and 0 for background).
|
|
||||||
auto& tensors_to_segmentation =
|
|
||||||
graph.AddNode("TensorsToSegmentationCalculator");
|
|
||||||
ConfigureTensorsToSegmentationCalculator(
|
|
||||||
&tensors_to_segmentation
|
|
||||||
.GetOptions<mediapipe::TensorsToSegmentationCalculatorOptions>());
|
|
||||||
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
|
|
||||||
auto raw_segmentation_mask =
|
|
||||||
tensors_to_segmentation[Output<Image>(kMaskTag)];
|
|
||||||
|
|
||||||
// Refines landmarks with the heatmap tensor.
|
// Refines landmarks with the heatmap tensor.
|
||||||
auto& refine_landmarks_from_heatmap =
|
auto& refine_landmarks_from_heatmap =
|
||||||
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
|
graph.AddNode("RefineLandmarksFromHeatmapCalculator");
|
||||||
|
@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
auto world_projected_landmarks =
|
auto world_projected_landmarks =
|
||||||
world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
|
world_landmarks_projection.Out(kLandmarksTag).Cast<LandmarkList>();
|
||||||
|
|
||||||
|
std::optional<Stream<Image>> segmentation_mask;
|
||||||
|
if (output_segmentation_mask) {
|
||||||
|
// Decodes the segmentation tensor into a mask image with pixel values in
|
||||||
|
// [0, 1] (1 for person and 0 for background).
|
||||||
|
auto& tensors_to_segmentation =
|
||||||
|
graph.AddNode("TensorsToSegmentationCalculator");
|
||||||
|
ConfigureTensorsToSegmentationCalculator(
|
||||||
|
&tensors_to_segmentation.GetOptions<
|
||||||
|
mediapipe::TensorsToSegmentationCalculatorOptions>());
|
||||||
|
ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag);
|
||||||
|
auto raw_segmentation_mask =
|
||||||
|
tensors_to_segmentation[Output<Image>(kMaskTag)];
|
||||||
|
|
||||||
// Calculates the inverse transformation matrix.
|
// Calculates the inverse transformation matrix.
|
||||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
||||||
matrix >> inverse_matrix.In(kMatrixTag);
|
matrix >> inverse_matrix.In(kMatrixTag);
|
||||||
auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
|
auto inverted_matrix = inverse_matrix.Out(kMatrixTag);
|
||||||
|
|
||||||
// Projects the segmentation mask from the letterboxed ROI back to the full
|
// Projects the segmentation mask from the letterboxed ROI back to the
|
||||||
// image.
|
// full image.
|
||||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
||||||
ConfigureWarpAffineCalculator(
|
ConfigureWarpAffineCalculator(
|
||||||
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
|
&warp_affine.GetOptions<mediapipe::WarpAffineCalculatorOptions>());
|
||||||
image_size >> warp_affine.In(kOutputSizeTag);
|
image_size >> warp_affine.In(kOutputSizeTag);
|
||||||
inverted_matrix >> warp_affine.In(kMatrixTag);
|
inverted_matrix >> warp_affine.In(kMatrixTag);
|
||||||
raw_segmentation_mask >> warp_affine.In(kImageTag);
|
raw_segmentation_mask >> warp_affine.In(kImageTag);
|
||||||
auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
|
segmentation_mask = warp_affine.Out(kImageTag).Cast<Image>();
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate region of interest based on auxiliary landmarks, to be used
|
// Calculate region of interest based on auxiliary landmarks, to be used
|
||||||
// in the next frame. Consists of LandmarksToDetection +
|
// in the next frame. Consists of LandmarksToDetection +
|
||||||
|
@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
/* pose_rect_next_frame= */ pose_rect_next_frame,
|
/* pose_rect_next_frame= */ pose_rect_next_frame,
|
||||||
/* pose_presence= */ pose_presence,
|
/* pose_presence= */ pose_presence,
|
||||||
/* pose_presence_score= */ pose_presence_score,
|
/* pose_presence_score= */ pose_presence_score,
|
||||||
/* segmentation_mask= */ projected_segmentation_mask,
|
/* segmentation_mask= */ segmentation_mask,
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
bool output_segmentation_masks =
|
||||||
|
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto pose_landmark_detection_outputs,
|
auto pose_landmark_detection_outputs,
|
||||||
BuildPoseLandmarksDetectorGraph(
|
BuildPoseLandmarksDetectorGraph(
|
||||||
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
sc->Options<PoseLandmarksDetectorGraphOptions>(),
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph));
|
graph[Input<std::vector<NormalizedRect>>(kNormRectTag)], graph,
|
||||||
|
output_segmentation_masks));
|
||||||
pose_landmark_detection_outputs.landmark_lists >>
|
pose_landmark_detection_outputs.landmark_lists >>
|
||||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||||
pose_landmark_detection_outputs.world_landmark_lists >>
|
pose_landmark_detection_outputs.world_landmark_lists >>
|
||||||
|
@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<std::vector<bool>>(kPresenceTag)];
|
graph[Output<std::vector<bool>>(kPresenceTag)];
|
||||||
pose_landmark_detection_outputs.presence_scores >>
|
pose_landmark_detection_outputs.presence_scores >>
|
||||||
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
graph[Output<std::vector<float>>(kPresenceScoreTag)];
|
||||||
pose_landmark_detection_outputs.segmentation_masks >>
|
if (pose_landmark_detection_outputs.segmentation_masks) {
|
||||||
|
*pose_landmark_detection_outputs.segmentation_masks >>
|
||||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||||
|
}
|
||||||
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
|
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarksDetectorGraph(
|
||||||
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
const PoseLandmarksDetectorGraphOptions& subgraph_options,
|
||||||
Source<Image> image_in,
|
Source<Image> image_in,
|
||||||
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph) {
|
Source<std::vector<NormalizedRect>> multi_pose_rects, Graph& graph,
|
||||||
|
bool output_segmentation_masks) {
|
||||||
auto& begin_loop_multi_pose_rects =
|
auto& begin_loop_multi_pose_rects =
|
||||||
graph.AddNode("BeginLoopNormalizedRectCalculator");
|
graph.AddNode("BeginLoopNormalizedRectCalculator");
|
||||||
image_in >> begin_loop_multi_pose_rects.In("CLONE");
|
image_in >> begin_loop_multi_pose_rects.In("CLONE");
|
||||||
|
@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
|
pose_landmark_subgraph.Out(kPoseRectNextFrameTag);
|
||||||
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
|
auto presence = pose_landmark_subgraph.Out(kPresenceTag);
|
||||||
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
|
auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag);
|
||||||
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
|
|
||||||
|
|
||||||
auto& end_loop_landmarks =
|
auto& end_loop_landmarks =
|
||||||
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
|
graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator");
|
||||||
|
@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
auto presence_scores =
|
auto presence_scores =
|
||||||
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
|
end_loop_presence_score[Output<std::vector<float>>(kIterableTag)];
|
||||||
|
|
||||||
auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator");
|
std::optional<Stream<std::vector<Image>>> segmentation_masks_vector;
|
||||||
|
if (output_segmentation_masks) {
|
||||||
|
auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag);
|
||||||
|
auto& end_loop_segmentation_mask =
|
||||||
|
graph.AddNode("EndLoopImageCalculator");
|
||||||
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
|
batch_end >> end_loop_segmentation_mask.In(kBatchEndTag);
|
||||||
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
|
segmentation_mask >> end_loop_segmentation_mask.In(kItemTag);
|
||||||
auto segmentation_masks =
|
segmentation_masks_vector =
|
||||||
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
|
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
|
||||||
|
}
|
||||||
|
|
||||||
return {{
|
return {{
|
||||||
/* landmark_lists= */ landmark_lists,
|
/* landmark_lists= */ landmark_lists,
|
||||||
|
@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
||||||
/* pose_rects_next_frame= */ pose_rects_next_frame,
|
/* pose_rects_next_frame= */ pose_rects_next_frame,
|
||||||
/* presences= */ presences,
|
/* presences= */ presences,
|
||||||
/* presence_scores= */ presence_scores,
|
/* presence_scores= */ presence_scores,
|
||||||
/* segmentation_masks= */ segmentation_masks,
|
/* segmentation_masks= */ segmentation_masks_vector,
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue
Block a user