Add support for [xmin, ymin, xmax, ymax] style of bbox output

PiperOrigin-RevId: 509942540
This commit is contained in:
MediaPipe Team 2023-02-15 15:03:18 -08:00 committed by Copybara-Service
parent 40c3e72c9c
commit 796a96d842
2 changed files with 97 additions and 24 deletions

View File

@ -64,6 +64,8 @@ bool CanUseGpu() {
namespace mediapipe {
namespace api2 {
using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat;
namespace {
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
@ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping(
return absl::OkStatus();
}
BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
if (options.has_box_format()) {
return options.box_format();
} else if (options.reverse_output_order()) {
return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH;
}
return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
}
} // namespace
// Convert result Tensors from object detection models into MediaPipe
@ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node {
int num_boxes_ = 0;
int num_coords_ = 0;
int max_results_ = -1;
BoxFormat box_output_format_ =
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
// Set of allowed or ignored class indices.
struct ClassIndexSet {
@ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
num_classes_ = options_.num_classes();
num_boxes_ = options_.num_boxes();
num_coords_ = options_.num_coords();
box_output_format_ = GetBoxFormat(options_);
CHECK_NE(options_.max_results(), 0)
<< "The maximum number of the top-scored detection results must be "
"non-zero.";
@ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
for (int i = 0; i < num_boxes_; ++i) {
const int box_offset = i * num_coords_ + options_.box_coord_offset();
float y_center = raw_boxes[box_offset];
float x_center = raw_boxes[box_offset + 1];
float h = raw_boxes[box_offset + 2];
float w = raw_boxes[box_offset + 3];
if (options_.reverse_output_order()) {
x_center = raw_boxes[box_offset];
y_center = raw_boxes[box_offset + 1];
w = raw_boxes[box_offset + 2];
h = raw_boxes[box_offset + 3];
float y_center = 0.0;
float x_center = 0.0;
float h = 0.0;
float w = 0.0;
// TODO
switch (box_output_format_) {
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
y_center = raw_boxes[box_offset];
x_center = raw_boxes[box_offset + 1];
h = raw_boxes[box_offset + 2];
w = raw_boxes[box_offset + 3];
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
x_center = raw_boxes[box_offset];
y_center = raw_boxes[box_offset + 1];
w = raw_boxes[box_offset + 2];
h = raw_boxes[box_offset + 3];
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2;
y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2;
w = raw_boxes[box_offset + 2] + raw_boxes[box_offset];
h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1];
break;
}
x_center =
x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
y_center =
@ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection(
}
absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) {
int output_format_flag = 0;
switch (box_output_format_) {
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
output_format_flag = 0;
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
output_format_flag = 1;
break;
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
output_format_flag = 2;
break;
}
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]()
-> absl::Status {
// A shader to decode detection boxes.
const std::string decode_src = absl::Substitute(
R"( #version 310 es
@ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 {
} raw_anchors;
uint num_coords = uint($0);
int reverse_output_order = int($1);
int output_format_flag = int($1);
int apply_exponential = int($2);
int box_coord_offset = int($3);
int num_keypoints = int($4);
@ -892,17 +935,25 @@ void main() {
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
float y_center, x_center, h, w;
if (reverse_output_order == int(0)) {
if (output_format_flag == int(0)) {
y_center = raw_boxes.data[box_offset + uint(0)];
x_center = raw_boxes.data[box_offset + uint(1)];
h = raw_boxes.data[box_offset + uint(2)];
w = raw_boxes.data[box_offset + uint(3)];
} else {
} else if (output_format_flag == int(1)) {
x_center = raw_boxes.data[box_offset + uint(0)];
y_center = raw_boxes.data[box_offset + uint(1)];
w = raw_boxes.data[box_offset + uint(2)];
h = raw_boxes.data[box_offset + uint(3)];
} else if (output_format_flag == int(2)) {
x_center = (-raw_boxes.data[box_offset + uint(0)]
+raw_boxes.data[box_offset + uint(2)]) / 2.0;
y_center = (-raw_boxes.data[box_offset + uint(1)]
+raw_boxes.data[box_offset + uint(3)]) / 2.0;
w = raw_boxes.data[box_offset + uint(0)]
+ raw_boxes.data[box_offset + uint(2)];
h = raw_boxes.data[box_offset + uint(1)]
+ raw_boxes.data[box_offset + uint(3)];
}
float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
@ -936,7 +987,7 @@ void main() {
int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x;
if (reverse_output_order == int(0)) {
if (output_format_flag == int(0)) {
kp_y = raw_boxes.data[kp_offset + int(0)];
kp_x = raw_boxes.data[kp_offset + int(1)];
} else {
@ -949,8 +1000,7 @@ void main() {
}
})",
options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0,
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
@ -1104,7 +1154,7 @@ kernel void decodeKernel(
uint2 gid [[ thread_position_in_grid ]]) {
uint num_coords = uint($0);
int reverse_output_order = int($1);
int output_format_flag = int($1);
int apply_exponential = int($2);
int box_coord_offset = int($3);
int num_keypoints = int($4);
@ -1112,8 +1162,7 @@ kernel void decodeKernel(
int num_values_per_keypt = int($6);
)",
options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0,
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
decode_src += absl::Substitute(
@ -1129,16 +1178,25 @@ kernel void decodeKernel(
float y_center, x_center, h, w;
if (reverse_output_order == int(0)) {
if (output_format_flag == int(0)) {
y_center = raw_boxes[box_offset + uint(0)];
x_center = raw_boxes[box_offset + uint(1)];
h = raw_boxes[box_offset + uint(2)];
w = raw_boxes[box_offset + uint(3)];
} else {
} else if (output_format_flag == int(1)) {
x_center = raw_boxes[box_offset + uint(0)];
y_center = raw_boxes[box_offset + uint(1)];
w = raw_boxes[box_offset + uint(2)];
h = raw_boxes[box_offset + uint(3)];
} else if (output_format_flag == int(2)) {
x_center = (-raw_boxes[box_offset + uint(0)]
+raw_boxes[box_offset + uint(2)]) / 2.0;
y_center = (-raw_boxes[box_offset + uint(1)]
+raw_boxes[box_offset + uint(3)]) / 2.0;
w = raw_boxes[box_offset + uint(0)]
+ raw_boxes[box_offset + uint(2)];
h = raw_boxes[box_offset + uint(1)]
+ raw_boxes[box_offset + uint(3)];
}
float anchor_yc = raw_anchors[anchor_offset + uint(0)];
@ -1172,7 +1230,7 @@ kernel void decodeKernel(
int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x;
if (reverse_output_order == int(0)) {
if (output_format_flag == int(0)) {
kp_y = raw_boxes[kp_offset + int(0)];
kp_x = raw_boxes[kp_offset + int(1)];
} else {

View File

@ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions {
// Whether to reverse the order of predicted x, y from output.
// If false, the order is [y_center, x_center, h, w], if true the order is
// [x_center, y_center, w, h].
// DEPRECATED. Use `box_format` instead.
optional bool reverse_output_order = 14 [default = false];
// The ids of classes that should be ignored during decoding the score for
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
@ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions {
oneof box_indices {
BoxBoundariesIndices box_boundaries_indices = 23;
}
// Tells the calculator how to convert the detector output to bounding boxes.
// Replaces `reverse_output_order` to support more bbox output formats.
enum BoxFormat {
// if UNSPECIFIED, the calculator assumes YXHW
UNSPECIFIED = 0;
// [y_center, x_center, height, width]
YXHW = 1;
// [x_center, y_center, width, height]
XYWH = 2;
// [xmin, ymin, xmax, ymax]
XYXY = 3;
}
optional BoxFormat box_format = 24 [default = UNSPECIFIED];
}