Smooth pose landmarks
PiperOrigin-RevId: 570441366
This commit is contained in:
		
							parent
							
								
									a72839ef99
								
							
						
					
					
						commit
						da8fcb6bb2
					
				| 
						 | 
				
			
			@ -54,8 +54,12 @@ cc_library(
 | 
			
		|||
    srcs = ["pose_landmarks_detector_graph.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:begin_loop_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:concatenate_vector_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:constant_side_packet_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:constant_side_packet_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/core:end_loop_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:gate_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:side_packet_to_stream_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:split_proto_list_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:split_vector_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -87,6 +91,9 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework:subgraph",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/api2/stream:get_vector_item",
 | 
			
		||||
        "//mediapipe/framework/api2/stream:image_size",
 | 
			
		||||
        "//mediapipe/framework/api2/stream:smoothing",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -98,6 +105,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "//mediapipe/util:graph_builder_utils",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
| 
						 | 
				
			
			@ -125,21 +133,18 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/utils:gate",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_asset_bundle_resources",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_resources_cache",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/metadata/utils:zip_utils",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
 | 
			
		||||
        "//mediapipe/util:graph_builder_utils",
 | 
			
		||||
        "@com_google_absl//absl/strings:str_format",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/str_format.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/core/gate_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/util/association_calculator.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -29,14 +26,11 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/utils.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -292,7 +286,9 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
			
		|||
 | 
			
		||||
    auto& pose_detector =
 | 
			
		||||
        graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
 | 
			
		||||
    pose_detector.GetOptions<PoseDetectorGraphOptions>().Swap(
 | 
			
		||||
    auto& pose_detector_options =
 | 
			
		||||
        pose_detector.GetOptions<PoseDetectorGraphOptions>();
 | 
			
		||||
    pose_detector_options.Swap(
 | 
			
		||||
        tasks_options.mutable_pose_detector_graph_options());
 | 
			
		||||
    auto& clip_pose_rects =
 | 
			
		||||
        graph.AddNode("ClipNormalizedRectVectorSizeCalculator");
 | 
			
		||||
| 
						 | 
				
			
			@ -303,9 +299,23 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
			
		|||
    auto& pose_landmarks_detector_graph = graph.AddNode(
 | 
			
		||||
        "mediapipe.tasks.vision.pose_landmarker."
 | 
			
		||||
        "MultiplePoseLandmarksDetectorGraph");
 | 
			
		||||
    auto& pose_landmarks_detector_graph_options =
 | 
			
		||||
        pose_landmarks_detector_graph
 | 
			
		||||
        .GetOptions<PoseLandmarksDetectorGraphOptions>()
 | 
			
		||||
        .Swap(tasks_options.mutable_pose_landmarks_detector_graph_options());
 | 
			
		||||
            .GetOptions<PoseLandmarksDetectorGraphOptions>();
 | 
			
		||||
    pose_landmarks_detector_graph_options.Swap(
 | 
			
		||||
        tasks_options.mutable_pose_landmarks_detector_graph_options());
 | 
			
		||||
 | 
			
		||||
    // Apply smoothing filter only on the single pose landmarks, because
 | 
			
		||||
    // landmarks smoothing calculator doesn't support multiple landmarks yet.
 | 
			
		||||
    if (pose_detector_options.num_poses() == 1) {
 | 
			
		||||
      pose_landmarks_detector_graph_options.set_smooth_landmarks(
 | 
			
		||||
          tasks_options.base_options().use_stream_mode());
 | 
			
		||||
    } else if (pose_detector_options.num_poses() > 1 &&
 | 
			
		||||
               pose_landmarks_detector_graph_options.smooth_landmarks()) {
 | 
			
		||||
      return absl::InvalidArgumentError(
 | 
			
		||||
          "Currently pose landmarks smoothing only supports a single pose.");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    image_in >> pose_landmarks_detector_graph.In(kImageTag);
 | 
			
		||||
    clipped_pose_rects >> pose_landmarks_detector_graph.In(kNormRectTag);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -240,7 +240,7 @@ TEST_P(ImageModeTest, Succeeds) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    PoseGestureTest, ImageModeTest,
 | 
			
		||||
    PoseTest, ImageModeTest,
 | 
			
		||||
    Values(TestParams{
 | 
			
		||||
               /* test_name= */ "Pose",
 | 
			
		||||
               /* test_image_name= */ kPoseImage,
 | 
			
		||||
| 
						 | 
				
			
			@ -328,7 +328,7 @@ TEST_P(VideoModeTest, Succeeds) {
 | 
			
		|||
// TODO Investigate PoseLandmarker performance in VideoMode.
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    PoseGestureTest, VideoModeTest,
 | 
			
		||||
    PoseTest, VideoModeTest,
 | 
			
		||||
    Values(TestParams{
 | 
			
		||||
               /* test_name= */ "Pose",
 | 
			
		||||
               /* test_image_name= */ kPoseImage,
 | 
			
		||||
| 
						 | 
				
			
			@ -444,7 +444,7 @@ TEST_P(LiveStreamModeTest, Succeeds) {
 | 
			
		|||
// Investigate PoseLandmarker performance in LiveStreamMode.
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    PoseGestureTest, LiveStreamModeTest,
 | 
			
		||||
    PoseTest, LiveStreamModeTest,
 | 
			
		||||
    Values(TestParams{
 | 
			
		||||
               /* test_name= */ "Pose",
 | 
			
		||||
               /* test_image_name= */ kPoseImage,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,12 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/calculators/core/constant_side_packet_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +31,9 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/calculators/util/visibility_copy_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/api2/stream/get_vector_item.h"
 | 
			
		||||
#include "mediapipe/framework/api2/stream/image_size.h"
 | 
			
		||||
#include "mediapipe/framework/api2/stream/smoothing.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -47,7 +55,10 @@ namespace pose_landmarker {
 | 
			
		|||
using ::mediapipe::NormalizedRect;
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::mediapipe::api2::builder::GetImageSize;
 | 
			
		||||
using ::mediapipe::api2::builder::Graph;
 | 
			
		||||
using ::mediapipe::api2::builder::SmoothLandmarks;
 | 
			
		||||
using ::mediapipe::api2::builder::SmoothLandmarksVisibility;
 | 
			
		||||
using ::mediapipe::api2::builder::Source;
 | 
			
		||||
using ::mediapipe::api2::builder::Stream;
 | 
			
		||||
using ::mediapipe::tasks::core::ModelResources;
 | 
			
		||||
| 
						 | 
				
			
			@ -213,6 +224,23 @@ void ConfigureWarpAffineCalculator(
 | 
			
		|||
  options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename TickT>
 | 
			
		||||
Stream<int> CreateIntConstantStream(Stream<TickT> tick_stream, int constant_int,
 | 
			
		||||
                                    Graph& graph) {
 | 
			
		||||
  auto& constant_side_packet_node =
 | 
			
		||||
      graph.AddNode("ConstantSidePacketCalculator");
 | 
			
		||||
  constant_side_packet_node
 | 
			
		||||
      .GetOptions<mediapipe::ConstantSidePacketCalculatorOptions>()
 | 
			
		||||
      .add_packet()
 | 
			
		||||
      ->set_int_value(constant_int);
 | 
			
		||||
  auto side_packet = constant_side_packet_node.SideOut("PACKET");
 | 
			
		||||
 | 
			
		||||
  auto& side_packet_to_stream = graph.AddNode("SidePacketToStreamCalculator");
 | 
			
		||||
  tick_stream.ConnectTo(side_packet_to_stream.In("TICK"));
 | 
			
		||||
  side_packet.ConnectTo(side_packet_to_stream.SideIn(""));
 | 
			
		||||
  return side_packet_to_stream.Out("AT_TICK").Cast<int>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A "mediapipe.tasks.vision.pose_landmarker.SinglePoseLandmarksDetectorGraph"
 | 
			
		||||
// performs pose landmarks detection.
 | 
			
		||||
// - Accepts CPU input images and outputs Landmark on CPU.
 | 
			
		||||
| 
						 | 
				
			
			@ -669,8 +697,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
			
		|||
    auto& pose_landmark_subgraph = graph.AddNode(
 | 
			
		||||
        "mediapipe.tasks.vision.pose_landmarker."
 | 
			
		||||
        "SinglePoseLandmarksDetectorGraph");
 | 
			
		||||
    pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>()
 | 
			
		||||
        .CopyFrom(subgraph_options);
 | 
			
		||||
    pose_landmark_subgraph.GetOptions<PoseLandmarksDetectorGraphOptions>() =
 | 
			
		||||
        subgraph_options;
 | 
			
		||||
    image >> pose_landmark_subgraph.In(kImageTag);
 | 
			
		||||
    pose_rect >> pose_landmark_subgraph.In(kNormRectTag);
 | 
			
		||||
    auto landmarks = pose_landmark_subgraph.Out(kLandmarksTag);
 | 
			
		||||
| 
						 | 
				
			
			@ -734,6 +762,70 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph {
 | 
			
		|||
          end_loop_segmentation_mask[Output<std::vector<Image>>(kIterableTag)];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Apply smoothing filter only on the single pose landmarks, because
 | 
			
		||||
    // landmarks smoothing calculator doesn't support multiple landmarks yet.
 | 
			
		||||
    // Notice the landmarks smoothing calculator cannot be put inside the for
 | 
			
		||||
    // loop calculator, because the smoothing calculator utilize the timestamp
 | 
			
		||||
    // to smoote landmarks across frames but the for loop calculator makes fake
 | 
			
		||||
    // timestamps for the streams.
 | 
			
		||||
    if (subgraph_options.smooth_landmarks()) {
 | 
			
		||||
      Stream<std::pair<int, int>> image_size = GetImageSize(image_in, graph);
 | 
			
		||||
      Stream<int> zero_index =
 | 
			
		||||
          CreateIntConstantStream(landmark_lists, 0, graph);
 | 
			
		||||
      Stream<NormalizedLandmarkList> landmarks =
 | 
			
		||||
          GetItem(landmark_lists, zero_index, graph);
 | 
			
		||||
      Stream<LandmarkList> world_landmarks =
 | 
			
		||||
          GetItem(world_landmark_lists, zero_index, graph);
 | 
			
		||||
      Stream<NormalizedRect> roi =
 | 
			
		||||
          GetItem(pose_rects_next_frame, zero_index, graph);
 | 
			
		||||
 | 
			
		||||
      // Apply smoothing filter on pose landmarks.
 | 
			
		||||
      landmarks = SmoothLandmarksVisibility(
 | 
			
		||||
          landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
 | 
			
		||||
      landmarks = SmoothLandmarks(
 | 
			
		||||
          landmarks, image_size, roi,
 | 
			
		||||
          {// Min cutoff 0.05 results into ~0.01 alpha in landmark EMA filter
 | 
			
		||||
           // when landmark is static.
 | 
			
		||||
           /*min_cutoff=*/0.05f,
 | 
			
		||||
           // Beta 80.0 in combination with min_cutoff 0.05 results into ~0.94
 | 
			
		||||
           // alpha in landmark EMA filter when landmark is moving fast.
 | 
			
		||||
           /*beta=*/80.0f,
 | 
			
		||||
           // Derivative cutoff 1.0 results into ~0.17 alpha in landmark
 | 
			
		||||
           // velocity EMA filter.
 | 
			
		||||
           /*derivate_cutoff=*/1.0f},
 | 
			
		||||
          graph);
 | 
			
		||||
 | 
			
		||||
      // Apply smoothing filter on pose world landmarks.
 | 
			
		||||
      world_landmarks = SmoothLandmarksVisibility(
 | 
			
		||||
          world_landmarks, /*low_pass_filter_alpha=*/0.1f, graph);
 | 
			
		||||
      world_landmarks = SmoothLandmarks(
 | 
			
		||||
          world_landmarks,
 | 
			
		||||
          /*scale_roi=*/std::nullopt,
 | 
			
		||||
          {// Min cutoff 0.1 results into ~ 0.02 alpha in landmark EMA filter
 | 
			
		||||
           // when landmark is static.
 | 
			
		||||
           /*min_cutoff=*/0.1f,
 | 
			
		||||
           // Beta 40.0 in combination with min_cutoff 0.1 results into ~0.8
 | 
			
		||||
           // alpha in landmark EMA filter when landmark is moving fast.
 | 
			
		||||
           /*beta=*/40.0f,
 | 
			
		||||
           // Derivative cutoff 1.0 results into ~0.17 alpha in landmark
 | 
			
		||||
           // velocity EMA filter.
 | 
			
		||||
           /*derivate_cutoff=*/1.0f},
 | 
			
		||||
          graph);
 | 
			
		||||
 | 
			
		||||
      // Wrap the single pose landmarks into a vector of landmarks.
 | 
			
		||||
      auto& concat_landmarks =
 | 
			
		||||
          graph.AddNode("ConcatenateNormalizedLandmarkListVectorCalculator");
 | 
			
		||||
      landmarks >> concat_landmarks.In("");
 | 
			
		||||
      landmark_lists =
 | 
			
		||||
          concat_landmarks.Out("").Cast<std::vector<NormalizedLandmarkList>>();
 | 
			
		||||
 | 
			
		||||
      auto& concat_world_landmarks =
 | 
			
		||||
          graph.AddNode("ConcatenateLandmarkListVectorCalculator");
 | 
			
		||||
      world_landmarks >> concat_world_landmarks.In("");
 | 
			
		||||
      world_landmark_lists =
 | 
			
		||||
          concat_world_landmarks.Out("").Cast<std::vector<LandmarkList>>();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return {{
 | 
			
		||||
        /* landmark_lists= */ landmark_lists,
 | 
			
		||||
        /* world_landmark_lists= */ world_landmark_lists,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,4 +35,11 @@ message PoseLandmarksDetectorGraphOptions {
 | 
			
		|||
  // Minimum confidence value ([0.0, 1.0]) for pose presence score to be
 | 
			
		||||
  // considered successfully detecting a pose in the image.
 | 
			
		||||
  optional float min_detection_confidence = 2 [default = 0.5];
 | 
			
		||||
 | 
			
		||||
  // Whether to smooth the detected landmarks over timestamps. Note that
 | 
			
		||||
  // landmarks smoothing is only applicable for a single pose. If multiple poses
 | 
			
		||||
  // landmarks are given, and smooth_landmarks is true, only the first pose
 | 
			
		||||
  // landmarks would be smoothed, and the remaining landmarks are discarded in
 | 
			
		||||
  // the returned landmarks list.
 | 
			
		||||
  optional bool smooth_landmarks = 3;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user