diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 4bb3f0f57..45393d4f1 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -64,6 +64,8 @@ bool CanUseGpu() { namespace mediapipe { namespace api2 { +using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat; + namespace { void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, @@ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping( return absl::OkStatus(); } +BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) { + if (options.has_box_format()) { + return options.box_format(); + } else if (options.reverse_output_order()) { + return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH; + } + return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; +} + } // namespace // Convert result Tensors from object detection models into MediaPipe @@ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node { int num_boxes_ = 0; int num_coords_ = 0; int max_results_ = -1; + BoxFormat box_output_format_ = + mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; // Set of allowed or ignored class indices. struct ClassIndexSet { @@ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { num_classes_ = options_.num_classes(); num_boxes_ = options_.num_boxes(); num_coords_ = options_.num_coords(); + box_output_format_ = GetBoxFormat(options_); CHECK_NE(options_.max_results(), 0) << "The maximum number of the top-scored detection results must be " "non-zero."; @@ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes( for (int i = 0; i < num_boxes_; ++i) { const int box_offset = i * num_coords_ + options_.box_coord_offset(); - float y_center = raw_boxes[box_offset]; - float x_center = raw_boxes[box_offset + 1]; - float h = raw_boxes[box_offset + 2]; - float w = raw_boxes[box_offset + 3]; - if (options_.reverse_output_order()) { - x_center = raw_boxes[box_offset]; - y_center = raw_boxes[box_offset + 1]; - w = raw_boxes[box_offset + 2]; - h = raw_boxes[box_offset + 3]; + float y_center = 0.0; + float x_center = 0.0; + float h = 0.0; + float w = 0.0; + // TODO + switch (box_output_format_) { + case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED: + case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW: + y_center = raw_boxes[box_offset]; + x_center = raw_boxes[box_offset + 1]; + h = raw_boxes[box_offset + 2]; + w = raw_boxes[box_offset + 3]; + break; + case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH: + x_center = raw_boxes[box_offset]; + y_center = raw_boxes[box_offset + 1]; + w = raw_boxes[box_offset + 2]; + h = raw_boxes[box_offset + 3]; + break; + case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY: + x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2; + y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2; + w = raw_boxes[box_offset + 2] + raw_boxes[box_offset]; + h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1]; + break; } - x_center = x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center(); y_center = @@ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection( } absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) { + int output_format_flag = 0; + switch (box_output_format_) { + case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED: + case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW: + output_format_flag = 0; + break; + case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH: + output_format_flag = 1; + break; + case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY: + output_format_flag = 2; + break; + } #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]() + -> absl::Status { // A shader to decode detection boxes. const std::string decode_src = absl::Substitute( R"( #version 310 es @@ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 { } raw_anchors; uint num_coords = uint($0); -int reverse_output_order = int($1); +int output_format_flag = int($1); int apply_exponential = int($2); int box_coord_offset = int($3); int num_keypoints = int($4); @@ -892,17 +935,25 @@ void main() { uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox float y_center, x_center, h, w; - - if (reverse_output_order == int(0)) { + if (output_format_flag == int(0)) { y_center = raw_boxes.data[box_offset + uint(0)]; x_center = raw_boxes.data[box_offset + uint(1)]; h = raw_boxes.data[box_offset + uint(2)]; w = raw_boxes.data[box_offset + uint(3)]; - } else { + } else if (output_format_flag == int(1)) { x_center = raw_boxes.data[box_offset + uint(0)]; y_center = raw_boxes.data[box_offset + uint(1)]; w = raw_boxes.data[box_offset + uint(2)]; h = raw_boxes.data[box_offset + uint(3)]; + } else if (output_format_flag == int(2)) { + x_center = (-raw_boxes.data[box_offset + uint(0)] + +raw_boxes.data[box_offset + uint(2)]) / 2.0; + y_center = (-raw_boxes.data[box_offset + uint(1)] + +raw_boxes.data[box_offset + uint(3)]) / 2.0; + w = raw_boxes.data[box_offset + uint(0)] + + raw_boxes.data[box_offset + uint(2)]; + h = raw_boxes.data[box_offset + uint(1)] + + raw_boxes.data[box_offset + uint(3)]; } float anchor_yc = raw_anchors.data[anchor_offset + uint(0)]; @@ -936,7 +987,7 @@ void main() { int kp_offset = int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; float kp_y, kp_x; - if (reverse_output_order == int(0)) { + if (output_format_flag == int(0)) { kp_y = raw_boxes.data[kp_offset + int(0)]; kp_x = raw_boxes.data[kp_offset + int(1)]; } else { @@ -949,8 +1000,7 @@ void main() { } })", options_.num_coords(), // box xywh - options_.reverse_output_order() ? 1 : 0, - options_.apply_exponential_on_box_size() ? 1 : 0, + output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0, options_.box_coord_offset(), options_.num_keypoints(), options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); @@ -1104,7 +1154,7 @@ kernel void decodeKernel( uint2 gid [[ thread_position_in_grid ]]) { uint num_coords = uint($0); - int reverse_output_order = int($1); + int output_format_flag = int($1); int apply_exponential = int($2); int box_coord_offset = int($3); int num_keypoints = int($4); @@ -1112,8 +1162,7 @@ kernel void decodeKernel( int num_values_per_keypt = int($6); )", options_.num_coords(), // box xywh - options_.reverse_output_order() ? 1 : 0, - options_.apply_exponential_on_box_size() ? 1 : 0, + output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0, options_.box_coord_offset(), options_.num_keypoints(), options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); decode_src += absl::Substitute( @@ -1129,16 +1178,25 @@ kernel void decodeKernel( float y_center, x_center, h, w; - if (reverse_output_order == int(0)) { + if (output_format_flag == int(0)) { y_center = raw_boxes[box_offset + uint(0)]; x_center = raw_boxes[box_offset + uint(1)]; h = raw_boxes[box_offset + uint(2)]; w = raw_boxes[box_offset + uint(3)]; - } else { + } else if (output_format_flag == int(1)) { x_center = raw_boxes[box_offset + uint(0)]; y_center = raw_boxes[box_offset + uint(1)]; w = raw_boxes[box_offset + uint(2)]; h = raw_boxes[box_offset + uint(3)]; + } else if (output_format_flag == int(2)) { + x_center = (-raw_boxes[box_offset + uint(0)] + +raw_boxes[box_offset + uint(2)]) / 2.0; + y_center = (-raw_boxes[box_offset + uint(1)] + +raw_boxes[box_offset + uint(3)]) / 2.0; + w = raw_boxes[box_offset + uint(0)] + + raw_boxes[box_offset + uint(2)]; + h = raw_boxes[box_offset + uint(1)] + + raw_boxes[box_offset + uint(3)]; } float anchor_yc = raw_anchors[anchor_offset + uint(0)]; @@ -1172,7 +1230,7 @@ kernel void decodeKernel( int kp_offset = int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; float kp_y, kp_x; - if (reverse_output_order == int(0)) { + if (output_format_flag == int(0)) { kp_y = raw_boxes[kp_offset + int(0)]; kp_x = raw_boxes[kp_offset + int(1)]; } else { diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto index c9d6b69da..1ebdcce0b 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.proto @@ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions { // Whether to reverse the order of predicted x, y from output. // If false, the order is [y_center, x_center, h, w], if true the order is // [x_center, y_center, w, h]. + // DEPRECATED. Use `box_format` instead. optional bool reverse_output_order = 14 [default = false]; // The ids of classes that should be ignored during decoding the score for // each predicted box. Can be overridden with IGNORE_CLASSES side packet. @@ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions { oneof box_indices { BoxBoundariesIndices box_boundaries_indices = 23; } + + // Tells the calculator how to convert the detector output to bounding boxes. + // Replaces `reverse_output_order` to support more bbox output formats. + enum BoxFormat { + // if UNSPECIFIED, the calculator assumes YXHW + UNSPECIFIED = 0; + // [y_center, x_center, height, width] + YXHW = 1; + // [x_center, y_center, width, height] + XYWH = 2; + // [xmin, ymin, xmax, ymax] + XYXY = 3; + } + optional BoxFormat box_format = 24 [default = UNSPECIFIED]; }