diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index f59a65d95..091656b2e 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -2,6 +2,36 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) +cc_library( + name = "detections_to_rects", + srcs = ["detections_to_rects.cc"], + hdrs = ["detections_to_rects.h"], + deps = [ + "//mediapipe/calculators/util:alignment_points_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + ], +) + +cc_test( + name = "detections_to_rects_test", + srcs = ["detections_to_rects_test.cc"], + deps = [ + ":detections_to_rects", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + ], +) + cc_library( name = "landmarks_to_detection", srcs = ["landmarks_to_detection.cc"], diff --git a/mediapipe/framework/api2/stream/detections_to_rects.cc b/mediapipe/framework/api2/stream/detections_to_rects.cc new file mode 100644 index 000000000..f3daddaee --- /dev/null +++ b/mediapipe/framework/api2/stream/detections_to_rects.cc @@ -0,0 +1,100 @@ +#include "mediapipe/framework/api2/stream/detections_to_rects.h" + +#include +#include + +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +namespace { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::builder::Graph; + +void AddOptions(int start_keypoint_index, int end_keypoint_index, + float target_angle, + mediapipe::api2::builder::GenericNode& node) { + auto& options = node.GetOptions(); + options.set_rotation_vector_start_keypoint_index(start_keypoint_index); + options.set_rotation_vector_end_keypoint_index(end_keypoint_index); + options.set_rotation_vector_target_angle_degrees(target_angle); +} + +} // namespace + +Stream ConvertAlignmentPointsDetectionToRect( + Stream detection, Stream> image_size, + int start_keypoint_index, int end_keypoint_index, float target_angle, + Graph& graph) { + auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator"); + AddOptions(start_keypoint_index, end_keypoint_index, target_angle, + align_node); + detection.ConnectTo(align_node.In("DETECTION")); + image_size.ConnectTo(align_node.In("IMAGE_SIZE")); + return align_node.Out("NORM_RECT").Cast(); +} + +Stream ConvertAlignmentPointsDetectionsToRect( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, Graph& graph) { + auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator"); + AddOptions(start_keypoint_index, end_keypoint_index, target_angle, + align_node); + detections.ConnectTo(align_node.In("DETECTIONS")); + image_size.ConnectTo(align_node.In("IMAGE_SIZE")); + return align_node.Out("NORM_RECT").Cast(); +} + +Stream ConvertDetectionToRect( + Stream detection, Stream> image_size, + int start_keypoint_index, int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph) { + auto& align_node = graph.AddNode("DetectionsToRectsCalculator"); + AddOptions(start_keypoint_index, end_keypoint_index, target_angle, + align_node); + detection.ConnectTo(align_node.In("DETECTION")); + image_size.ConnectTo(align_node.In("IMAGE_SIZE")); + return align_node.Out("NORM_RECT").Cast(); +} + +Stream> ConvertDetectionsToRects( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph) { + // TODO: check if we can substitute DetectionsToRectsCalculator + // with AlignmentPointsRectsCalculator and use it instead. Ideally, merge or + // remove one of calculators. + auto& align_node = graph.AddNode("DetectionsToRectsCalculator"); + AddOptions(start_keypoint_index, end_keypoint_index, target_angle, + align_node); + detections.ConnectTo(align_node.In("DETECTIONS")); + image_size.ConnectTo(align_node.In("IMAGE_SIZE")); + return align_node.Out("NORM_RECTS").Cast>(); +} + +Stream ConvertDetectionsToRectUsingKeypoints( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph) { + auto& node = graph.AddNode("DetectionsToRectsCalculator"); + + auto& options = node.GetOptions(); + options.set_rotation_vector_start_keypoint_index(start_keypoint_index); + options.set_rotation_vector_end_keypoint_index(end_keypoint_index); + options.set_rotation_vector_target_angle_degrees(target_angle); + options.set_conversion_mode( + DetectionsToRectsCalculatorOptions::USE_KEYPOINTS); + + detections.ConnectTo(node.In("DETECTIONS")); + image_size.ConnectTo(node.In("IMAGE_SIZE")); + return node.Out("NORM_RECT").Cast(); +} + +} // namespace mediapipe::api2::builder diff --git a/mediapipe/framework/api2/stream/detections_to_rects.h b/mediapipe/framework/api2/stream/detections_to_rects.h new file mode 100644 index 000000000..f1db9247e --- /dev/null +++ b/mediapipe/framework/api2/stream/detections_to_rects.h @@ -0,0 +1,55 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to convert @detection into a `NormalizedRect` according to +// passed parameters. +Stream ConvertAlignmentPointsDetectionToRect( + Stream detection, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to convert first detection from @detections into a +// `NormalizedRect` according to passed parameters. +Stream ConvertAlignmentPointsDetectionsToRect( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to convert @detection into a `NormalizedRect` according to +// passed parameters. +Stream ConvertDetectionToRect( + Stream detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to convert @detections into a stream holding vector of +// `NormalizedRect` according to passed parameters. +Stream> ConvertDetectionsToRects( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph); + +// Updates @graph to convert @detections into a stream holding vector of +// `NormalizedRect` according to passed parameters and using keypoints. +Stream ConvertDetectionsToRectUsingKeypoints( + Stream> detections, + Stream> image_size, int start_keypoint_index, + int end_keypoint_index, float target_angle, + mediapipe::api2::builder::Graph& graph); + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_ diff --git a/mediapipe/framework/api2/stream/detections_to_rects_test.cc b/mediapipe/framework/api2/stream/detections_to_rects_test.cc new file mode 100644 index 000000000..a7d3d77db --- /dev/null +++ b/mediapipe/framework/api2/stream/detections_to_rects_test.cc @@ -0,0 +1,208 @@ +#include "mediapipe/framework/api2/stream/detections_to_rects.h" + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(DetectionsToRects, ConvertAlignmentPointsDetectionToRect) { + mediapipe::api2::builder::Graph graph; + + Stream detection = graph.In("DETECTION").Cast(); + detection.SetName("detection"); + Stream> size = + graph.In("SIZE").Cast>(); + size.SetName("size"); + Stream rect = ConvertAlignmentPointsDetectionToRect( + detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100, + /*target_angle=*/200, graph); + rect.SetName("rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "AlignmentPointsRectsCalculator" + input_stream: "DETECTION:detection" + input_stream: "IMAGE_SIZE:size" + output_stream: "NORM_RECT:rect" + options { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 100 + rotation_vector_target_angle_degrees: 200 + } + } + } + input_stream: "DETECTION:detection" + input_stream: "SIZE:size" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DetectionsToRects, ConvertAlignmentPointsDetectionsToRect) { + mediapipe::api2::builder::Graph graph; + + Stream> detections = + graph.In("DETECTIONS").Cast>(); + detections.SetName("detections"); + Stream> size = + graph.In("SIZE").Cast>(); + size.SetName("size"); + Stream rect = ConvertAlignmentPointsDetectionsToRect( + detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100, + /*target_angle=*/200, graph); + rect.SetName("rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "AlignmentPointsRectsCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "IMAGE_SIZE:size" + output_stream: "NORM_RECT:rect" + options { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 100 + rotation_vector_target_angle_degrees: 200 + } + } + } + input_stream: "DETECTIONS:detections" + input_stream: "SIZE:size" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DetectionsToRects, ConvertDetectionToRect) { + mediapipe::api2::builder::Graph graph; + + Stream detection = graph.In("DETECTION").Cast(); + detection.SetName("detection"); + Stream> size = + graph.In("SIZE").Cast>(); + size.SetName("size"); + Stream rect = ConvertDetectionToRect( + detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100, + /*target_angle=*/200, graph); + rect.SetName("rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTION:detection" + input_stream: "IMAGE_SIZE:size" + output_stream: "NORM_RECT:rect" + options { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 100 + rotation_vector_target_angle_degrees: 200 + } + } + } + input_stream: "DETECTION:detection" + input_stream: "SIZE:size" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DetectionsToRects, ConvertDetectionsToRects) { + mediapipe::api2::builder::Graph graph; + + Stream> detections = + graph.In("DETECTIONS").Cast>(); + detections.SetName("detections"); + Stream> size = + graph.In("SIZE").Cast>(); + size.SetName("size"); + Stream> rects = ConvertDetectionsToRects( + detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100, + /*target_angle=*/200, graph); + rects.SetName("rects"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "IMAGE_SIZE:size" + output_stream: "NORM_RECTS:rects" + options { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 100 + rotation_vector_target_angle_degrees: 200 + } + } + } + input_stream: "DETECTIONS:detections" + input_stream: "SIZE:size" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DetectionsToRects, ConvertDetectionsToRectUsingKeypoints) { + mediapipe::api2::builder::Graph graph; + + Stream> detections = + graph.In("DETECTIONS").Cast>(); + detections.SetName("detections"); + Stream> size = + graph.In("SIZE").Cast>(); + size.SetName("size"); + Stream rect = ConvertDetectionsToRectUsingKeypoints( + detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100, + /*target_angle=*/200, graph); + rect.SetName("rect"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "DetectionsToRectsCalculator" + input_stream: "DETECTIONS:detections" + input_stream: "IMAGE_SIZE:size" + output_stream: "NORM_RECT:rect" + options { + [mediapipe.DetectionsToRectsCalculatorOptions.ext] { + rotation_vector_start_keypoint_index: 0 + rotation_vector_end_keypoint_index: 100 + rotation_vector_target_angle_degrees: 200 + conversion_mode: USE_KEYPOINTS + } + } + } + input_stream: "DETECTIONS:detections" + input_stream: "SIZE:size" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace mediapipe::api2::builder