detections_to_rects stream utility function.

PiperOrigin-RevId: 566358715
This commit is contained in:
MediaPipe Team 2023-09-18 11:14:57 -07:00 committed by Copybara-Service
parent f4477f1739
commit 58a7790081
4 changed files with 393 additions and 0 deletions

View File

@ -2,6 +2,36 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"])
cc_library(
name = "detections_to_rects",
srcs = ["detections_to_rects.cc"],
hdrs = ["detections_to_rects.h"],
deps = [
"//mediapipe/calculators/util:alignment_points_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
],
)
cc_test(
name = "detections_to_rects_test",
srcs = ["detections_to_rects_test.cc"],
deps = [
":detections_to_rects",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)
cc_library(
name = "landmarks_to_detection",
srcs = ["landmarks_to_detection.cc"],

View File

@ -0,0 +1,100 @@
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include <utility>
#include <vector>
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe::api2::builder {
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::builder::Graph;
void AddOptions(int start_keypoint_index, int end_keypoint_index,
float target_angle,
mediapipe::api2::builder::GenericNode& node) {
auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
options.set_rotation_vector_target_angle_degrees(target_angle);
}
} // namespace
Stream<NormalizedRect> ConvertAlignmentPointsDetectionToRect(
Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
int start_keypoint_index, int end_keypoint_index, float target_angle,
Graph& graph) {
auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
align_node);
detection.ConnectTo(align_node.In("DETECTION"));
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
}
Stream<NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
Stream<std::vector<Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle, Graph& graph) {
auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
align_node);
detections.ConnectTo(align_node.In("DETECTIONS"));
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
}
Stream<NormalizedRect> ConvertDetectionToRect(
Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
int start_keypoint_index, int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph) {
auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
align_node);
detection.ConnectTo(align_node.In("DETECTION"));
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
}
Stream<std::vector<NormalizedRect>> ConvertDetectionsToRects(
Stream<std::vector<Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph) {
// TODO: check if we can substitute DetectionsToRectsCalculator
// with AlignmentPointsRectsCalculator and use it instead. Ideally, merge or
// remove one of calculators.
auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
align_node);
detections.ConnectTo(align_node.In("DETECTIONS"));
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
return align_node.Out("NORM_RECTS").Cast<std::vector<NormalizedRect>>();
}
Stream<NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
Stream<std::vector<Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph) {
auto& node = graph.AddNode("DetectionsToRectsCalculator");
auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
options.set_rotation_vector_target_angle_degrees(target_angle);
options.set_conversion_mode(
DetectionsToRectsCalculatorOptions::USE_KEYPOINTS);
detections.ConnectTo(node.In("DETECTIONS"));
image_size.ConnectTo(node.In("IMAGE_SIZE"));
return node.Out("NORM_RECT").Cast<NormalizedRect>();
}
} // namespace mediapipe::api2::builder

View File

@ -0,0 +1,55 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe::api2::builder {
// Updates @graph to convert @detection into a `NormalizedRect` according to
// passed parameters.
Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionToRect(
Stream<mediapipe::Detection> detection,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to convert first detection from @detections into a
// `NormalizedRect` according to passed parameters.
Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
Stream<std::vector<mediapipe::Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to convert @detection into a `NormalizedRect` according to
// passed parameters.
Stream<mediapipe::NormalizedRect> ConvertDetectionToRect(
Stream<mediapipe::Detection> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to convert @detections into a stream holding vector of
// `NormalizedRect` according to passed parameters.
Stream<std::vector<mediapipe::NormalizedRect>> ConvertDetectionsToRects(
Stream<std::vector<mediapipe::Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph);
// Updates @graph to convert @detections into a stream holding vector of
// `NormalizedRect` according to passed parameters and using keypoints.
Stream<mediapipe::NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
Stream<std::vector<mediapipe::Detection>> detections,
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
int end_keypoint_index, float target_angle,
mediapipe::api2::builder::Graph& graph);
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_

View File

@ -0,0 +1,208 @@
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder {
namespace {
TEST(DetectionsToRects, ConvertAlignmentPointsDetectionToRect) {
mediapipe::api2::builder::Graph graph;
Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
detection.SetName("detection");
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
size.SetName("size");
Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionToRect(
detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
/*target_angle=*/200, graph);
rect.SetName("rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "AlignmentPointsRectsCalculator"
input_stream: "DETECTION:detection"
input_stream: "IMAGE_SIZE:size"
output_stream: "NORM_RECT:rect"
options {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
rotation_vector_start_keypoint_index: 0
rotation_vector_end_keypoint_index: 100
rotation_vector_target_angle_degrees: 200
}
}
}
input_stream: "DETECTION:detection"
input_stream: "SIZE:size"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DetectionsToRects, ConvertAlignmentPointsDetectionsToRect) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<Detection>> detections =
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
detections.SetName("detections");
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
size.SetName("size");
Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionsToRect(
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
/*target_angle=*/200, graph);
rect.SetName("rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "AlignmentPointsRectsCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:size"
output_stream: "NORM_RECT:rect"
options {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
rotation_vector_start_keypoint_index: 0
rotation_vector_end_keypoint_index: 100
rotation_vector_target_angle_degrees: 200
}
}
}
input_stream: "DETECTIONS:detections"
input_stream: "SIZE:size"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DetectionsToRects, ConvertDetectionToRect) {
mediapipe::api2::builder::Graph graph;
Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
detection.SetName("detection");
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
size.SetName("size");
Stream<NormalizedRect> rect = ConvertDetectionToRect(
detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
/*target_angle=*/200, graph);
rect.SetName("rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTION:detection"
input_stream: "IMAGE_SIZE:size"
output_stream: "NORM_RECT:rect"
options {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
rotation_vector_start_keypoint_index: 0
rotation_vector_end_keypoint_index: 100
rotation_vector_target_angle_degrees: 200
}
}
}
input_stream: "DETECTION:detection"
input_stream: "SIZE:size"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DetectionsToRects, ConvertDetectionsToRects) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<Detection>> detections =
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
detections.SetName("detections");
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
size.SetName("size");
Stream<std::vector<NormalizedRect>> rects = ConvertDetectionsToRects(
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
/*target_angle=*/200, graph);
rects.SetName("rects");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:size"
output_stream: "NORM_RECTS:rects"
options {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
rotation_vector_start_keypoint_index: 0
rotation_vector_end_keypoint_index: 100
rotation_vector_target_angle_degrees: 200
}
}
}
input_stream: "DETECTIONS:detections"
input_stream: "SIZE:size"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DetectionsToRects, ConvertDetectionsToRectUsingKeypoints) {
mediapipe::api2::builder::Graph graph;
Stream<std::vector<Detection>> detections =
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
detections.SetName("detections");
Stream<std::pair<int, int>> size =
graph.In("SIZE").Cast<std::pair<int, int>>();
size.SetName("size");
Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
/*target_angle=*/200, graph);
rect.SetName("rect");
EXPECT_THAT(
graph.GetConfig(),
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:size"
output_stream: "NORM_RECT:rect"
options {
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
rotation_vector_start_keypoint_index: 0
rotation_vector_end_keypoint_index: 100
rotation_vector_target_angle_degrees: 200
conversion_mode: USE_KEYPOINTS
}
}
}
input_stream: "DETECTIONS:detections"
input_stream: "SIZE:size"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder