smoothing stream utility function.

PiperOrigin-RevId: 569074973
This commit is contained in:
MediaPipe Team 2023-09-27 23:17:32 -07:00 committed by Copybara-Service
parent 9edb4cd753
commit a577dc3043
4 changed files with 517 additions and 0 deletions

View File

@ -288,6 +288,43 @@ cc_test(
],
)
cc_library(
name = "smoothing",
srcs = ["smoothing.cc"],
hdrs = ["smoothing.h"],
deps = [
"//mediapipe/calculators/util:landmarks_smoothing_calculator",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/calculators/util:multi_landmarks_smoothing_calculator",
"//mediapipe/calculators/util:multi_world_landmarks_smoothing_calculator",
"//mediapipe/calculators/util:visibility_smoothing_calculator",
"//mediapipe/calculators/util:visibility_smoothing_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"@com_google_absl//absl/types:optional",
],
)
cc_test(
name = "smoothing_test",
srcs = ["smoothing_test.cc"],
deps = [
":smoothing",
"//mediapipe/calculators/util:landmarks_smoothing_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "split",
hdrs = ["split.h"],

View File

@ -0,0 +1,131 @@
#include "mediapipe/framework/api2/stream/smoothing.h"
#include <optional>
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/calculators/util/visibility_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe::api2::builder {
namespace {
void SetFilterConfig(const OneEuroFilterConfig& config,
bool disable_value_scaling, GenericNode& node) {
auto& smoothing_node_opts =
node.GetOptions<LandmarksSmoothingCalculatorOptions>();
auto& one_euro_filter = *smoothing_node_opts.mutable_one_euro_filter();
one_euro_filter.set_min_cutoff(config.min_cutoff);
one_euro_filter.set_derivate_cutoff(config.derivate_cutoff);
one_euro_filter.set_beta(config.beta);
one_euro_filter.set_disable_value_scaling(disable_value_scaling);
}
void SetFilterConfig(const LandmarksSmoothingCalculatorOptions& config,
GenericNode& node) {
auto& smoothing_node_opts =
node.GetOptions<LandmarksSmoothingCalculatorOptions>();
smoothing_node_opts = config;
}
GenericNode& AddVisibilitySmoothingNode(float low_pass_filter_alpha,
Graph& graph) {
auto& smoothing_node = graph.AddNode("VisibilitySmoothingCalculator");
auto& smoothing_node_opts =
smoothing_node.GetOptions<VisibilitySmoothingCalculatorOptions>();
smoothing_node_opts.mutable_low_pass_filter()->set_alpha(
low_pass_filter_alpha);
return smoothing_node;
}
} // namespace
Stream<NormalizedLandmarkList> SmoothLandmarks(
Stream<NormalizedLandmarkList> landmarks,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator");
SetFilterConfig(config, /*disable_value_scaling=*/false, smoothing_node);
landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS"));
image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("NORM_FILTERED_LANDMARKS")
.Cast<NormalizedLandmarkList>();
}
Stream<LandmarkList> SmoothLandmarks(
Stream<LandmarkList> landmarks,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("LandmarksSmoothingCalculator");
SetFilterConfig(config, /*disable_value_scaling=*/true, smoothing_node);
landmarks.ConnectTo(smoothing_node.In("LANDMARKS"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("FILTERED_LANDMARKS").Cast<LandmarkList>();
}
Stream<std::vector<NormalizedLandmarkList>> SmoothMultiLandmarks(
Stream<std::vector<NormalizedLandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<std::vector<NormalizedRect>>> scale_roi,
const LandmarksSmoothingCalculatorOptions& config, Graph& graph) {
auto& smoothing_node = graph.AddNode("MultiLandmarksSmoothingCalculator");
SetFilterConfig(config, smoothing_node);
landmarks.ConnectTo(smoothing_node.In("NORM_LANDMARKS"));
tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS"));
image_size.ConnectTo(smoothing_node.In("IMAGE_SIZE"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("NORM_FILTERED_LANDMARKS")
.Cast<std::vector<NormalizedLandmarkList>>();
}
Stream<std::vector<LandmarkList>> SmoothMultiWorldLandmarks(
Stream<std::vector<LandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
std::optional<Stream<std::vector<Rect>>> scale_roi,
const LandmarksSmoothingCalculatorOptions& config, Graph& graph) {
auto& smoothing_node =
graph.AddNode("MultiWorldLandmarksSmoothingCalculator");
SetFilterConfig(config, smoothing_node);
landmarks.ConnectTo(smoothing_node.In("LANDMARKS"));
tracking_ids.ConnectTo(smoothing_node.In("TRACKING_IDS"));
if (scale_roi) {
scale_roi->ConnectTo(smoothing_node.In("OBJECT_SCALE_ROI"));
}
return smoothing_node.Out("FILTERED_LANDMARKS")
.Cast<std::vector<LandmarkList>>();
}
Stream<NormalizedLandmarkList> SmoothLandmarksVisibility(
Stream<NormalizedLandmarkList> landmarks, float low_pass_filter_alpha,
Graph& graph) {
auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph);
landmarks.ConnectTo(node.In("NORM_LANDMARKS"));
return node.Out("NORM_FILTERED_LANDMARKS").Cast<NormalizedLandmarkList>();
}
Stream<LandmarkList> SmoothLandmarksVisibility(Stream<LandmarkList> landmarks,
float low_pass_filter_alpha,
Graph& graph) {
auto& node = AddVisibilitySmoothingNode(low_pass_filter_alpha, graph);
landmarks.ConnectTo(node.In("LANDMARKS"));
return node.Out("FILTERED_LANDMARKS").Cast<LandmarkList>();
}
} // namespace mediapipe::api2::builder

View File

@ -0,0 +1,119 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_
#include <cstdint>
#include <optional>
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe::api2::builder {
struct OneEuroFilterConfig {
float min_cutoff;
float beta;
float derivate_cutoff;
};
// Updates graph to smooth normalized landmarks and returns resulting stream.
//
// @landmarks - normalized landmarks.
// @image_size - size of image where landmarks were detected.
// @scale_roi - can be used to specify object scale.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered normalized landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<mediapipe::NormalizedLandmarkList> SmoothLandmarks(
Stream<mediapipe::NormalizedLandmarkList> landmarks,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph);
// Updates graph to smooth absolute landmarks and returns resulting stream.
//
// @landmarks - absolute landmarks.
// @scale_roi - can be used to specify object scale.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered absolute landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<mediapipe::LandmarkList> SmoothLandmarks(
Stream<mediapipe::LandmarkList> landmarks,
std::optional<Stream<NormalizedRect>> scale_roi,
const OneEuroFilterConfig& config, Graph& graph);
// Updates graph to smooth normalized landmarks and returns resulting stream.
//
// @landmarks - normalized landmarks vector.
// @tracking_ids - tracking IDs associated with landmarks
// @image_size - size of image where landmarks were detected.
// @scale_roi - can be used to specify object scales.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered normalized landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<std::vector<mediapipe::NormalizedLandmarkList>> SmoothMultiLandmarks(
Stream<std::vector<mediapipe::NormalizedLandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
Stream<std::pair<int, int>> image_size,
std::optional<Stream<std::vector<NormalizedRect>>> scale_roi,
const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph);
// Updates graph to smooth absolute landmarks and returns resulting stream.
//
// @landmarks - absolute landmarks vector.
// @tracking_ids - tracking IDs associated with landmarks
// @scale_roi - can be used to specify object scales.
// @config - filter config.
// @graph - graph to update.
//
// Returns: smoothed/filtered absolute landmarks.
//
// NOTE: one-euro filter is exposed only. Other filter options can be exposed
// on demand.
Stream<std::vector<mediapipe::LandmarkList>> SmoothMultiWorldLandmarks(
Stream<std::vector<mediapipe::LandmarkList>> landmarks,
Stream<std::vector<int64_t>> tracking_ids,
std::optional<Stream<std::vector<mediapipe::Rect>>> scale_roi,
const mediapipe::LandmarksSmoothingCalculatorOptions& config, Graph& graph);
// Updates graph to smooth visibility of landmarks.
//
// @landmarks - normalized landmarks.
// @low_pass_filter_alpha - low pass filter alpha to use for smoothing.
// @graph - graph to update.
//
// Returns: normalized landmarks containing smoothed visibility.
Stream<mediapipe::NormalizedLandmarkList> SmoothLandmarksVisibility(
Stream<mediapipe::NormalizedLandmarkList> landmarks,
float low_pass_filter_alpha, Graph& graph);
// Updates graph to smooth visibility of landmarks.
//
// @landmarks - absolute landmarks.
// @low_pass_filter_alpha - low pass filter alpha to use for smoothing.
// @graph - graph to update.
//
// Returns: absolute landmarks containing smoothed visibility.
Stream<mediapipe::LandmarkList> SmoothLandmarksVisibility(
Stream<mediapipe::LandmarkList> landmarks, float low_pass_filter_alpha,
mediapipe::api2::builder::Graph& graph);
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_SMOOTHING_H_

View File

@ -0,0 +1,230 @@
#include "mediapipe/framework/api2/stream/smoothing.h"
#include <cstdint>
#include <optional>
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "mediapipe/calculators/util/landmarks_smoothing_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder {
namespace {
TEST(Smoothing, NormLandmarks) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedLandmarkList> norm_landmarks =
graph.In("NORM_LANDMARKS").Cast<NormalizedLandmarkList>();
Stream<std::pair<int, int>> image_size =
graph.In("IMAGE_SIZE").Cast<std::pair<int, int>>();
Stream<NormalizedRect> scale_roi =
graph.In("SCALE_ROI").Cast<NormalizedRect>();
SmoothLandmarks(
norm_landmarks, image_size, scale_roi,
{.min_cutoff = 0.5f, .beta = 100.0f, .derivate_cutoff = 20.0f}, graph)
.SetName("smoothed_norm_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "LandmarksSmoothingCalculator"
input_stream: "IMAGE_SIZE:__stream_0"
input_stream: "NORM_LANDMARKS:__stream_1"
input_stream: "OBJECT_SCALE_ROI:__stream_2"
output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks"
options {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
one_euro_filter {
min_cutoff: 0.5
beta: 100
derivate_cutoff: 20
disable_value_scaling: false
}
}
}
}
input_stream: "IMAGE_SIZE:__stream_0"
input_stream: "NORM_LANDMARKS:__stream_1"
input_stream: "SCALE_ROI:__stream_2"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(Smoothing, Landmarks) {
mediapipe::api2::builder::Graph graph;
Stream<LandmarkList> landmarks = graph.In("LANDMARKS").Cast<LandmarkList>();
SmoothLandmarks(landmarks, /*scale_roi=*/std::nullopt,
{.min_cutoff = 1.5f, .beta = 90.0f, .derivate_cutoff = 10.0f},
graph)
.SetName("smoothed_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "LandmarksSmoothingCalculator"
input_stream: "LANDMARKS:__stream_0"
output_stream: "FILTERED_LANDMARKS:smoothed_landmarks"
options {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] {
one_euro_filter {
min_cutoff: 1.5
beta: 90
derivate_cutoff: 10
disable_value_scaling: true
}
}
}
}
input_stream: "LANDMARKS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(Smoothing, MultiLandmarks) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<NormalizedLandmarkList>> norm_landmarks =
graph.In("NORM_LANDMARKS").Cast<std::vector<NormalizedLandmarkList>>();
Stream<std::vector<int64_t>> tracking_ids =
graph.In("TRACKING_IDS").Cast<std::vector<int64_t>>();
Stream<std::pair<int, int>> image_size =
graph.In("IMAGE_SIZE").Cast<std::pair<int, int>>();
Stream<std::vector<NormalizedRect>> scale_roi =
graph.In("SCALE_ROI").Cast<std::vector<NormalizedRect>>();
auto config = LandmarksSmoothingCalculatorOptions();
config.mutable_no_filter();
SmoothMultiLandmarks(norm_landmarks, tracking_ids, image_size, scale_roi,
config, graph)
.SetName("smoothed_norm_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "MultiLandmarksSmoothingCalculator"
input_stream: "IMAGE_SIZE:__stream_0"
input_stream: "NORM_LANDMARKS:__stream_1"
input_stream: "OBJECT_SCALE_ROI:__stream_2"
input_stream: "TRACKING_IDS:__stream_3"
output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks"
options {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] { no_filter {} }
}
}
input_stream: "IMAGE_SIZE:__stream_0"
input_stream: "NORM_LANDMARKS:__stream_1"
input_stream: "SCALE_ROI:__stream_2"
input_stream: "TRACKING_IDS:__stream_3"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(Smoothing, MultiWorldLandmarks) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<LandmarkList>> landmarks =
graph.In("LANDMARKS").Cast<std::vector<LandmarkList>>();
Stream<std::vector<int64_t>> tracking_ids =
graph.In("TRACKING_IDS").Cast<std::vector<int64_t>>();
auto config = LandmarksSmoothingCalculatorOptions();
config.mutable_no_filter();
SmoothMultiWorldLandmarks(landmarks, tracking_ids, /*scale_roi=*/std::nullopt,
config, graph)
.SetName("smoothed_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "MultiWorldLandmarksSmoothingCalculator"
input_stream: "LANDMARKS:__stream_0"
input_stream: "TRACKING_IDS:__stream_1"
output_stream: "FILTERED_LANDMARKS:smoothed_landmarks"
options {
[mediapipe.LandmarksSmoothingCalculatorOptions.ext] { no_filter {} }
}
}
input_stream: "LANDMARKS:__stream_0"
input_stream: "TRACKING_IDS:__stream_1"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(Smoothing, NormLandmarksVisibility) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedLandmarkList> norm_landmarks =
graph.In("NORM_LANDMARKS").Cast<NormalizedLandmarkList>();
Stream<NormalizedLandmarkList> smoothed_norm_landmarks =
SmoothLandmarksVisibility(norm_landmarks, /*low_pass_filter_alpha=*/0.9f,
graph);
smoothed_norm_landmarks.SetName("smoothed_norm_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "VisibilitySmoothingCalculator"
input_stream: "NORM_LANDMARKS:__stream_0"
output_stream: "NORM_FILTERED_LANDMARKS:smoothed_norm_landmarks"
options {
[mediapipe.VisibilitySmoothingCalculatorOptions.ext] {
low_pass_filter { alpha: 0.9 }
}
}
}
input_stream: "NORM_LANDMARKS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
TEST(Smoothing, LandmarksVisibility) {
mediapipe::api2::builder::Graph graph;
Stream<LandmarkList> landmarks = graph.In("LANDMARKS").Cast<LandmarkList>();
Stream<LandmarkList> smoothed_landmarks = SmoothLandmarksVisibility(
landmarks, /*low_pass_filter_alpha=*/0.9f, graph);
smoothed_landmarks.SetName("smoothed_landmarks");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "VisibilitySmoothingCalculator"
input_stream: "LANDMARKS:__stream_0"
output_stream: "FILTERED_LANDMARKS:smoothed_landmarks"
options {
[mediapipe.VisibilitySmoothingCalculatorOptions.ext] {
low_pass_filter { alpha: 0.9 }
}
}
}
input_stream: "LANDMARKS:__stream_0"
)pb")));
CalculatorGraph calculator_graph;
MP_EXPECT_OK(calculator_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder