detections_to_rects stream utility function.
PiperOrigin-RevId: 566358715
This commit is contained in:
parent
f4477f1739
commit
58a7790081
|
@ -2,6 +2,36 @@ package(default_visibility = ["//visibility:public"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "detections_to_rects",
|
||||
srcs = ["detections_to_rects.cc"],
|
||||
hdrs = ["detections_to_rects.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:alignment_points_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detections_to_rects_test",
|
||||
srcs = ["detections_to_rects_test.cc"],
|
||||
deps = [
|
||||
":detections_to_rects",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_to_detection",
|
||||
srcs = ["landmarks_to_detection.cc"],
|
||||
|
|
100
mediapipe/framework/api2/stream/detections_to_rects.cc
Normal file
100
mediapipe/framework/api2/stream/detections_to_rects.cc
Normal file
|
@ -0,0 +1,100 @@
|
|||
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
|
||||
void AddOptions(int start_keypoint_index, int end_keypoint_index,
|
||||
float target_angle,
|
||||
mediapipe::api2::builder::GenericNode& node) {
|
||||
auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
|
||||
options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
|
||||
options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
|
||||
options.set_rotation_vector_target_angle_degrees(target_angle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Stream<NormalizedRect> ConvertAlignmentPointsDetectionToRect(
|
||||
Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
|
||||
int start_keypoint_index, int end_keypoint_index, float target_angle,
|
||||
Graph& graph) {
|
||||
auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
|
||||
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
|
||||
align_node);
|
||||
detection.ConnectTo(align_node.In("DETECTION"));
|
||||
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
|
||||
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
|
||||
Stream<std::vector<Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle, Graph& graph) {
|
||||
auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
|
||||
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
|
||||
align_node);
|
||||
detections.ConnectTo(align_node.In("DETECTIONS"));
|
||||
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
|
||||
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> ConvertDetectionToRect(
|
||||
Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
|
||||
int start_keypoint_index, int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
|
||||
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
|
||||
align_node);
|
||||
detection.ConnectTo(align_node.In("DETECTION"));
|
||||
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
|
||||
return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
Stream<std::vector<NormalizedRect>> ConvertDetectionsToRects(
|
||||
Stream<std::vector<Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
// TODO: check if we can substitute DetectionsToRectsCalculator
|
||||
// with AlignmentPointsRectsCalculator and use it instead. Ideally, merge or
|
||||
// remove one of calculators.
|
||||
auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
|
||||
AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
|
||||
align_node);
|
||||
detections.ConnectTo(align_node.In("DETECTIONS"));
|
||||
image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
|
||||
return align_node.Out("NORM_RECTS").Cast<std::vector<NormalizedRect>>();
|
||||
}
|
||||
|
||||
Stream<NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
|
||||
Stream<std::vector<Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph) {
|
||||
auto& node = graph.AddNode("DetectionsToRectsCalculator");
|
||||
|
||||
auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
|
||||
options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
|
||||
options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
|
||||
options.set_rotation_vector_target_angle_degrees(target_angle);
|
||||
options.set_conversion_mode(
|
||||
DetectionsToRectsCalculatorOptions::USE_KEYPOINTS);
|
||||
|
||||
detections.ConnectTo(node.In("DETECTIONS"));
|
||||
image_size.ConnectTo(node.In("IMAGE_SIZE"));
|
||||
return node.Out("NORM_RECT").Cast<NormalizedRect>();
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
55
mediapipe/framework/api2/stream/detections_to_rects.h
Normal file
55
mediapipe/framework/api2/stream/detections_to_rects.h
Normal file
|
@ -0,0 +1,55 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
// Updates @graph to convert @detection into a `NormalizedRect` according to
|
||||
// passed parameters.
|
||||
Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionToRect(
|
||||
Stream<mediapipe::Detection> detection,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
// Updates @graph to convert first detection from @detections into a
|
||||
// `NormalizedRect` according to passed parameters.
|
||||
Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
|
||||
Stream<std::vector<mediapipe::Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
// Updates @graph to convert @detection into a `NormalizedRect` according to
|
||||
// passed parameters.
|
||||
Stream<mediapipe::NormalizedRect> ConvertDetectionToRect(
|
||||
Stream<mediapipe::Detection> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
// Updates @graph to convert @detections into a stream holding vector of
|
||||
// `NormalizedRect` according to passed parameters.
|
||||
Stream<std::vector<mediapipe::NormalizedRect>> ConvertDetectionsToRects(
|
||||
Stream<std::vector<mediapipe::Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
// Updates @graph to convert @detections into a stream holding vector of
|
||||
// `NormalizedRect` according to passed parameters and using keypoints.
|
||||
Stream<mediapipe::NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
|
||||
Stream<std::vector<mediapipe::Detection>> detections,
|
||||
Stream<std::pair<int, int>> image_size, int start_keypoint_index,
|
||||
int end_keypoint_index, float target_angle,
|
||||
mediapipe::api2::builder::Graph& graph);
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
|
208
mediapipe/framework/api2/stream/detections_to_rects_test.cc
Normal file
208
mediapipe/framework/api2/stream/detections_to_rects_test.cc
Normal file
|
@ -0,0 +1,208 @@
|
|||
#include "mediapipe/framework/api2/stream/detections_to_rects.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
TEST(DetectionsToRects, ConvertAlignmentPointsDetectionToRect) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
|
||||
detection.SetName("detection");
|
||||
Stream<std::pair<int, int>> size =
|
||||
graph.In("SIZE").Cast<std::pair<int, int>>();
|
||||
size.SetName("size");
|
||||
Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionToRect(
|
||||
detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
|
||||
/*target_angle=*/200, graph);
|
||||
rect.SetName("rect");
|
||||
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "AlignmentPointsRectsCalculator"
|
||||
input_stream: "DETECTION:detection"
|
||||
input_stream: "IMAGE_SIZE:size"
|
||||
output_stream: "NORM_RECT:rect"
|
||||
options {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
rotation_vector_start_keypoint_index: 0
|
||||
rotation_vector_end_keypoint_index: 100
|
||||
rotation_vector_target_angle_degrees: 200
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "DETECTION:detection"
|
||||
input_stream: "SIZE:size"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DetectionsToRects, ConvertAlignmentPointsDetectionsToRect) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Stream<std::vector<Detection>> detections =
|
||||
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
|
||||
detections.SetName("detections");
|
||||
Stream<std::pair<int, int>> size =
|
||||
graph.In("SIZE").Cast<std::pair<int, int>>();
|
||||
size.SetName("size");
|
||||
Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionsToRect(
|
||||
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
|
||||
/*target_angle=*/200, graph);
|
||||
rect.SetName("rect");
|
||||
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "AlignmentPointsRectsCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:size"
|
||||
output_stream: "NORM_RECT:rect"
|
||||
options {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
rotation_vector_start_keypoint_index: 0
|
||||
rotation_vector_end_keypoint_index: 100
|
||||
rotation_vector_target_angle_degrees: 200
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "SIZE:size"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DetectionsToRects, ConvertDetectionToRect) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
|
||||
detection.SetName("detection");
|
||||
Stream<std::pair<int, int>> size =
|
||||
graph.In("SIZE").Cast<std::pair<int, int>>();
|
||||
size.SetName("size");
|
||||
Stream<NormalizedRect> rect = ConvertDetectionToRect(
|
||||
detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
|
||||
/*target_angle=*/200, graph);
|
||||
rect.SetName("rect");
|
||||
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTION:detection"
|
||||
input_stream: "IMAGE_SIZE:size"
|
||||
output_stream: "NORM_RECT:rect"
|
||||
options {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
rotation_vector_start_keypoint_index: 0
|
||||
rotation_vector_end_keypoint_index: 100
|
||||
rotation_vector_target_angle_degrees: 200
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "DETECTION:detection"
|
||||
input_stream: "SIZE:size"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DetectionsToRects, ConvertDetectionsToRects) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Stream<std::vector<Detection>> detections =
|
||||
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
|
||||
detections.SetName("detections");
|
||||
Stream<std::pair<int, int>> size =
|
||||
graph.In("SIZE").Cast<std::pair<int, int>>();
|
||||
size.SetName("size");
|
||||
Stream<std::vector<NormalizedRect>> rects = ConvertDetectionsToRects(
|
||||
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
|
||||
/*target_angle=*/200, graph);
|
||||
rects.SetName("rects");
|
||||
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:size"
|
||||
output_stream: "NORM_RECTS:rects"
|
||||
options {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
rotation_vector_start_keypoint_index: 0
|
||||
rotation_vector_end_keypoint_index: 100
|
||||
rotation_vector_target_angle_degrees: 200
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "SIZE:size"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(DetectionsToRects, ConvertDetectionsToRectUsingKeypoints) {
|
||||
mediapipe::api2::builder::Graph graph;
|
||||
|
||||
Stream<std::vector<Detection>> detections =
|
||||
graph.In("DETECTIONS").Cast<std::vector<Detection>>();
|
||||
detections.SetName("detections");
|
||||
Stream<std::pair<int, int>> size =
|
||||
graph.In("SIZE").Cast<std::pair<int, int>>();
|
||||
size.SetName("size");
|
||||
Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
|
||||
detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
|
||||
/*target_angle=*/200, graph);
|
||||
rect.SetName("rect");
|
||||
|
||||
EXPECT_THAT(
|
||||
graph.GetConfig(),
|
||||
EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "DetectionsToRectsCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:size"
|
||||
output_stream: "NORM_RECT:rect"
|
||||
options {
|
||||
[mediapipe.DetectionsToRectsCalculatorOptions.ext] {
|
||||
rotation_vector_start_keypoint_index: 0
|
||||
rotation_vector_end_keypoint_index: 100
|
||||
rotation_vector_target_angle_degrees: 200
|
||||
conversion_mode: USE_KEYPOINTS
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "SIZE:size"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user