Smooth pose landmarks
PiperOrigin-RevId: 570441366
This commit is contained in:
parent
a72839ef99
commit
da8fcb6bb2
|
@ -54,8 +54,12 @@ cc_library(
|
|||
srcs = ["pose_landmarks_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||
"//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:gate_calculator",
|
||||
"//mediapipe/calculators/core:side_packet_to_stream_calculator",
|
||||
"//mediapipe/calculators/core:split_proto_list_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
|
@ -87,6 +91,9 @@ cc_library(
|
|||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/api2/stream:get_vector_item",
|
||||
"//mediapipe/framework/api2/stream:image_size",
|
||||
"//mediapipe/framework/api2/stream:smoothing",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
|
@ -98,6 +105,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -125,21 +133,18 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/association_calculator.pb.h"
|
||||
|
@ -29,14 +26,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
|
||||
|
@ -292,7 +286,9 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
|
||||
auto& pose_detector =
|
||||
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
|
||||
pose_detector.GetOptions<PoseDetectorGraphOptions>().Swap(
|
||||
auto& pose_detector_options =
|
||||
pose_detector.GetOptions<PoseDetectorGraphOptions>();
|
||||
pose_detector_options.Swap(
|
||||
tasks_options.mutable_pose_detector_graph_options());
|
||||
auto& clip_pose_rects =
|
||||
graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
|
||||
|
@ -303,9 +299,23 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
auto& pose_landmarks_detector_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"MultiplePoseLandmarksDetectorGraph");
|
||||
auto& pose_landmarks_detector_graph_options =
|
||||
pose_landmarks_detector_graph
|
||||
.GetOptions<PoseLandmarksDetectorGraphOptions>()
|
||||
.Swap(tasks_options.mutable_pose_landmarks_detector_graph_options());
|
||||
.GetOptions<PoseLandmarksDetectorGraphOptions>();
|
||||
pose_landmarks_detector_graph_options.Swap(
|
||||
tasks_options.mutable_pose_landmarks_detector_graph_options());
|
||||
|
||||
// Apply smoothing filter only on the single pose landmarks, because
|
||||
// landmarks smoothing calculator doesn't support multiple landmarks yet.
|
||||
if (pose_detector_options.num_poses() == 1) {
|
||||
pose_landmarks_detector_graph_options.set_smooth_landmarks(
|
||||
tasks_options.base_options().use_stream_mode());
|
||||
} else if (pose_detector_options.num_poses() > 1 &&
|
||||
pose_landmarks_detector_graph_options.smooth_landmarks()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Currently pose landmarks smoothing only supports a single pose.");
|
||||
}
|
||||
|
||||
image_in >> pose_landmarks_detector_graph.In(kImageTag);
|
||||
clipped_pose_rects >> pose_landmarks_detector_graph.In(kNormRectTag);
|
||||
|
||||
|
|
|
@ -240,7 +240,7 @@ TEST_P(ImageModeTest, Succeeds) {
|
|||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseGestureTest, ImageModeTest,
|
||||
PoseTest, ImageModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "Pose",
|
||||
/* test_image_name= */ kPoseImage,
|
||||
|
@ -328,7 +328,7 @@ TEST_P(VideoModeTest, Succeeds) {
|
|||
// TODO Investigate PoseLandmarker performance in VideoMode.
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseGestureTest, VideoModeTest,
|
||||
PoseTest, VideoModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "Pose",
|
||||
/* test_image_name= */ kPoseImage,
|
||||
|
@ -444,7 +444,7 @@ TEST_P(LiveStreamModeTest, Succeeds) {
|
|||
// Investigate PoseLandmarker performance in LiveStreamMode.
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseGestureTest, LiveStreamModeTest,
|
||||
PoseTest, LiveStreamModeTest,
|
||||
Values(TestParams{
|
||||
/* test_name= */ "Pose",
|
||||
/* test_image_name= */ kPoseImage,
|
||||
|
|
|
@ -13,7 +13,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
|
@ -26,6 +31,9 @@ limitations under the License.
|
|||
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/api2/stream/get_vector_item.h"
|
||||
#include "mediapipe/framework/api2/stream/image_size.h"
|
||||
#include "mediapipe/framework/api2/stream/smoothing.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
@ -47,7 +55,10 @@ namespace pose_landmarker {
|
|||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::GetImageSize;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::SmoothLandmarks;
|
||||
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
@ -213,6 +224,23 @@ void ConfigureWarpAffineCalculator(
|
|||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
}
|
||||
|
||||
template <typename TickT>
|
||||
Stream<int> CreateIntConstantStream(Stream<TickT> tick_stream, int constant_int,
|
||||
Graph& graph) {
|
||||
auto& constant_side_packet_node =
|
||||
graph.AddNode("ConstantSidePacketCalculator");
|
||||
constant_side_packet_node
|
||||
.GetOptions<mediapipe::ConstantSidePacketCalculatorOptions>()
|
||||
.add_packet()
|
||||
->set_int_value(constant_int);
|
||||
auto side_packet = constant_side_packet_node.SideOut("PACKET");
|
||||
|
||||
auto& side_packet_to_stream = graph.AddNode("SidePacketToStreamCalculator");
|
||||
tick_stream.ConnectTo(side_packet_to_stream.In("TICK"));
|
||||
side_packet.ConnectTo(side_packet_to_stream.SideIn(""));
|
||||
return side_packet_to_stream.Out("AT_TICK").Cast<int>();
|
||||
}
|
||||
|
||||
// A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph"
|
||||
// performs pose landmarks detection.
|
||||
// - Accepts CPU input images and outputs Landmark on CPU.
|
||||
|
@ -669,8 +697,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
auto& pose_landmark_subgraph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.pose_landmarker."
|
||||
"SinglePoseLandmarksDetectorGraph");
|
||||
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>()
|
||||
.CopyFrom(subgraph_options);
|
||||
pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>() =
|
||||
subgraph_options;
|
||||
image >> pose_landmark_subgraph.In(kImageTag);
|
||||
pose_rect >> pose_landmark_subgraph.In(kNormRectTag);
|
||||
auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag);
|
||||
|
@ -734,6 +762,70 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
|
||||
}
|
||||
|
||||
// Apply smoothing filter only on the single pose landmarks, because
|
||||
// landmarks smoothing calculator doesn't support multiple landmarks yet.
|
||||
// Notice the landmarks smoothing calculator cannot be put inside the for
|
||||
// loop calculator, because the smoothing calculator utilize the timestamp
|
||||
// to smoote landmarks across frames but the for loop calculator makes fake
|
||||
// timestamps for the streams.
|
||||
if (subgraph_options.smooth_landmarks()) {
|
||||
Stream<std::pair<int, int>> image_size = GetImageSize(image_in, graph);
|
||||
Stream<int> zero_index =
|
||||
CreateIntConstantStream(landmark_lists, 0, graph);
|
||||
Stream<NormalizedLandmarkList> landmarks =
|
||||
GetItem(landmark_lists, zero_index, graph);
|
||||
Stream<LandmarkList> world_landmarks =
|
||||
GetItem(world_landmark_lists, zero_index, graph);
|
||||
Stream<NormalizedRect> roi =
|
||||
GetItem(pose_rects_next_frame, zero_index, graph);
|
||||
|
||||
// Apply smoothing filter on pose landmarks.
|
||||
landmarks = SmoothLandmarksVisibility(
|
||||
landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
|
||||
landmarks = SmoothLandmarks(
|
||||
landmarks, image_size, roi,
|
||||
{// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter
|
||||
// when landmark is static.
|
||||
/*min_cutoff=*/0.05f,
|
||||
// Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
|
||||
// alpha in landmark EMA filter when landmark is moving fast.
|
||||
/*beta=*/80.0f,
|
||||
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark
|
||||
// velocity EMA filter.
|
||||
/*derivate_cutoff=*/1.0f},
|
||||
graph);
|
||||
|
||||
// Apply smoothing filter on pose world landmarks.
|
||||
world_landmarks = SmoothLandmarksVisibility(
|
||||
world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
|
||||
world_landmarks = SmoothLandmarks(
|
||||
world_landmarks,
|
||||
/*scale_roi=*/std::nullopt,
|
||||
{// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter
|
||||
// when landmark is static.
|
||||
/*min_cutoff=*/0.1f,
|
||||
// Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
|
||||
// alpha in landmark EMA filter when landmark is moving fast.
|
||||
/*beta=*/40.0f,
|
||||
// Derivative cutoff 1.0 results into ~0.17 alpha in landmark
|
||||
// velocity EMA filter.
|
||||
/*derivate_cutoff=*/1.0f},
|
||||
graph);
|
||||
|
||||
// Wrap the single pose landmarks into a vector of landmarks.
|
||||
auto& concat_landmarks =
|
||||
graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator");
|
||||
landmarks >> concat_landmarks.In("");
|
||||
landmark_lists =
|
||||
concat_landmarks.Out("").Cast<std::vector<NormalizedLandmarkList>>();
|
||||
|
||||
auto& concat_world_landmarks =
|
||||
graph.AddNode("ConcatenateLandmarkListVectorCalculator");
|
||||
world_landmarks >> concat_world_landmarks.In("");
|
||||
world_landmark_lists =
|
||||
concat_world_landmarks.Out("").Cast<std::vector<LandmarkList>>();
|
||||
}
|
||||
|
||||
return {{
|
||||
/* landmark_lists= */ landmark_lists,
|
||||
/* world_landmark_lists= */ world_landmark_lists,
|
||||
|
|
|
@ -35,4 +35,11 @@ message PoseLandmarksDetectorGraphOptions {
|
|||
// Minimum confidence value ([0.0, 1.0]) for pose presence score to be
|
||||
// considered successfully detecting a pose in the image.
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
|
||||
// Whether to smooth the detected landmarks over timestamps. Note that
|
||||
// landmarks smoothing is only applicable for a single pose. If multiple poses
|
||||
// landmarks are given, and smooth_landmarks is true, only the first pose
|
||||
// landmarks would be smoothed, and the remaining landmarks are discarded in
|
||||
// the returned landmarks list.
|
||||
optional bool smooth_landmarks = 3;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user