Smooth pose landmarks

PiperOrigin-RevId: 570441366
This commit is contained in:
MediaPipe Team 2023-10-03 11:09:56 -07:00 committed by Copybara-Service
parent a72839ef99
commit da8fcb6bb2
5 changed files with 134 additions and 20 deletions

View File

@ -54,8 +54,12 @@ cc_library(
srcs = ["pose_landmarks_detector_graph.cc"],
deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
"//mediapipe/calculators/core:split_proto_list_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
@ -87,6 +91,9 @@ cc_library(
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/api2/stream:get_vector_item",
"//mediapipe/framework/api2/stream:image_size",
"//mediapipe/framework/api2/stream:smoothing",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
@ -98,6 +105,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
@ -125,21 +133,18 @@ cc_library(
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)

View File

@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/strings/str_format.h"
#include "absl/status/status.h"
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
#include "mediapipe/calculators/core/gate_calculator.pb.h"
#include "mediapipe/calculators/util/association_calculator.pb.h"
@ -29,14 +26,11 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
@ -292,7 +286,9 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
auto& pose_detector =
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
pose_detector.GetOptions<PoseDetectorGraphOptions>().Swap(
auto& pose_detector_options =
pose_detector.GetOptions<PoseDetectorGraphOptions>();
pose_detector_options.Swap(
tasks_options.mutable_pose_detector_graph_options());
auto& clip_pose_rects =
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
@ -303,9 +299,23 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
auto& pose_landmarks_detector_graph = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"MultiplePoseLandmarksDetectorGraph");
auto& pose_landmarks_detector_graph_options =
pose_landmarks_detector_graph
.GetOptions<PoseLandmarksDetectorGraphOptions>()
.Swap(tasks_options.mutable_pose_landmarks_detector_graph_options());
.GetOptions<PoseLandmarksDetectorGraphOptions>();
pose_landmarks_detector_graph_options.Swap(
tasks_options.mutable_pose_landmarks_detector_graph_options());
// Apply smoothing filter only on the single pose landmarks, because
// landmarks smoothing calculator doesn't support multiple landmarks yet.
if (pose_detector_options.num_poses() == 1) {
pose_landmarks_detector_graph_options.set_smooth_landmarks(
tasks_options.base_options().use_stream_mode());
} else if (pose_detector_options.num_poses() > 1 &&
pose_landmarks_detector_graph_options.smooth_landmarks()) {
return absl::InvalidArgumentError(
"Currently pose landmarks smoothing only supports a single pose.");
}
image_in >> pose_landmarks_detector_graph.In(kImageTag);
clipped_pose_rects >> pose_landmarks_detector_graph.In(kNormRectTag);

View File

@ -240,7 +240,7 @@ TEST_P(ImageModeTest, Succeeds) {
}
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, ImageModeTest,
PoseTest, ImageModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,
@ -328,7 +328,7 @@ TEST_P(VideoModeTest, Succeeds) {
// TODO Investigate PoseLandmarker performance in VideoMode.
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, VideoModeTest,
PoseTest, VideoModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,
@ -444,7 +444,7 @@ TEST_P(LiveStreamModeTest, Succeeds) {
// Investigate PoseLandmarker performance in LiveStreamMode.
INSTANTIATE_TEST_SUITE_P(
PoseGestureTest, LiveStreamModeTest,
PoseTest, LiveStreamModeTest,
Values(TestParams{
/* test_name= */ "Pose",
/* test_image_name= */ kPoseImage,

View File

@ -13,7 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
@ -26,6 +31,9 @@ limitations under the License.
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/api2/stream/get_vector_item.h"
#include "mediapipe/framework/api2/stream/image_size.h"
#include "mediapipe/framework/api2/stream/smoothing.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -47,7 +55,10 @@ namespace pose_landmarker {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GetImageSize;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::SmoothLandmarks;
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::api2::builder::Stream;
using ::mediapipe::tasks::core::ModelResources;
@ -213,6 +224,23 @@ void ConfigureWarpAffineCalculator(
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
}
template <typename TickT>
Stream<int> CreateIntConstantStream(Stream<TickT> tick_stream, int constant_int,
Graph& graph) {
auto& constant_side_packet_node =
graph.AddNode("ConstantSidePacketCalculator");
constant_side_packet_node
.GetOptions<mediapipe::ConstantSidePacketCalculatorOptions>()
.add_packet()
->set_int_value(constant_int);
auto side_packet = constant_side_packet_node.SideOut("PACKET");
auto& side_packet_to_stream = graph.AddNode("SidePacketToStreamCalculator");
tick_stream.ConnectTo(side_packet_to_stream.In("TICK"));
side_packet.ConnectTo(side_packet_to_stream.SideIn(""));
return side_packet_to_stream.Out("AT_TICK").Cast<int>();
}
// A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph"
// performs pose landmarks detection.
// - Accepts CPU input images and outputs Landmark on CPU.
@ -669,8 +697,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
auto& pose_landmark_subgraph = graph.AddNode(
"mediapipe.tasks.vision.pose_landmarker."
"SinglePoseLandmarksDetectorGraph");
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>()
.CopyFrom(subgraph_options);
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>() =
subgraph_options;
image >> pose_landmark_subgraph.In(kImageTag);
pose_rect >> pose_landmark_subgraph.In(kNormRectTag);
auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag);
@ -734,6 +762,70 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
}
// Apply smoothing filter only on the single pose landmarks, because
// landmarks smoothing calculator doesn't support multiple landmarks yet.
// Notice the landmarks smoothing calculator cannot be put inside the for
// loop calculator, because the smoothing calculator utilize the timestamp
// to smoote landmarks across frames but the for loop calculator makes fake
// timestamps for the streams.
if (subgraph_options.smooth_landmarks()) {
Stream<std::pair<int, int>> image_size = GetImageSize(image_in, graph);
Stream<int> zero_index =
CreateIntConstantStream(landmark_lists, 0, graph);
Stream<NormalizedLandmarkList> landmarks =
GetItem(landmark_lists, zero_index, graph);
Stream<LandmarkList> world_landmarks =
GetItem(world_landmark_lists, zero_index, graph);
Stream<NormalizedRect> roi =
GetItem(pose_rects_next_frame, zero_index, graph);
// Apply smoothing filter on pose landmarks.
landmarks = SmoothLandmarksVisibility(
landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
landmarks = SmoothLandmarks(
landmarks, image_size, roi,
{// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter
// when landmark is static.
/*min_cutoff=*/0.05f,
// Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
// alpha in landmark EMA filter when landmark is moving fast.
/*beta=*/80.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark
// velocity EMA filter.
/*derivate_cutoff=*/1.0f},
graph);
// Apply smoothing filter on pose world landmarks.
world_landmarks = SmoothLandmarksVisibility(
world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
world_landmarks = SmoothLandmarks(
world_landmarks,
/*scale_roi=*/std::nullopt,
{// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter
// when landmark is static.
/*min_cutoff=*/0.1f,
// Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
// alpha in landmark EMA filter when landmark is moving fast.
/*beta=*/40.0f,
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark
// velocity EMA filter.
/*derivate_cutoff=*/1.0f},
graph);
// Wrap the single pose landmarks into a vector of landmarks.
auto& concat_landmarks =
graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator");
landmarks >> concat_landmarks.In("");
landmark_lists =
concat_landmarks.Out("").Cast<std::vector<NormalizedLandmarkList>>();
auto& concat_world_landmarks =
graph.AddNode("ConcatenateLandmarkListVectorCalculator");
world_landmarks >> concat_world_landmarks.In("");
world_landmark_lists =
concat_world_landmarks.Out("").Cast<std::vector<LandmarkList>>();
}
return {{
/* landmark_lists= */ landmark_lists,
/* world_landmark_lists= */ world_landmark_lists,

View File

@ -35,4 +35,11 @@ message PoseLandmarksDetectorGraphOptions {
// Minimum confidence value ([0.0, 1.0]) for pose presence score to be
// considered successfully detecting a pose in the image.
optional float min_detection_confidence = 2 [default = 0.5];
// Whether to smooth the detected landmarks over timestamps. Note that
// landmarks smoothing is only applicable for a single pose. If multiple poses
// landmarks are given, and smooth_landmarks is true, only the first pose
// landmarks would be smoothed, and the remaining landmarks are discarded in
// the returned landmarks list.
optional bool smooth_landmarks = 3;
}