diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index f57dd46b5..086c68a51 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -2,6 +2,33 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "landmarks_projection", + srcs = ["landmarks_projection.cc"], + hdrs = ["landmarks_projection.h"], + deps = [ + "//mediapipe/calculators/util:landmark_projection_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) + +cc_test( + name = "landmarks_projection_test", + srcs = ["landmarks_projection_test.cc"], + deps = [ + ":landmarks_projection", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:status_matchers", + ], +) + cc_library( name = "loopback", hdrs = ["loopback.h"], diff --git a/mediapipe/framework/api2/stream/landmarks_projection.cc b/mediapipe/framework/api2/stream/landmarks_projection.cc new file mode 100644 index 000000000..1735dc1d6 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection.cc @@ -0,0 +1,20 @@ +#include "mediapipe/framework/api2/stream/landmarks_projection.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +Stream ProjectLandmarks( + Stream landmarks, + Stream> projection_matrix, Graph& graph) { + auto& projector = graph.AddNode("LandmarkProjectionCalculator"); + landmarks.ConnectTo(projector.In("NORM_LANDMARKS")); + projection_matrix.ConnectTo(projector.In("PROJECTION_MATRIX")); + return projector.Out("NORM_LANDMARKS") + .Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/landmarks_projection.h b/mediapipe/framework/api2/stream/landmarks_projection.h new file mode 100644 index 000000000..3a9508a45 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection.h @@ -0,0 +1,23 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to project predicted @landmarks back to the original @image +// based on @projection_matrix +// +// @landmarks - landmarks (NormalizedLandmarkList) stream, output from the model +// @projection_matrix - matrix that stores the preprocessing information +// @graph - mediapipe graph to update. +Stream ProjectLandmarks( + Stream landmarks, + Stream> projection_matrix, Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_LANDMARKS_PROJECTION_H_ diff --git a/mediapipe/framework/api2/stream/landmarks_projection_test.cc b/mediapipe/framework/api2/stream/landmarks_projection_test.cc new file mode 100644 index 000000000..2f743d808 --- /dev/null +++ b/mediapipe/framework/api2/stream/landmarks_projection_test.cc @@ -0,0 +1,45 @@ +#include "mediapipe/framework/api2/stream/landmarks_projection.h" + +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(LandmarksProjection, ProjectLandmarks) { + mediapipe::api2::builder::Graph graph; + + Stream landmarks = + graph.In("NORM_LANDMARKS").Cast(); + Stream> projection_matrix = + graph.In("PROJECTION_MATRIX").Cast>(); + Stream result = + ProjectLandmarks(landmarks, projection_matrix, graph); + result.SetName("landmarks_value"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "LandmarkProjectionCalculator" + input_stream: "NORM_LANDMARKS:__stream_0" + input_stream: "PROJECTION_MATRIX:__stream_1" + output_stream: "NORM_LANDMARKS:landmarks_value" + } + input_stream: "NORM_LANDMARKS:__stream_0" + input_stream: "PROJECTION_MATRIX:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder