diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 4444938ac..f57dd46b5 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -56,3 +56,29 @@ cc_test( "//mediapipe/gpu:gpu_buffer", ], ) + +cc_library( + name = "rect_transformation", + srcs = ["rect_transformation.cc"], + hdrs = ["rect_transformation.h"], + deps = [ + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:rect_cc_proto", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "rect_transformation_test", + srcs = ["rect_transformation_test.cc"], + deps = [ + ":rect_transformation", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) diff --git a/mediapipe/framework/api2/stream/rect_transformation.cc b/mediapipe/framework/api2/stream/rect_transformation.cc new file mode 100644 index 000000000..3e63375fc --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation.cc @@ -0,0 +1,108 @@ +#include "mediapipe/framework/api2/stream/rect_transformation.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::GenericNode; +using ::mediapipe::api2::builder::Graph; + +template +Stream InternalScaleAndShift( + Stream transformee, Stream> image_size, + float scale_x_factor, float scale_y_factor, std::optional shift_x, + std::optional shift_y, bool square_long, Graph& graph) { + auto& node = graph.AddNode("RectTransformationCalculator"); + auto& node_opts = + node.GetOptions(); + node_opts.set_scale_x(scale_x_factor); + node_opts.set_scale_y(scale_y_factor); + if (shift_x) { + node_opts.set_shift_x(shift_x.value()); + } + if (shift_y) { + node_opts.set_shift_y(shift_y.value()); + } + if (square_long) { + node_opts.set_square_long(square_long); + } + image_size.ConnectTo(node.In("IMAGE_SIZE")); + if constexpr (std::is_same_v>) { + transformee.ConnectTo(node.In("NORM_RECTS")); + } else if constexpr (std::is_same_v) { + transformee.ConnectTo(node.In("NORM_RECT")); + } else { + static_assert(dependent_false::value, "Unsupported type."); + } + return node.Out("").template Cast(); +} + +} // namespace + +Stream ScaleAndMakeSquare( + Stream rect, Stream> image_size, + float scale_x_factor, float scale_y_factor, Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + /*shift_x=*/std::nullopt, + /*shift_y=*/std::nullopt, + /*square_long=*/true, graph); +} + +Stream Scale(Stream rect, + Stream> image_size, + float scale_x_factor, float scale_y_factor, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + /*shift_x=*/std::nullopt, + /*shift_y=*/std::nullopt, + /*square_long=*/false, graph); +} + +Stream> ScaleAndShiftAndMakeSquareLong( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, Graph& graph) { + return InternalScaleAndShift(rects, image_size, scale_x_factor, + scale_y_factor, shift_x, shift_y, + /*square_long=*/true, graph); +} + +Stream> ScaleAndShift( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, Graph& graph) { + return InternalScaleAndShift(rects, image_size, scale_x_factor, + scale_y_factor, shift_x, shift_y, + /*square_long=*/false, graph); +} + +Stream ScaleAndShiftAndMakeSquareLong( + Stream rect, Stream> image_size, + float scale_x_factor, float scale_y_factor, float shift_x, float shift_y, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + shift_x, shift_y, + /*square_long=*/true, graph); +} + +Stream ScaleAndShift(Stream rect, + Stream> image_size, + float scale_x_factor, float scale_y_factor, + float shift_x, float shift_y, + Graph& graph) { + return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor, + shift_x, shift_y, /*square_long=*/false, graph); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/rect_transformation.h b/mediapipe/framework/api2/stream/rect_transformation.h new file mode 100644 index 000000000..9f6a98980 --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation.h @@ -0,0 +1,67 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to scale @rect according to passed parameters. +Stream Scale(Stream rect, + Stream> image_size, + float scale_x_factor, + float scale_y_factor, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale @rect according to passed parameters and make it a +// square that has the same center and rotation, and with the side of the square +// equal to the long side of the rect. +// +// TODO: consider removing after migrating to `Scale`. +Stream ScaleAndMakeSquare( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift vector of @rects according to parameters. +Stream> ScaleAndShift( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift vector of @rects according to passed +// parameters and make each a square that has the same center and rotation, and +// with the side of the square equal to the long side of a particular rect. +// +// TODO: consider removing after migrating to `ScaleAndShift`. +Stream> ScaleAndShiftAndMakeSquareLong( + Stream> rects, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale, shift @rect according to passed parameters. +Stream ScaleAndShift( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to scale and shift @rect according to passed parameters and +// make it a square that has the same center and rotation, and with the side of +// the square equal to the long side of the rect. +// +// TODO: consider removing after migrating to `ScaleAndShift`. +Stream ScaleAndShiftAndMakeSquareLong( + Stream rect, + Stream> image_size, float scale_x_factor, + float scale_y_factor, float shift_x, float shift_y, + mediapipe::api2::builder::Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_ diff --git a/mediapipe/framework/api2/stream/rect_transformation_test.cc b/mediapipe/framework/api2/stream/rect_transformation_test.cc new file mode 100644 index 000000000..79fa66175 --- /dev/null +++ b/mediapipe/framework/api2/stream/rect_transformation_test.cc @@ -0,0 +1,217 @@ +#include "mediapipe/framework/api2/stream/rect_transformation.h" + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe::api2::builder { + +namespace { + +using ::mediapipe::NormalizedRect; + +TEST(RectTransformation, ScaleAndMakeSquare) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = ScaleAndMakeSquare( + rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + square_long: true + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, Scale) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = + Scale(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShift) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = + ScaleAndShift(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftAndMakeSquareLong) { + mediapipe::api2::builder::Graph graph; + + Stream rect = graph.In("RECT").Cast(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream transformed_rect = ScaleAndShiftAndMakeSquareLong( + rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rect.SetName("transformed_rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECT:__stream_0" + output_stream: "transformed_rect" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + square_long: true + } + } + } + input_stream: "RECT:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftMultipleRects) { + mediapipe::api2::builder::Graph graph; + + Stream> rects = + graph.In("RECTS").Cast>(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream> transformed_rects = + ScaleAndShift(rects, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rects.SetName("transformed_rects"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECTS:__stream_0" + output_stream: "transformed_rects" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + } + } + } + input_stream: "RECTS:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +TEST(RectTransformation, ScaleAndShiftAndMakeSquareLongMultipleRects) { + mediapipe::api2::builder::Graph graph; + + Stream> rects = + graph.In("RECTS").Cast>(); + Stream> size = + graph.In("SIZE").Cast>(); + Stream> transformed_rects = + ScaleAndShiftAndMakeSquareLong(rects, size, /*scale_x_factor=*/2, + /*scale_y_factor=*/7, + /*shift_x=*/10, /*shift_y=*/0.5f, graph); + transformed_rects.SetName("transformed_rects"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "RectTransformationCalculator" + input_stream: "IMAGE_SIZE:__stream_1" + input_stream: "NORM_RECTS:__stream_0" + output_stream: "transformed_rects" + options { + [mediapipe.RectTransformationCalculatorOptions.ext] { + scale_x: 2 + scale_y: 7 + shift_x: 10 + shift_y: 0.5 + square_long: true + } + } + } + input_stream: "RECTS:__stream_0" + input_stream: "SIZE:__stream_1" + )pb"))); +} + +} // namespace +} // namespace mediapipe::api2::builder