From a577dc3043a139a5a447f4ca7209e3e416c3db44 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 27 Sep 2023 23:17:32 -0700 Subject: [PATCH] smoothing stream utility function. PiperOrigin-RevId: 569074973 --- mediapipe/framework/api2/stream/BUILD | 37 +++ mediapipe/framework/api2/stream/smoothing.cc | 131 ++++++++++ mediapipe/framework/api2/stream/smoothing.h | 119 +++++++++ .../framework/api2/stream/smoothing_test.cc | 230 ++++++++++++++++++ 4 files changed, 517 insertions(+) create mode 100644 mediapipe/framework/api2/stream/smoothing.cc create mode 100644 mediapipe/framework/api2/stream/smoothing.h create mode 100644 mediapipe/framework/api2/stream/smoothing_test.cc diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index a33645b2d..fd8ca5d16 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -288,6 +288,43 @@ cc_test( ], ) +cc_library( + name = "smoothing", + srcs = ["smoothing.cc"], + hdrs = ["smoothing.h"], + deps = [ + "//mediapipe/calculators/util:landmarks_smoothing_calculator", + "//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto", + "//mediapipe/calculators/util:multi_landmarks_smoothing_calculator", + "//mediapipe/calculators/util:multi_world_landmarks_smoothing_calculator", + "//mediapipe/calculators/util:visibility_smoothing_calculator", + "//mediapipe/calculators/util:visibility_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "smoothing_test", + srcs = ["smoothing_test.cc"], + deps = [ + ":smoothing", + "//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "split", hdrs = ["split.h"], diff --git a/mediapipe/framework/api2/stream/smoothing.cc b/mediapipe/framework/api2/stream/smoothing.cc new file mode 100644 index 000000000..5ef04254c --- /dev/null +++ b/mediapipe/framework/api2/stream/smoothing.cc @@ -0,0 +1,131 @@ +#include "mediapipe/framework/api2/stream/smoothing.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +namespace { + +void SetFilterConfig(const OneEuroFilterConfig& config, + bool disable_value_scaling, GenericNode& node) { + auto& smoothing_node_opts = + node.GetOptions(); + auto& one_euro_filter = *smoothing_node_opts.mutable_one_euro_filter(); + one_euro_filter.set_min_cutoff(config.min_cutoff); + one_euro_filter.set_derivate_cutoff(config.derivate_cutoff); + one_euro_filter.set_beta(config.beta); + one_euro_filter.set_disable_value_scaling(disable_value_scaling); +} + +void SetFilterConfig(const LandmarksSmoothingCalculatorOptions& config, + GenericNode& node) { + auto& smoothing_node_opts = + node.GetOptions(); + smoothing_node_opts = config; +} + +GenericNode& AddVisibilitySmoothingNode(float low_pass_filter_alpha, + Graph& graph) { + auto& smoothing_node = graph.AddNode("VisibilitySmoothingCalculator"); + auto& smoothing_node_opts = + smoothing_node.GetOptions(); + smoothing_node_opts.mutable_low_pass_filter()->set_alpha( + low_pass_filter_alpha); + return smoothing_node; +} + +} // namespace + +Stream SmoothLandmarks( + Stream landmarks, + Stream> image_size, + std::optional> scale_roi, + const OneEuroFilterConfig& config, Graph& graph) { + auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator"); + SetFilterConfig(config, /*disable_value_scaling=*/false, smoothing_node); + + landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS")); + image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE")); + if (scale_roi) { + scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI")); + } + return smoothing_node.Out("NORM_FILTERED_LANDMARKS") + .Cast(); +} + +Stream SmoothLandmarks( + Stream landmarks, + std::optional> scale_roi, + const OneEuroFilterConfig& config, Graph& graph) { + auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator"); + SetFilterConfig(config, /*disable_value_scaling=*/true, smoothing_node); + + landmarks.ConnectTo(smoothing_node.In("LANDMARKS")); + if (scale_roi) { + scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI")); + } + return smoothing_node.Out("FILTERED_LANDMARKS").Cast(); +} + +Stream> SmoothMultiLandmarks( + Stream> landmarks, + Stream> tracking_ids, + Stream> image_size, + std::optional>> scale_roi, + const LandmarksSmoothingCalculatorOptions& config, Graph& graph) { + auto& smoothing_node = graph.AddNode("MultiLandmarksSmoothingCalculator"); + SetFilterConfig(config, smoothing_node); + + landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS")); + tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS")); + image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE")); + if (scale_roi) { + scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI")); + } + return smoothing_node.Out("NORM_FILTERED_LANDMARKS") + .Cast>(); +} + +Stream> SmoothMultiWorldLandmarks( + Stream> landmarks, + Stream> tracking_ids, + std::optional>> scale_roi, + const LandmarksSmoothingCalculatorOptions& config, Graph& graph) { + auto& smoothing_node = + graph.AddNode("MultiWorldLandmarksSmoothingCalculator"); + SetFilterConfig(config, smoothing_node); + + landmarks.ConnectTo(smoothing_node.In("LANDMARKS")); + tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS")); + if (scale_roi) { + scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI")); + } + return smoothing_node.Out("FILTERED_LANDMARKS") + .Cast>(); +} + +Stream SmoothLandmarksVisibility( + Stream landmarks, float low_pass_filter_alpha, + Graph& graph) { + auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph); + landmarks.ConnectTo(node.In("NORM_LANDMARKS")); + return node.Out("NORM_FILTERED_LANDMARKS").Cast(); +} + +Stream SmoothLandmarksVisibility(Stream landmarks, + float low_pass_filter_alpha, + Graph& graph) { + auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph); + landmarks.ConnectTo(node.In("LANDMARKS")); + return node.Out("FILTERED_LANDMARKS").Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/smoothing.h b/mediapipe/framework/api2/stream/smoothing.h new file mode 100644 index 000000000..4c395141c --- /dev/null +++ b/mediapipe/framework/api2/stream/smoothing.h @@ -0,0 +1,119 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_ + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +struct OneEuroFilterConfig { + float min_cutoff; + float beta; + float derivate_cutoff; +}; + +// Updates graph to smooth normalized landmarks and returns resulting stream. +// +// @landmarks - normalized landmarks. +// @image_size - size of image where landmarks were detected. +// @scale_roi - can be used to specify object scale. +// @config - filter config. +// @graph - graph to update. +// +// Returns: smoothed/filtered normalized landmarks. +// +// NOTE: one-euro filter is exposed only. Other filter options can be exposed +// on demand. +Stream SmoothLandmarks( + Stream landmarks, + Stream> image_size, + std::optional> scale_roi, + const OneEuroFilterConfig& config, Graph& graph); + +// Updates graph to smooth absolute landmarks and returns resulting stream. +// +// @landmarks - absolute landmarks. +// @scale_roi - can be used to specify object scale. +// @config - filter config. +// @graph - graph to update. +// +// Returns: smoothed/filtered absolute landmarks. +// +// NOTE: one-euro filter is exposed only. Other filter options can be exposed +// on demand. +Stream SmoothLandmarks( + Stream landmarks, + std::optional> scale_roi, + const OneEuroFilterConfig& config, Graph& graph); + +// Updates graph to smooth normalized landmarks and returns resulting stream. +// +// @landmarks - normalized landmarks vector. +// @tracking_ids - tracking IDs associated with landmarks +// @image_size - size of image where landmarks were detected. +// @scale_roi - can be used to specify object scales. +// @config - filter config. +// @graph - graph to update. +// +// Returns: smoothed/filtered normalized landmarks. +// +// NOTE: one-euro filter is exposed only. Other filter options can be exposed +// on demand. +Stream> SmoothMultiLandmarks( + Stream> landmarks, + Stream> tracking_ids, + Stream> image_size, + std::optional>> scale_roi, + const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph); + +// Updates graph to smooth absolute landmarks and returns resulting stream. +// +// @landmarks - absolute landmarks vector. +// @tracking_ids - tracking IDs associated with landmarks +// @scale_roi - can be used to specify object scales. +// @config - filter config. +// @graph - graph to update. +// +// Returns: smoothed/filtered absolute landmarks. +// +// NOTE: one-euro filter is exposed only. Other filter options can be exposed +// on demand. +Stream> SmoothMultiWorldLandmarks( + Stream> landmarks, + Stream> tracking_ids, + std::optional>> scale_roi, + const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph); + +// Updates graph to smooth visibility of landmarks. +// +// @landmarks - normalized landmarks. +// @low_pass_filter_alpha - low pass filter alpha to use for smoothing. +// @graph - graph to update. +// +// Returns: normalized landmarks containing smoothed visibility. +Stream SmoothLandmarksVisibility( + Stream landmarks, + float low_pass_filter_alpha, Graph& graph); + +// Updates graph to smooth visibility of landmarks. +// +// @landmarks - absolute landmarks. +// @low_pass_filter_alpha - low pass filter alpha to use for smoothing. +// @graph - graph to update. +// +// Returns: absolute landmarks containing smoothed visibility. +Stream SmoothLandmarksVisibility( + Stream landmarks, float low_pass_filter_alpha, + mediapipe::api2::builder::Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_ diff --git a/mediapipe/framework/api2/stream/smoothing_test.cc b/mediapipe/framework/api2/stream/smoothing_test.cc new file mode 100644 index 000000000..7f6a191d0 --- /dev/null +++ b/mediapipe/framework/api2/stream/smoothing_test.cc @@ -0,0 +1,230 @@ +#include "mediapipe/framework/api2/stream/smoothing.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(Smoothing, NormLandmarks) { + mediapipe::api2::builder::Graph graph; + + Stream norm_landmarks = + graph.In("NORM_LANDMARKS").Cast(); + Stream> image_size = + graph.In("IMAGE_SIZE").Cast>(); + Stream scale_roi = + graph.In("SCALE_ROI").Cast(); + SmoothLandmarks( + norm_landmarks, image_size, scale_roi, + {.min_cutoff = 0.5f, .beta = 100.0f, .derivate_cutoff = 20.0f}, graph) + .SetName("smoothed_norm_landmarks"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarksSmoothingCalculator" + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "NORM_LANDMARKS:__stream_1" + input_stream: "OBJECT_SCALE_ROI:__stream_2" + output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks" + options { + [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { + one_euro_filter { + min_cutoff: 0.5 + beta: 100 + derivate_cutoff: 20 + disable_value_scaling: false + } + } + } + } + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "NORM_LANDMARKS:__stream_1" + input_stream: "SCALE_ROI:__stream_2" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(Smoothing, Landmarks) { + mediapipe::api2::builder::Graph graph; + + Stream landmarks = graph.In("LANDMARKS").Cast(); + SmoothLandmarks(landmarks, /*scale_roi=*/std::nullopt, + {.min_cutoff = 1.5f, .beta = 90.0f, .derivate_cutoff = 10.0f}, + graph) + .SetName("smoothed_landmarks"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarksSmoothingCalculator" + input_stream: "LANDMARKS:__stream_0" + output_stream: "FILTERED_LANDMARKS:smoothed_landmarks" + options { + [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { + one_euro_filter { + min_cutoff: 1.5 + beta: 90 + derivate_cutoff: 10 + disable_value_scaling: true + } + } + } + } + input_stream: "LANDMARKS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(Smoothing, MultiLandmarks) { + mediapipe::api2::builder::Graph graph; + + Stream> norm_landmarks = + graph.In("NORM_LANDMARKS").Cast>(); + Stream> tracking_ids = + graph.In("TRACKING_IDS").Cast>(); + Stream> image_size = + graph.In("IMAGE_SIZE").Cast>(); + Stream> scale_roi = + graph.In("SCALE_ROI").Cast>(); + auto config = LandmarksSmoothingCalculatorOptions(); + config.mutable_no_filter(); + SmoothMultiLandmarks(norm_landmarks, tracking_ids, image_size, scale_roi, + config, graph) + .SetName("smoothed_norm_landmarks"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "MultiLandmarksSmoothingCalculator" + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "NORM_LANDMARKS:__stream_1" + input_stream: "OBJECT_SCALE_ROI:__stream_2" + input_stream: "TRACKING_IDS:__stream_3" + output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks" + options { + [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { no_filter {} } + } + } + input_stream: "IMAGE_SIZE:__stream_0" + input_stream: "NORM_LANDMARKS:__stream_1" + input_stream: "SCALE_ROI:__stream_2" + input_stream: "TRACKING_IDS:__stream_3" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(Smoothing, MultiWorldLandmarks) { + mediapipe::api2::builder::Graph graph; + + Stream> landmarks = + graph.In("LANDMARKS").Cast>(); + Stream> tracking_ids = + graph.In("TRACKING_IDS").Cast>(); + auto config = LandmarksSmoothingCalculatorOptions(); + config.mutable_no_filter(); + SmoothMultiWorldLandmarks(landmarks, tracking_ids, /*scale_roi=*/std::nullopt, + config, graph) + .SetName("smoothed_landmarks"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "MultiWorldLandmarksSmoothingCalculator" + input_stream: "LANDMARKS:__stream_0" + input_stream: "TRACKING_IDS:__stream_1" + output_stream: "FILTERED_LANDMARKS:smoothed_landmarks" + options { + [mediapipe.LandmarksSmoothingCalculatorOptions.ext] { no_filter {} } + } + } + input_stream: "LANDMARKS:__stream_0" + input_stream: "TRACKING_IDS:__stream_1" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(Smoothing, NormLandmarksVisibility) { + mediapipe::api2::builder::Graph graph; + + Stream norm_landmarks = + graph.In("NORM_LANDMARKS").Cast(); + Stream smoothed_norm_landmarks = + SmoothLandmarksVisibility(norm_landmarks, /*low_pass_filter_alpha=*/0.9f, + graph); + smoothed_norm_landmarks.SetName("smoothed_norm_landmarks"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "VisibilitySmoothingCalculator" + input_stream: "NORM_LANDMARKS:__stream_0" + output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks" + options { + [mediapipe.VisibilitySmoothingCalculatorOptions.ext] { + low_pass_filter { alpha: 0.9 } + } + } + } + input_stream: "NORM_LANDMARKS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +TEST(Smoothing, LandmarksVisibility) { + mediapipe::api2::builder::Graph graph; + + Stream landmarks = graph.In("LANDMARKS").Cast(); + Stream smoothed_landmarks = SmoothLandmarksVisibility( + landmarks, /*low_pass_filter_alpha=*/0.9f, graph); + smoothed_landmarks.SetName("smoothed_landmarks"); + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "VisibilitySmoothingCalculator" + input_stream: "LANDMARKS:__stream_0" + output_stream: "FILTERED_LANDMARKS:smoothed_landmarks" + options { + [mediapipe.VisibilitySmoothingCalculatorOptions.ext] { + low_pass_filter { alpha: 0.9 } + } + } + } + input_stream: "LANDMARKS:__stream_0" + )pb"))); + + CalculatorGraph calculator_graph; + MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder