Move stream API rect_transformation to third_party.

PiperOrigin-RevId: 559652775
This commit is contained in:
MediaPipe Team 2023-08-23 23:09:29 -07:00 committed by Copybara-Service
parent b2446c6ca8
commit 4b1b6ae7fb
4 changed files with 418 additions and 0 deletions

View File

@ -56,3 +56,29 @@ cc_test(
"//mediapipe/gpu:gpu_buffer",
],
)
cc_library(
name = "rect_transformation",
srcs = ["rect_transformation.cc"],
hdrs = ["rect_transformation.h"],
deps = [
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:rect_cc_proto",
"@com_google_absl//absl/types:optional",
],
)
cc_test(
name = "rect_transformation_test",
srcs = ["rect_transformation_test.cc"],
deps = [
":rect_transformation",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)

View File

@ -0,0 +1,108 @@
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include <optional>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/types/optional.h"
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe::api2::builder {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
template <typename TransformeeT>
Stream<TransformeeT> InternalScaleAndShift(
Stream<TransformeeT> transformee, Stream<std::pair<int, int>> image_size,
float scale_x_factor, float scale_y_factor, std::optional<float> shift_x,
std::optional<float> shift_y, bool square_long, Graph& graph) {
auto& node = graph.AddNode("RectTransformationCalculator");
auto& node_opts =
node.GetOptions<mediapipe::RectTransformationCalculatorOptions>();
node_opts.set_scale_x(scale_x_factor);
node_opts.set_scale_y(scale_y_factor);
if (shift_x) {
node_opts.set_shift_x(shift_x.value());
}
if (shift_y) {
node_opts.set_shift_y(shift_y.value());
}
if (square_long) {
node_opts.set_square_long(square_long);
}
image_size.ConnectTo(node.In("IMAGE_SIZE"));
if constexpr (std::is_same_v<TransformeeT, std::vector<NormalizedRect>>) {
transformee.ConnectTo(node.In("NORM_RECTS"));
} else if constexpr (std::is_same_v<TransformeeT, NormalizedRect>) {
transformee.ConnectTo(node.In("NORM_RECT"));
} else {
static_assert(dependent_false<TransformeeT>::value, "Unsupported type.");
}
return node.Out("").template Cast<TransformeeT>();
}
} // namespace
Stream<NormalizedRect> ScaleAndMakeSquare(
Stream<NormalizedRect> rect, Stream<std::pair<int, int>> image_size,
float scale_x_factor, float scale_y_factor, Graph& graph) {
return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor,
/*shift_x=*/std::nullopt,
/*shift_y=*/std::nullopt,
/*square_long=*/true, graph);
}
Stream<NormalizedRect> Scale(Stream<NormalizedRect> rect,
Stream<std::pair<int, int>> image_size,
float scale_x_factor, float scale_y_factor,
Graph& graph) {
return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor,
/*shift_x=*/std::nullopt,
/*shift_y=*/std::nullopt,
/*square_long=*/false, graph);
}
Stream<std::vector<NormalizedRect>> ScaleAndShiftAndMakeSquareLong(
Stream<std::vector<NormalizedRect>> rects,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y, Graph& graph) {
return InternalScaleAndShift(rects, image_size, scale_x_factor,
scale_y_factor, shift_x, shift_y,
/*square_long=*/true, graph);
}
Stream<std::vector<NormalizedRect>> ScaleAndShift(
Stream<std::vector<NormalizedRect>> rects,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y, Graph& graph) {
return InternalScaleAndShift(rects, image_size, scale_x_factor,
scale_y_factor, shift_x, shift_y,
/*square_long=*/false, graph);
}
Stream<NormalizedRect> ScaleAndShiftAndMakeSquareLong(
Stream<NormalizedRect> rect, Stream<std::pair<int, int>> image_size,
float scale_x_factor, float scale_y_factor, float shift_x, float shift_y,
Graph& graph) {
return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor,
shift_x, shift_y,
/*square_long=*/true, graph);
}
Stream<NormalizedRect> ScaleAndShift(Stream<NormalizedRect> rect,
Stream<std::pair<int, int>> image_size,
float scale_x_factor, float scale_y_factor,
float shift_x, float shift_y,
Graph& graph) {
return InternalScaleAndShift(rect, image_size, scale_x_factor, scale_y_factor,
shift_x, shift_y, /*square_long=*/false, graph);
}
} // namespace mediapipe::api2::builder

View File

@ -0,0 +1,67 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe::api2::builder {
// Updates @graph to scale @rect according to passed parameters.
Stream<mediapipe::NormalizedRect> Scale(Stream<mediapipe::NormalizedRect> rect,
Stream<std::pair<int, int>> image_size,
float scale_x_factor,
float scale_y_factor,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to scale @rect according to passed parameters and make it a
// square that has the same center and rotation, and with the side of the square
// equal to the long side of the rect.
//
// TODO: consider removing after migrating to `Scale`.
Stream<mediapipe::NormalizedRect> ScaleAndMakeSquare(
Stream<mediapipe::NormalizedRect> rect,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, mediapipe::api2::builder::Graph& graph);
// Updates @graph to scale and shift vector of @rects according to parameters.
Stream<std::vector<mediapipe::NormalizedRect>> ScaleAndShift(
Stream<std::vector<mediapipe::NormalizedRect>> rects,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to scale and shift vector of @rects according to passed
// parameters and make each a square that has the same center and rotation, and
// with the side of the square equal to the long side of a particular rect.
//
// TODO: consider removing after migrating to `ScaleAndShift`.
Stream<std::vector<mediapipe::NormalizedRect>> ScaleAndShiftAndMakeSquareLong(
Stream<std::vector<mediapipe::NormalizedRect>> rects,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to scale, shift @rect according to passed parameters.
Stream<mediapipe::NormalizedRect> ScaleAndShift(
Stream<mediapipe::NormalizedRect> rect,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to scale and shift @rect according to passed parameters and
// make it a square that has the same center and rotation, and with the side of
// the square equal to the long side of the rect.
//
// TODO: consider removing after migrating to `ScaleAndShift`.
Stream<mediapipe::NormalizedRect> ScaleAndShiftAndMakeSquareLong(
Stream<mediapipe::NormalizedRect> rect,
Stream<std::pair<int, int>> image_size, float scale_x_factor,
float scale_y_factor, float shift_x, float shift_y,
mediapipe::api2::builder::Graph& graph);
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_RECT_TRANSFORMATION_H_

View File

@ -0,0 +1,217 @@
#include "mediapipe/framework/api2/stream/rect_transformation.h"
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
namespace mediapipe::api2::builder {
namespace {
using ::mediapipe::NormalizedRect;
TEST(RectTransformation, ScaleAndMakeSquare) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedRect> rect = graph.In("RECT").Cast<NormalizedRect>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<NormalizedRect> transformed_rect = ScaleAndMakeSquare(
rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph);
transformed_rect.SetName("transformed_rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECT:__stream_0"
output_stream: "transformed_rect"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
square_long: true
}
}
}
input_stream: "RECT:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
TEST(RectTransformation, Scale) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedRect> rect = graph.In("RECT").Cast<NormalizedRect>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<NormalizedRect> transformed_rect =
Scale(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7, graph);
transformed_rect.SetName("transformed_rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECT:__stream_0"
output_stream: "transformed_rect"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
}
}
}
input_stream: "RECT:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
TEST(RectTransformation, ScaleAndShift) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedRect> rect = graph.In("RECT").Cast<NormalizedRect>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<NormalizedRect> transformed_rect =
ScaleAndShift(rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7,
/*shift_x=*/10, /*shift_y=*/0.5f, graph);
transformed_rect.SetName("transformed_rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECT:__stream_0"
output_stream: "transformed_rect"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
shift_x: 10
shift_y: 0.5
}
}
}
input_stream: "RECT:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
TEST(RectTransformation, ScaleAndShiftAndMakeSquareLong) {
mediapipe::api2::builder::Graph graph;
Stream<NormalizedRect> rect = graph.In("RECT").Cast<NormalizedRect>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<NormalizedRect> transformed_rect = ScaleAndShiftAndMakeSquareLong(
rect, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7,
/*shift_x=*/10, /*shift_y=*/0.5f, graph);
transformed_rect.SetName("transformed_rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECT:__stream_0"
output_stream: "transformed_rect"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
shift_x: 10
shift_y: 0.5
square_long: true
}
}
}
input_stream: "RECT:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
TEST(RectTransformation, ScaleAndShiftMultipleRects) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<NormalizedRect>> rects =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<std::vector<NormalizedRect>> transformed_rects =
ScaleAndShift(rects, size, /*scale_x_factor=*/2, /*scale_y_factor=*/7,
/*shift_x=*/10, /*shift_y=*/0.5f, graph);
transformed_rects.SetName("transformed_rects");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECTS:__stream_0"
output_stream: "transformed_rects"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
shift_x: 10
shift_y: 0.5
}
}
}
input_stream: "RECTS:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
TEST(RectTransformation, ScaleAndShiftAndMakeSquareLongMultipleRects) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<NormalizedRect>> rects =
graph.In("RECTS").Cast<std::vector<NormalizedRect>>();
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
Stream<std::vector<NormalizedRect>> transformed_rects =
ScaleAndShiftAndMakeSquareLong(rects, size, /*scale_x_factor=*/2,
/*scale_y_factor=*/7,
/*shift_x=*/10, /*shift_y=*/0.5f, graph);
transformed_rects.SetName("transformed_rects");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "RectTransformationCalculator"
input_stream: "IMAGE_SIZE:__stream_1"
input_stream: "NORM_RECTS:__stream_0"
output_stream: "transformed_rects"
options {
[mediapipe.RectTransformationCalculatorOptions.ext] {
scale_x: 2
scale_y: 7
shift_x: 10
shift_y: 0.5
square_long: true
}
}
}
input_stream: "RECTS:__stream_0"
input_stream: "SIZE:__stream_1"
)pb")));
}
} // namespace
} // namespace mediapipe::api2::builder