Add support for [xmin, ymin, xmax, ymax] style of bbox output

PiperOrigin-RevId: 509942540
This commit is contained in:
MediaPipe Team 2023-02-15 15:03:18 -08:00 committed by Copybara-Service
parent 40c3e72c9c
commit 796a96d842
2 changed files with 97 additions and 24 deletions

View File

@ -64,6 +64,8 @@ bool CanUseGpu() {
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat;
namespace { namespace {
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
@ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping(
return absl::OkStatus(); return absl::OkStatus();
} }
BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
if (options.has_box_format()) {
return options.box_format();
} else if (options.reverse_output_order()) {
return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH;
}
return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
}
} // namespace } // namespace
// Convert result Tensors from object detection models into MediaPipe // Convert result Tensors from object detection models into MediaPipe
@ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node {
int num_boxes_ = 0; int num_boxes_ = 0;
int num_coords_ = 0; int num_coords_ = 0;
int max_results_ = -1; int max_results_ = -1;
BoxFormat box_output_format_ =
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
// Set of allowed or ignored class indices. // Set of allowed or ignored class indices.
struct ClassIndexSet { struct ClassIndexSet {
@ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
num_classes_ = options_.num_classes(); num_classes_ = options_.num_classes();
num_boxes_ = options_.num_boxes(); num_boxes_ = options_.num_boxes();
num_coords_ = options_.num_coords(); num_coords_ = options_.num_coords();
box_output_format_ = GetBoxFormat(options_);
CHECK_NE(options_.max_results(), 0) CHECK_NE(options_.max_results(), 0)
<< "The maximum number of the top-scored detection results must be " << "The maximum number of the top-scored detection results must be "
"non-zero."; "non-zero.";
@ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < num_boxes_; ++i) {
const int box_offset = i * num_coords_ + options_.box_coord_offset(); const int box_offset = i * num_coords_ + options_.box_coord_offset();
float y_center = raw_boxes[box_offset]; float y_center = 0.0;
float x_center = raw_boxes[box_offset + 1]; float x_center = 0.0;
float h = raw_boxes[box_offset + 2]; float h = 0.0;
float w = raw_boxes[box_offset + 3]; float w = 0.0;
if (options_.reverse_output_order()) { // TODO
switch (box_output_format_) {
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
y_center = raw_boxes[box_offset];
x_center = raw_boxes[box_offset + 1];
h = raw_boxes[box_offset + 2];
w = raw_boxes[box_offset + 3];
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
x_center = raw_boxes[box_offset]; x_center = raw_boxes[box_offset];
y_center = raw_boxes[box_offset + 1]; y_center = raw_boxes[box_offset + 1];
w = raw_boxes[box_offset + 2]; w = raw_boxes[box_offset + 2];
h = raw_boxes[box_offset + 3]; h = raw_boxes[box_offset + 3];
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2;
y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2;
w = raw_boxes[box_offset + 2] + raw_boxes[box_offset];
h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1];
break;
} }
x_center = x_center =
x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center(); x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
y_center = y_center =
@ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection(
} }
absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) { absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) {
int output_format_flag = 0;
switch (box_output_format_) {
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
output_format_flag = 0;
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
output_format_flag = 1;
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
output_format_flag = 2;
break;
}
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]()
-> absl::Status {
// A shader to decode detection boxes. // A shader to decode detection boxes.
const std::string decode_src = absl::Substitute( const std::string decode_src = absl::Substitute(
R"( #version 310 es R"( #version 310 es
@ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 {
} raw_anchors; } raw_anchors;
uint num_coords = uint($0); uint num_coords = uint($0);
int reverse_output_order = int($1); int output_format_flag = int($1);
int apply_exponential = int($2); int apply_exponential = int($2);
int box_coord_offset = int($3); int box_coord_offset = int($3);
int num_keypoints = int($4); int num_keypoints = int($4);
@ -892,17 +935,25 @@ void main() {
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
float y_center, x_center, h, w; float y_center, x_center, h, w;
if (output_format_flag == int(0)) {
if (reverse_output_order == int(0)) {
y_center = raw_boxes.data[box_offset + uint(0)]; y_center = raw_boxes.data[box_offset + uint(0)];
x_center = raw_boxes.data[box_offset + uint(1)]; x_center = raw_boxes.data[box_offset + uint(1)];
h = raw_boxes.data[box_offset + uint(2)]; h = raw_boxes.data[box_offset + uint(2)];
w = raw_boxes.data[box_offset + uint(3)]; w = raw_boxes.data[box_offset + uint(3)];
} else { } else if (output_format_flag == int(1)) {
x_center = raw_boxes.data[box_offset + uint(0)]; x_center = raw_boxes.data[box_offset + uint(0)];
y_center = raw_boxes.data[box_offset + uint(1)]; y_center = raw_boxes.data[box_offset + uint(1)];
w = raw_boxes.data[box_offset + uint(2)]; w = raw_boxes.data[box_offset + uint(2)];
h = raw_boxes.data[box_offset + uint(3)]; h = raw_boxes.data[box_offset + uint(3)];
} else if (output_format_flag == int(2)) {
x_center = (-raw_boxes.data[box_offset + uint(0)]
+raw_boxes.data[box_offset + uint(2)]) / 2.0;
y_center = (-raw_boxes.data[box_offset + uint(1)]
+raw_boxes.data[box_offset + uint(3)]) / 2.0;
w = raw_boxes.data[box_offset + uint(0)]
+ raw_boxes.data[box_offset + uint(2)];
h = raw_boxes.data[box_offset + uint(1)]
+ raw_boxes.data[box_offset + uint(3)];
} }
float anchor_yc = raw_anchors.data[anchor_offset + uint(0)]; float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
@ -936,7 +987,7 @@ void main() {
int kp_offset = int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x; float kp_y, kp_x;
if (reverse_output_order == int(0)) { if (output_format_flag == int(0)) {
kp_y = raw_boxes.data[kp_offset + int(0)]; kp_y = raw_boxes.data[kp_offset + int(0)];
kp_x = raw_boxes.data[kp_offset + int(1)]; kp_x = raw_boxes.data[kp_offset + int(1)];
} else { } else {
@ -949,8 +1000,7 @@ void main() {
} }
})", })",
options_.num_coords(), // box xywh options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0, output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(), options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
@ -1104,7 +1154,7 @@ kernel void decodeKernel(
uint2 gid [[ thread_position_in_grid ]]) { uint2 gid [[ thread_position_in_grid ]]) {
uint num_coords = uint($0); uint num_coords = uint($0);
int reverse_output_order = int($1); int output_format_flag = int($1);
int apply_exponential = int($2); int apply_exponential = int($2);
int box_coord_offset = int($3); int box_coord_offset = int($3);
int num_keypoints = int($4); int num_keypoints = int($4);
@ -1112,8 +1162,7 @@ kernel void decodeKernel(
int num_values_per_keypt = int($6); int num_values_per_keypt = int($6);
)", )",
options_.num_coords(), // box xywh options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0, output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(), options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
decode_src += absl::Substitute( decode_src += absl::Substitute(
@ -1129,16 +1178,25 @@ kernel void decodeKernel(
float y_center, x_center, h, w; float y_center, x_center, h, w;
if (reverse_output_order == int(0)) { if (output_format_flag == int(0)) {
y_center = raw_boxes[box_offset + uint(0)]; y_center = raw_boxes[box_offset + uint(0)];
x_center = raw_boxes[box_offset + uint(1)]; x_center = raw_boxes[box_offset + uint(1)];
h = raw_boxes[box_offset + uint(2)]; h = raw_boxes[box_offset + uint(2)];
w = raw_boxes[box_offset + uint(3)]; w = raw_boxes[box_offset + uint(3)];
} else { } else if (output_format_flag == int(1)) {
x_center = raw_boxes[box_offset + uint(0)]; x_center = raw_boxes[box_offset + uint(0)];
y_center = raw_boxes[box_offset + uint(1)]; y_center = raw_boxes[box_offset + uint(1)];
w = raw_boxes[box_offset + uint(2)]; w = raw_boxes[box_offset + uint(2)];
h = raw_boxes[box_offset + uint(3)]; h = raw_boxes[box_offset + uint(3)];
} else if (output_format_flag == int(2)) {
x_center = (-raw_boxes[box_offset + uint(0)]
+raw_boxes[box_offset + uint(2)]) / 2.0;
y_center = (-raw_boxes[box_offset + uint(1)]
+raw_boxes[box_offset + uint(3)]) / 2.0;
w = raw_boxes[box_offset + uint(0)]
+ raw_boxes[box_offset + uint(2)];
h = raw_boxes[box_offset + uint(1)]
+ raw_boxes[box_offset + uint(3)];
} }
float anchor_yc = raw_anchors[anchor_offset + uint(0)]; float anchor_yc = raw_anchors[anchor_offset + uint(0)];
@ -1172,7 +1230,7 @@ kernel void decodeKernel(
int kp_offset = int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x; float kp_y, kp_x;
if (reverse_output_order == int(0)) { if (output_format_flag == int(0)) {
kp_y = raw_boxes[kp_offset + int(0)]; kp_y = raw_boxes[kp_offset + int(0)];
kp_x = raw_boxes[kp_offset + int(1)]; kp_x = raw_boxes[kp_offset + int(1)];
} else { } else {

View File

@ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions {
// Whether to reverse the order of predicted x, y from output. // Whether to reverse the order of predicted x, y from output.
// If false, the order is [y_center, x_center, h, w], if true the order is // If false, the order is [y_center, x_center, h, w], if true the order is
// [x_center, y_center, w, h]. // [x_center, y_center, w, h].
// DEPRECATED. Use `box_format` instead.
optional bool reverse_output_order = 14 [default = false]; optional bool reverse_output_order = 14 [default = false];
// The ids of classes that should be ignored during decoding the score for // The ids of classes that should be ignored during decoding the score for
// each predicted box. Can be overridden with IGNORE_CLASSES side packet. // each predicted box. Can be overridden with IGNORE_CLASSES side packet.
@ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions {
oneof box_indices { oneof box_indices {
BoxBoundariesIndices box_boundaries_indices = 23; BoxBoundariesIndices box_boundaries_indices = 23;
} }
// Tells the calculator how to convert the detector output to bounding boxes.
// Replaces `reverse_output_order` to support more bbox output formats.
enum BoxFormat {
// if UNSPECIFIED, the calculator assumes YXHW
UNSPECIFIED = 0;
// [y_center, x_center, height, width]
YXHW = 1;
// [x_center, y_center, width, height]
XYWH = 2;
// [xmin, ymin, xmax, ymax]
XYXY = 3;
}
optional BoxFormat box_format = 24 [default = UNSPECIFIED];
} }