Add support for [xmin, ymin, xmax, ymax] style of bbox output
PiperOrigin-RevId: 509942540
This commit is contained in:
		
							parent
							
								
									40c3e72c9c
								
							
						
					
					
						commit
						796a96d842
					
				|  | @ -64,6 +64,8 @@ bool CanUseGpu() { | |||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
| 
 | ||||
| using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat; | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, | ||||
|  | @ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping( | |||
|   return absl::OkStatus(); | ||||
| } | ||||
| 
 | ||||
| BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) { | ||||
|   if (options.has_box_format()) { | ||||
|     return options.box_format(); | ||||
|   } else if (options.reverse_output_order()) { | ||||
|     return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH; | ||||
|   } | ||||
|   return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| // Convert result Tensors from object detection models into MediaPipe
 | ||||
|  | @ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node { | |||
|   int num_boxes_ = 0; | ||||
|   int num_coords_ = 0; | ||||
|   int max_results_ = -1; | ||||
|   BoxFormat box_output_format_ = | ||||
|       mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; | ||||
| 
 | ||||
|   // Set of allowed or ignored class indices.
 | ||||
|   struct ClassIndexSet { | ||||
|  | @ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { | |||
|   num_classes_ = options_.num_classes(); | ||||
|   num_boxes_ = options_.num_boxes(); | ||||
|   num_coords_ = options_.num_coords(); | ||||
|   box_output_format_ = GetBoxFormat(options_); | ||||
|   CHECK_NE(options_.max_results(), 0) | ||||
|       << "The maximum number of the top-scored detection results must be " | ||||
|          "non-zero."; | ||||
|  | @ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes( | |||
|   for (int i = 0; i < num_boxes_; ++i) { | ||||
|     const int box_offset = i * num_coords_ + options_.box_coord_offset(); | ||||
| 
 | ||||
|     float y_center = raw_boxes[box_offset]; | ||||
|     float x_center = raw_boxes[box_offset + 1]; | ||||
|     float h = raw_boxes[box_offset + 2]; | ||||
|     float w = raw_boxes[box_offset + 3]; | ||||
|     if (options_.reverse_output_order()) { | ||||
|       x_center = raw_boxes[box_offset]; | ||||
|       y_center = raw_boxes[box_offset + 1]; | ||||
|       w = raw_boxes[box_offset + 2]; | ||||
|       h = raw_boxes[box_offset + 3]; | ||||
|     float y_center = 0.0; | ||||
|     float x_center = 0.0; | ||||
|     float h = 0.0; | ||||
|     float w = 0.0; | ||||
|     // TODO
 | ||||
|     switch (box_output_format_) { | ||||
|       case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED: | ||||
|       case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW: | ||||
|         y_center = raw_boxes[box_offset]; | ||||
|         x_center = raw_boxes[box_offset + 1]; | ||||
|         h = raw_boxes[box_offset + 2]; | ||||
|         w = raw_boxes[box_offset + 3]; | ||||
|         break; | ||||
|       case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH: | ||||
|         x_center = raw_boxes[box_offset]; | ||||
|         y_center = raw_boxes[box_offset + 1]; | ||||
|         w = raw_boxes[box_offset + 2]; | ||||
|         h = raw_boxes[box_offset + 3]; | ||||
|         break; | ||||
|       case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY: | ||||
|         x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2; | ||||
|         y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2; | ||||
|         w = raw_boxes[box_offset + 2] + raw_boxes[box_offset]; | ||||
|         h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1]; | ||||
|         break; | ||||
|     } | ||||
| 
 | ||||
|     x_center = | ||||
|         x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center(); | ||||
|     y_center = | ||||
|  | @ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection( | |||
| } | ||||
| 
 | ||||
| absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) { | ||||
|   int output_format_flag = 0; | ||||
|   switch (box_output_format_) { | ||||
|     case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED: | ||||
|     case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW: | ||||
|       output_format_flag = 0; | ||||
|       break; | ||||
|     case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH: | ||||
|       output_format_flag = 1; | ||||
|       break; | ||||
|     case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY: | ||||
|       output_format_flag = 2; | ||||
|       break; | ||||
|   } | ||||
| #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE | ||||
|   MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { | ||||
|   MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]() | ||||
|                                                     -> absl::Status { | ||||
|     // A shader to decode detection boxes.
 | ||||
|     const std::string decode_src = absl::Substitute( | ||||
|         R"( #version 310 es | ||||
|  | @ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 { | |||
| } raw_anchors; | ||||
| 
 | ||||
| uint num_coords = uint($0); | ||||
| int reverse_output_order = int($1); | ||||
| int output_format_flag = int($1); | ||||
| int apply_exponential = int($2); | ||||
| int box_coord_offset = int($3); | ||||
| int num_keypoints = int($4); | ||||
|  | @ -892,17 +935,25 @@ void main() { | |||
|   uint anchor_offset = g_idx * uint(4);  // check kNumCoordsPerBox
 | ||||
| 
 | ||||
|   float y_center, x_center, h, w; | ||||
| 
 | ||||
|   if (reverse_output_order == int(0)) { | ||||
|   if (output_format_flag == int(0)) { | ||||
|     y_center = raw_boxes.data[box_offset + uint(0)]; | ||||
|     x_center = raw_boxes.data[box_offset + uint(1)]; | ||||
|     h = raw_boxes.data[box_offset + uint(2)]; | ||||
|     w = raw_boxes.data[box_offset + uint(3)]; | ||||
|   } else { | ||||
|   } else if (output_format_flag == int(1)) { | ||||
|     x_center = raw_boxes.data[box_offset + uint(0)]; | ||||
|     y_center = raw_boxes.data[box_offset + uint(1)]; | ||||
|     w = raw_boxes.data[box_offset + uint(2)]; | ||||
|     h = raw_boxes.data[box_offset + uint(3)]; | ||||
|   } else if (output_format_flag == int(2)) { | ||||
|     x_center = (-raw_boxes.data[box_offset + uint(0)] | ||||
|                 +raw_boxes.data[box_offset + uint(2)]) / 2.0; | ||||
|     y_center = (-raw_boxes.data[box_offset + uint(1)] | ||||
|                 +raw_boxes.data[box_offset + uint(3)]) / 2.0; | ||||
|     w = raw_boxes.data[box_offset + uint(0)] | ||||
|       + raw_boxes.data[box_offset + uint(2)]; | ||||
|     h = raw_boxes.data[box_offset + uint(1)] | ||||
|       + raw_boxes.data[box_offset + uint(3)]; | ||||
|   } | ||||
| 
 | ||||
|   float anchor_yc = raw_anchors.data[anchor_offset + uint(0)]; | ||||
|  | @ -936,7 +987,7 @@ void main() { | |||
|       int kp_offset = | ||||
|         int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; | ||||
|       float kp_y, kp_x; | ||||
|       if (reverse_output_order == int(0)) { | ||||
|       if (output_format_flag == int(0)) { | ||||
|         kp_y = raw_boxes.data[kp_offset + int(0)]; | ||||
|         kp_x = raw_boxes.data[kp_offset + int(1)]; | ||||
|       } else { | ||||
|  | @ -949,8 +1000,7 @@ void main() { | |||
|   } | ||||
| })", | ||||
|         options_.num_coords(),  // box xywh
 | ||||
|         options_.reverse_output_order() ? 1 : 0, | ||||
|         options_.apply_exponential_on_box_size() ? 1 : 0, | ||||
|         output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0, | ||||
|         options_.box_coord_offset(), options_.num_keypoints(), | ||||
|         options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); | ||||
| 
 | ||||
|  | @ -1104,7 +1154,7 @@ kernel void decodeKernel( | |||
|     uint2                           gid         [[ thread_position_in_grid ]]) { | ||||
| 
 | ||||
|   uint num_coords = uint($0); | ||||
|   int reverse_output_order = int($1); | ||||
|   int output_format_flag = int($1); | ||||
|   int apply_exponential = int($2); | ||||
|   int box_coord_offset = int($3); | ||||
|   int num_keypoints = int($4); | ||||
|  | @ -1112,8 +1162,7 @@ kernel void decodeKernel( | |||
|   int num_values_per_keypt = int($6); | ||||
| )", | ||||
|       options_.num_coords(),  // box xywh
 | ||||
|       options_.reverse_output_order() ? 1 : 0, | ||||
|       options_.apply_exponential_on_box_size() ? 1 : 0, | ||||
|       output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0, | ||||
|       options_.box_coord_offset(), options_.num_keypoints(), | ||||
|       options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); | ||||
|   decode_src += absl::Substitute( | ||||
|  | @ -1129,16 +1178,25 @@ kernel void decodeKernel( | |||
| 
 | ||||
|   float y_center, x_center, h, w; | ||||
| 
 | ||||
|   if (reverse_output_order == int(0)) { | ||||
|   if (output_format_flag == int(0)) { | ||||
|     y_center = raw_boxes[box_offset + uint(0)]; | ||||
|     x_center = raw_boxes[box_offset + uint(1)]; | ||||
|     h = raw_boxes[box_offset + uint(2)]; | ||||
|     w = raw_boxes[box_offset + uint(3)]; | ||||
|   } else { | ||||
|   } else if (output_format_flag == int(1)) { | ||||
|     x_center = raw_boxes[box_offset + uint(0)]; | ||||
|     y_center = raw_boxes[box_offset + uint(1)]; | ||||
|     w = raw_boxes[box_offset + uint(2)]; | ||||
|     h = raw_boxes[box_offset + uint(3)]; | ||||
|   } else if (output_format_flag == int(2)) { | ||||
|     x_center = (-raw_boxes[box_offset + uint(0)] | ||||
|                 +raw_boxes[box_offset + uint(2)]) / 2.0; | ||||
|     y_center = (-raw_boxes[box_offset + uint(1)] | ||||
|                 +raw_boxes[box_offset + uint(3)]) / 2.0; | ||||
|     w = raw_boxes[box_offset + uint(0)] | ||||
|       + raw_boxes[box_offset + uint(2)]; | ||||
|     h = raw_boxes[box_offset + uint(1)] | ||||
|       + raw_boxes[box_offset + uint(3)]; | ||||
|   } | ||||
| 
 | ||||
|   float anchor_yc = raw_anchors[anchor_offset + uint(0)]; | ||||
|  | @ -1172,7 +1230,7 @@ kernel void decodeKernel( | |||
|       int kp_offset = | ||||
|         int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; | ||||
|       float kp_y, kp_x; | ||||
|       if (reverse_output_order == int(0)) { | ||||
|       if (output_format_flag == int(0)) { | ||||
|         kp_y = raw_boxes[kp_offset + int(0)]; | ||||
|         kp_x = raw_boxes[kp_offset + int(1)]; | ||||
|       } else { | ||||
|  |  | |||
|  | @ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions { | |||
|   // Whether to reverse the order of predicted x, y from output. | ||||
|   // If false, the order is [y_center, x_center, h, w], if true the order is | ||||
|   // [x_center, y_center, w, h]. | ||||
|   // DEPRECATED. Use `box_format` instead. | ||||
|   optional bool reverse_output_order = 14 [default = false]; | ||||
|   // The ids of classes that should be ignored during decoding the score for | ||||
|   // each predicted box. Can be overridden with IGNORE_CLASSES side packet. | ||||
|  | @ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions { | |||
|   oneof box_indices { | ||||
|     BoxBoundariesIndices box_boundaries_indices = 23; | ||||
|   } | ||||
| 
 | ||||
|   // Tells the calculator how to convert the detector output to bounding boxes. | ||||
|   // Replaces `reverse_output_order` to support more bbox output formats. | ||||
|   enum BoxFormat { | ||||
|     // if UNSPECIFIED, the calculator assumes YXHW | ||||
|     UNSPECIFIED = 0; | ||||
|     // [y_center, x_center, height, width] | ||||
|     YXHW = 1; | ||||
|     // [x_center, y_center, width, height] | ||||
|     XYWH = 2; | ||||
|     // [xmin, ymin, xmax, ymax] | ||||
|     XYXY = 3; | ||||
|   } | ||||
|   optional BoxFormat box_format = 24 [default = UNSPECIFIED]; | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user