detections_to_rects stream utility function.
PiperOrigin-RevId: 566358715
This commit is contained in:
		
							parent
							
								
									f4477f1739
								
							
						
					
					
						commit
						58a7790081
					
				| 
						 | 
					@ -2,6 +2,36 @@ package(default_visibility = ["//visibility:public"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
licenses(["notice"])
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "detections_to_rects",
 | 
				
			||||||
 | 
					    srcs = ["detections_to_rects.cc"],
 | 
				
			||||||
 | 
					    hdrs = ["detections_to_rects.h"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/util:alignment_points_to_rects_calculator",
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/util:detections_to_rects_calculator",
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:detection_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:rect_cc_proto",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "detections_to_rects_test",
 | 
				
			||||||
 | 
					    srcs = ["detections_to_rects_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":detections_to_rects",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:detection_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:rect_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:parse_text_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:status_matchers",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "landmarks_to_detection",
 | 
					    name = "landmarks_to_detection",
 | 
				
			||||||
    srcs = ["landmarks_to_detection.cc"],
 | 
					    srcs = ["landmarks_to_detection.cc"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										100
									
								
								mediapipe/framework/api2/stream/detections_to_rects.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								mediapipe/framework/api2/stream/detections_to_rects.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,100 @@
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/stream/detections_to_rects.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/detection.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/rect.pb.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::mediapipe::NormalizedRect;
 | 
				
			||||||
 | 
					using ::mediapipe::api2::builder::Graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void AddOptions(int start_keypoint_index, int end_keypoint_index,
 | 
				
			||||||
 | 
					                float target_angle,
 | 
				
			||||||
 | 
					                mediapipe::api2::builder::GenericNode& node) {
 | 
				
			||||||
 | 
					  auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
 | 
				
			||||||
 | 
					  options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
 | 
				
			||||||
 | 
					  options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
 | 
				
			||||||
 | 
					  options.set_rotation_vector_target_angle_degrees(target_angle);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<NormalizedRect> ConvertAlignmentPointsDetectionToRect(
 | 
				
			||||||
 | 
					    Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
 | 
				
			||||||
 | 
					    int start_keypoint_index, int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    Graph& graph) {
 | 
				
			||||||
 | 
					  auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
 | 
				
			||||||
 | 
					  AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
 | 
				
			||||||
 | 
					             align_node);
 | 
				
			||||||
 | 
					  detection.ConnectTo(align_node.In("DETECTION"));
 | 
				
			||||||
 | 
					  image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
 | 
				
			||||||
 | 
					  return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
 | 
				
			||||||
 | 
					    Stream<std::vector<Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle, Graph& graph) {
 | 
				
			||||||
 | 
					  auto& align_node = graph.AddNode("AlignmentPointsRectsCalculator");
 | 
				
			||||||
 | 
					  AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
 | 
				
			||||||
 | 
					             align_node);
 | 
				
			||||||
 | 
					  detections.ConnectTo(align_node.In("DETECTIONS"));
 | 
				
			||||||
 | 
					  image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
 | 
				
			||||||
 | 
					  return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<NormalizedRect> ConvertDetectionToRect(
 | 
				
			||||||
 | 
					    Stream<Detection> detection, Stream<std::pair<int, int>> image_size,
 | 
				
			||||||
 | 
					    int start_keypoint_index, int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph) {
 | 
				
			||||||
 | 
					  auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
 | 
				
			||||||
 | 
					  AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
 | 
				
			||||||
 | 
					             align_node);
 | 
				
			||||||
 | 
					  detection.ConnectTo(align_node.In("DETECTION"));
 | 
				
			||||||
 | 
					  image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
 | 
				
			||||||
 | 
					  return align_node.Out("NORM_RECT").Cast<NormalizedRect>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<std::vector<NormalizedRect>> ConvertDetectionsToRects(
 | 
				
			||||||
 | 
					    Stream<std::vector<Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph) {
 | 
				
			||||||
 | 
					  // TODO: check if we can substitute DetectionsToRectsCalculator
 | 
				
			||||||
 | 
					  // with AlignmentPointsRectsCalculator and use it instead. Ideally, merge or
 | 
				
			||||||
 | 
					  // remove one of calculators.
 | 
				
			||||||
 | 
					  auto& align_node = graph.AddNode("DetectionsToRectsCalculator");
 | 
				
			||||||
 | 
					  AddOptions(start_keypoint_index, end_keypoint_index, target_angle,
 | 
				
			||||||
 | 
					             align_node);
 | 
				
			||||||
 | 
					  detections.ConnectTo(align_node.In("DETECTIONS"));
 | 
				
			||||||
 | 
					  image_size.ConnectTo(align_node.In("IMAGE_SIZE"));
 | 
				
			||||||
 | 
					  return align_node.Out("NORM_RECTS").Cast<std::vector<NormalizedRect>>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Stream<NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
 | 
				
			||||||
 | 
					    Stream<std::vector<Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph) {
 | 
				
			||||||
 | 
					  auto& node = graph.AddNode("DetectionsToRectsCalculator");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto& options = node.GetOptions<DetectionsToRectsCalculatorOptions>();
 | 
				
			||||||
 | 
					  options.set_rotation_vector_start_keypoint_index(start_keypoint_index);
 | 
				
			||||||
 | 
					  options.set_rotation_vector_end_keypoint_index(end_keypoint_index);
 | 
				
			||||||
 | 
					  options.set_rotation_vector_target_angle_degrees(target_angle);
 | 
				
			||||||
 | 
					  options.set_conversion_mode(
 | 
				
			||||||
 | 
					      DetectionsToRectsCalculatorOptions::USE_KEYPOINTS);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  detections.ConnectTo(node.In("DETECTIONS"));
 | 
				
			||||||
 | 
					  image_size.ConnectTo(node.In("IMAGE_SIZE"));
 | 
				
			||||||
 | 
					  return node.Out("NORM_RECT").Cast<NormalizedRect>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
							
								
								
									
										55
									
								
								mediapipe/framework/api2/stream/detections_to_rects.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								mediapipe/framework/api2/stream/detections_to_rects.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,55 @@
 | 
				
			||||||
 | 
					#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
 | 
				
			||||||
 | 
					#define MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/detection.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/rect.pb.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to convert @detection into a `NormalizedRect` according to
 | 
				
			||||||
 | 
					// passed parameters.
 | 
				
			||||||
 | 
					Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionToRect(
 | 
				
			||||||
 | 
					    Stream<mediapipe::Detection> detection,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to convert first detection from @detections into a
 | 
				
			||||||
 | 
					// `NormalizedRect` according to passed parameters.
 | 
				
			||||||
 | 
					Stream<mediapipe::NormalizedRect> ConvertAlignmentPointsDetectionsToRect(
 | 
				
			||||||
 | 
					    Stream<std::vector<mediapipe::Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to convert @detection into a `NormalizedRect` according to
 | 
				
			||||||
 | 
					// passed parameters.
 | 
				
			||||||
 | 
					Stream<mediapipe::NormalizedRect> ConvertDetectionToRect(
 | 
				
			||||||
 | 
					    Stream<mediapipe::Detection> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to convert @detections into a stream holding vector of
 | 
				
			||||||
 | 
					// `NormalizedRect` according to passed parameters.
 | 
				
			||||||
 | 
					Stream<std::vector<mediapipe::NormalizedRect>> ConvertDetectionsToRects(
 | 
				
			||||||
 | 
					    Stream<std::vector<mediapipe::Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Updates @graph to convert @detections into a stream holding vector of
 | 
				
			||||||
 | 
					// `NormalizedRect` according to passed parameters and using keypoints.
 | 
				
			||||||
 | 
					Stream<mediapipe::NormalizedRect> ConvertDetectionsToRectUsingKeypoints(
 | 
				
			||||||
 | 
					    Stream<std::vector<mediapipe::Detection>> detections,
 | 
				
			||||||
 | 
					    Stream<std::pair<int, int>> image_size, int start_keypoint_index,
 | 
				
			||||||
 | 
					    int end_keypoint_index, float target_angle,
 | 
				
			||||||
 | 
					    mediapipe::api2::builder::Graph& graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // MEDIAPIPE_FRAMEWORK_API2_STREAM_DETECTIONS_TO_RECTS_H_
 | 
				
			||||||
							
								
								
									
										208
									
								
								mediapipe/framework/api2/stream/detections_to_rects_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								mediapipe/framework/api2/stream/detections_to_rects_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,208 @@
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/stream/detections_to_rects.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/detection.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/rect.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/parse_text_proto.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe::api2::builder {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DetectionsToRects, ConvertAlignmentPointsDetectionToRect) {
 | 
				
			||||||
 | 
					  mediapipe::api2::builder::Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
 | 
				
			||||||
 | 
					  detection.SetName("detection");
 | 
				
			||||||
 | 
					  Stream<std::pair<int, int>> size =
 | 
				
			||||||
 | 
					      graph.In("SIZE").Cast<std::pair<int, int>>();
 | 
				
			||||||
 | 
					  size.SetName("size");
 | 
				
			||||||
 | 
					  Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionToRect(
 | 
				
			||||||
 | 
					      detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
 | 
				
			||||||
 | 
					      /*target_angle=*/200, graph);
 | 
				
			||||||
 | 
					  rect.SetName("rect");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      graph.GetConfig(),
 | 
				
			||||||
 | 
					      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "AlignmentPointsRectsCalculator"
 | 
				
			||||||
 | 
					          input_stream: "DETECTION:detection"
 | 
				
			||||||
 | 
					          input_stream: "IMAGE_SIZE:size"
 | 
				
			||||||
 | 
					          output_stream: "NORM_RECT:rect"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.DetectionsToRectsCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					              rotation_vector_start_keypoint_index: 0
 | 
				
			||||||
 | 
					              rotation_vector_end_keypoint_index: 100
 | 
				
			||||||
 | 
					              rotation_vector_target_angle_degrees: 200
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        input_stream: "DETECTION:detection"
 | 
				
			||||||
 | 
					        input_stream: "SIZE:size"
 | 
				
			||||||
 | 
					      )pb")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calcualtor_graph;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DetectionsToRects, ConvertAlignmentPointsDetectionsToRect) {
 | 
				
			||||||
 | 
					  mediapipe::api2::builder::Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<std::vector<Detection>> detections =
 | 
				
			||||||
 | 
					      graph.In("DETECTIONS").Cast<std::vector<Detection>>();
 | 
				
			||||||
 | 
					  detections.SetName("detections");
 | 
				
			||||||
 | 
					  Stream<std::pair<int, int>> size =
 | 
				
			||||||
 | 
					      graph.In("SIZE").Cast<std::pair<int, int>>();
 | 
				
			||||||
 | 
					  size.SetName("size");
 | 
				
			||||||
 | 
					  Stream<NormalizedRect> rect = ConvertAlignmentPointsDetectionsToRect(
 | 
				
			||||||
 | 
					      detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
 | 
				
			||||||
 | 
					      /*target_angle=*/200, graph);
 | 
				
			||||||
 | 
					  rect.SetName("rect");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      graph.GetConfig(),
 | 
				
			||||||
 | 
					      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "AlignmentPointsRectsCalculator"
 | 
				
			||||||
 | 
					          input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					          input_stream: "IMAGE_SIZE:size"
 | 
				
			||||||
 | 
					          output_stream: "NORM_RECT:rect"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.DetectionsToRectsCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					              rotation_vector_start_keypoint_index: 0
 | 
				
			||||||
 | 
					              rotation_vector_end_keypoint_index: 100
 | 
				
			||||||
 | 
					              rotation_vector_target_angle_degrees: 200
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					        input_stream: "SIZE:size"
 | 
				
			||||||
 | 
					      )pb")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calcualtor_graph;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DetectionsToRects, ConvertDetectionToRect) {
 | 
				
			||||||
 | 
					  mediapipe::api2::builder::Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<Detection> detection = graph.In("DETECTION").Cast<Detection>();
 | 
				
			||||||
 | 
					  detection.SetName("detection");
 | 
				
			||||||
 | 
					  Stream<std::pair<int, int>> size =
 | 
				
			||||||
 | 
					      graph.In("SIZE").Cast<std::pair<int, int>>();
 | 
				
			||||||
 | 
					  size.SetName("size");
 | 
				
			||||||
 | 
					  Stream<NormalizedRect> rect = ConvertDetectionToRect(
 | 
				
			||||||
 | 
					      detection, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
 | 
				
			||||||
 | 
					      /*target_angle=*/200, graph);
 | 
				
			||||||
 | 
					  rect.SetName("rect");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      graph.GetConfig(),
 | 
				
			||||||
 | 
					      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "DetectionsToRectsCalculator"
 | 
				
			||||||
 | 
					          input_stream: "DETECTION:detection"
 | 
				
			||||||
 | 
					          input_stream: "IMAGE_SIZE:size"
 | 
				
			||||||
 | 
					          output_stream: "NORM_RECT:rect"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.DetectionsToRectsCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					              rotation_vector_start_keypoint_index: 0
 | 
				
			||||||
 | 
					              rotation_vector_end_keypoint_index: 100
 | 
				
			||||||
 | 
					              rotation_vector_target_angle_degrees: 200
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        input_stream: "DETECTION:detection"
 | 
				
			||||||
 | 
					        input_stream: "SIZE:size"
 | 
				
			||||||
 | 
					      )pb")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calcualtor_graph;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DetectionsToRects, ConvertDetectionsToRects) {
 | 
				
			||||||
 | 
					  mediapipe::api2::builder::Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<std::vector<Detection>> detections =
 | 
				
			||||||
 | 
					      graph.In("DETECTIONS").Cast<std::vector<Detection>>();
 | 
				
			||||||
 | 
					  detections.SetName("detections");
 | 
				
			||||||
 | 
					  Stream<std::pair<int, int>> size =
 | 
				
			||||||
 | 
					      graph.In("SIZE").Cast<std::pair<int, int>>();
 | 
				
			||||||
 | 
					  size.SetName("size");
 | 
				
			||||||
 | 
					  Stream<std::vector<NormalizedRect>> rects = ConvertDetectionsToRects(
 | 
				
			||||||
 | 
					      detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
 | 
				
			||||||
 | 
					      /*target_angle=*/200, graph);
 | 
				
			||||||
 | 
					  rects.SetName("rects");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      graph.GetConfig(),
 | 
				
			||||||
 | 
					      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "DetectionsToRectsCalculator"
 | 
				
			||||||
 | 
					          input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					          input_stream: "IMAGE_SIZE:size"
 | 
				
			||||||
 | 
					          output_stream: "NORM_RECTS:rects"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.DetectionsToRectsCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					              rotation_vector_start_keypoint_index: 0
 | 
				
			||||||
 | 
					              rotation_vector_end_keypoint_index: 100
 | 
				
			||||||
 | 
					              rotation_vector_target_angle_degrees: 200
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					        input_stream: "SIZE:size"
 | 
				
			||||||
 | 
					      )pb")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calcualtor_graph;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DetectionsToRects, ConvertDetectionsToRectUsingKeypoints) {
 | 
				
			||||||
 | 
					  mediapipe::api2::builder::Graph graph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Stream<std::vector<Detection>> detections =
 | 
				
			||||||
 | 
					      graph.In("DETECTIONS").Cast<std::vector<Detection>>();
 | 
				
			||||||
 | 
					  detections.SetName("detections");
 | 
				
			||||||
 | 
					  Stream<std::pair<int, int>> size =
 | 
				
			||||||
 | 
					      graph.In("SIZE").Cast<std::pair<int, int>>();
 | 
				
			||||||
 | 
					  size.SetName("size");
 | 
				
			||||||
 | 
					  Stream<NormalizedRect> rect = ConvertDetectionsToRectUsingKeypoints(
 | 
				
			||||||
 | 
					      detections, size, /*start_keypoint_index=*/0, /*end_keypoint_index=*/100,
 | 
				
			||||||
 | 
					      /*target_angle=*/200, graph);
 | 
				
			||||||
 | 
					  rect.SetName("rect");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      graph.GetConfig(),
 | 
				
			||||||
 | 
					      EqualsProto(mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "DetectionsToRectsCalculator"
 | 
				
			||||||
 | 
					          input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					          input_stream: "IMAGE_SIZE:size"
 | 
				
			||||||
 | 
					          output_stream: "NORM_RECT:rect"
 | 
				
			||||||
 | 
					          options {
 | 
				
			||||||
 | 
					            [mediapipe.DetectionsToRectsCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					              rotation_vector_start_keypoint_index: 0
 | 
				
			||||||
 | 
					              rotation_vector_end_keypoint_index: 100
 | 
				
			||||||
 | 
					              rotation_vector_target_angle_degrees: 200
 | 
				
			||||||
 | 
					              conversion_mode: USE_KEYPOINTS
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        input_stream: "DETECTIONS:detections"
 | 
				
			||||||
 | 
					        input_stream: "SIZE:size"
 | 
				
			||||||
 | 
					      )pb")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calcualtor_graph;
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					}  // namespace mediapipe::api2::builder
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user