Add support for [xmin, ymin, xmax, ymax] style of bbox output
PiperOrigin-RevId: 509942540
This commit is contained in:
parent
40c3e72c9c
commit
796a96d842
|
@ -64,6 +64,8 @@ bool CanUseGpu() {
|
|||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat;
|
||||
|
||||
namespace {
|
||||
|
||||
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
||||
|
@ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
|
||||
if (options.has_box_format()) {
|
||||
return options.box_format();
|
||||
} else if (options.reverse_output_order()) {
|
||||
return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH;
|
||||
}
|
||||
return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Convert result Tensors from object detection models into MediaPipe
|
||||
|
@ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
int num_boxes_ = 0;
|
||||
int num_coords_ = 0;
|
||||
int max_results_ = -1;
|
||||
BoxFormat box_output_format_ =
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||
|
||||
// Set of allowed or ignored class indices.
|
||||
struct ClassIndexSet {
|
||||
|
@ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
|
|||
num_classes_ = options_.num_classes();
|
||||
num_boxes_ = options_.num_boxes();
|
||||
num_coords_ = options_.num_coords();
|
||||
box_output_format_ = GetBoxFormat(options_);
|
||||
CHECK_NE(options_.max_results(), 0)
|
||||
<< "The maximum number of the top-scored detection results must be "
|
||||
"non-zero.";
|
||||
|
@ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
|
|||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
const int box_offset = i * num_coords_ + options_.box_coord_offset();
|
||||
|
||||
float y_center = raw_boxes[box_offset];
|
||||
float x_center = raw_boxes[box_offset + 1];
|
||||
float h = raw_boxes[box_offset + 2];
|
||||
float w = raw_boxes[box_offset + 3];
|
||||
if (options_.reverse_output_order()) {
|
||||
x_center = raw_boxes[box_offset];
|
||||
y_center = raw_boxes[box_offset + 1];
|
||||
w = raw_boxes[box_offset + 2];
|
||||
h = raw_boxes[box_offset + 3];
|
||||
float y_center = 0.0;
|
||||
float x_center = 0.0;
|
||||
float h = 0.0;
|
||||
float w = 0.0;
|
||||
// TODO
|
||||
switch (box_output_format_) {
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
|
||||
y_center = raw_boxes[box_offset];
|
||||
x_center = raw_boxes[box_offset + 1];
|
||||
h = raw_boxes[box_offset + 2];
|
||||
w = raw_boxes[box_offset + 3];
|
||||
break;
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
|
||||
x_center = raw_boxes[box_offset];
|
||||
y_center = raw_boxes[box_offset + 1];
|
||||
w = raw_boxes[box_offset + 2];
|
||||
h = raw_boxes[box_offset + 3];
|
||||
break;
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
|
||||
x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2;
|
||||
y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2;
|
||||
w = raw_boxes[box_offset + 2] + raw_boxes[box_offset];
|
||||
h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1];
|
||||
break;
|
||||
}
|
||||
|
||||
x_center =
|
||||
x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
|
||||
y_center =
|
||||
|
@ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
|||
}
|
||||
|
||||
absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) {
|
||||
int output_format_flag = 0;
|
||||
switch (box_output_format_) {
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
|
||||
output_format_flag = 0;
|
||||
break;
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
|
||||
output_format_flag = 1;
|
||||
break;
|
||||
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
|
||||
output_format_flag = 2;
|
||||
break;
|
||||
}
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]()
|
||||
-> absl::Status {
|
||||
// A shader to decode detection boxes.
|
||||
const std::string decode_src = absl::Substitute(
|
||||
R"( #version 310 es
|
||||
|
@ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 {
|
|||
} raw_anchors;
|
||||
|
||||
uint num_coords = uint($0);
|
||||
int reverse_output_order = int($1);
|
||||
int output_format_flag = int($1);
|
||||
int apply_exponential = int($2);
|
||||
int box_coord_offset = int($3);
|
||||
int num_keypoints = int($4);
|
||||
|
@ -892,17 +935,25 @@ void main() {
|
|||
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
|
||||
|
||||
float y_center, x_center, h, w;
|
||||
|
||||
if (reverse_output_order == int(0)) {
|
||||
if (output_format_flag == int(0)) {
|
||||
y_center = raw_boxes.data[box_offset + uint(0)];
|
||||
x_center = raw_boxes.data[box_offset + uint(1)];
|
||||
h = raw_boxes.data[box_offset + uint(2)];
|
||||
w = raw_boxes.data[box_offset + uint(3)];
|
||||
} else {
|
||||
} else if (output_format_flag == int(1)) {
|
||||
x_center = raw_boxes.data[box_offset + uint(0)];
|
||||
y_center = raw_boxes.data[box_offset + uint(1)];
|
||||
w = raw_boxes.data[box_offset + uint(2)];
|
||||
h = raw_boxes.data[box_offset + uint(3)];
|
||||
} else if (output_format_flag == int(2)) {
|
||||
x_center = (-raw_boxes.data[box_offset + uint(0)]
|
||||
+raw_boxes.data[box_offset + uint(2)]) / 2.0;
|
||||
y_center = (-raw_boxes.data[box_offset + uint(1)]
|
||||
+raw_boxes.data[box_offset + uint(3)]) / 2.0;
|
||||
w = raw_boxes.data[box_offset + uint(0)]
|
||||
+ raw_boxes.data[box_offset + uint(2)];
|
||||
h = raw_boxes.data[box_offset + uint(1)]
|
||||
+ raw_boxes.data[box_offset + uint(3)];
|
||||
}
|
||||
|
||||
float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
|
||||
|
@ -936,7 +987,7 @@ void main() {
|
|||
int kp_offset =
|
||||
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
||||
float kp_y, kp_x;
|
||||
if (reverse_output_order == int(0)) {
|
||||
if (output_format_flag == int(0)) {
|
||||
kp_y = raw_boxes.data[kp_offset + int(0)];
|
||||
kp_x = raw_boxes.data[kp_offset + int(1)];
|
||||
} else {
|
||||
|
@ -949,8 +1000,7 @@ void main() {
|
|||
}
|
||||
})",
|
||||
options_.num_coords(), // box xywh
|
||||
options_.reverse_output_order() ? 1 : 0,
|
||||
options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||
options_.box_coord_offset(), options_.num_keypoints(),
|
||||
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
||||
|
||||
|
@ -1104,7 +1154,7 @@ kernel void decodeKernel(
|
|||
uint2 gid [[ thread_position_in_grid ]]) {
|
||||
|
||||
uint num_coords = uint($0);
|
||||
int reverse_output_order = int($1);
|
||||
int output_format_flag = int($1);
|
||||
int apply_exponential = int($2);
|
||||
int box_coord_offset = int($3);
|
||||
int num_keypoints = int($4);
|
||||
|
@ -1112,8 +1162,7 @@ kernel void decodeKernel(
|
|||
int num_values_per_keypt = int($6);
|
||||
)",
|
||||
options_.num_coords(), // box xywh
|
||||
options_.reverse_output_order() ? 1 : 0,
|
||||
options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||
options_.box_coord_offset(), options_.num_keypoints(),
|
||||
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
||||
decode_src += absl::Substitute(
|
||||
|
@ -1129,16 +1178,25 @@ kernel void decodeKernel(
|
|||
|
||||
float y_center, x_center, h, w;
|
||||
|
||||
if (reverse_output_order == int(0)) {
|
||||
if (output_format_flag == int(0)) {
|
||||
y_center = raw_boxes[box_offset + uint(0)];
|
||||
x_center = raw_boxes[box_offset + uint(1)];
|
||||
h = raw_boxes[box_offset + uint(2)];
|
||||
w = raw_boxes[box_offset + uint(3)];
|
||||
} else {
|
||||
} else if (output_format_flag == int(1)) {
|
||||
x_center = raw_boxes[box_offset + uint(0)];
|
||||
y_center = raw_boxes[box_offset + uint(1)];
|
||||
w = raw_boxes[box_offset + uint(2)];
|
||||
h = raw_boxes[box_offset + uint(3)];
|
||||
} else if (output_format_flag == int(2)) {
|
||||
x_center = (-raw_boxes[box_offset + uint(0)]
|
||||
+raw_boxes[box_offset + uint(2)]) / 2.0;
|
||||
y_center = (-raw_boxes[box_offset + uint(1)]
|
||||
+raw_boxes[box_offset + uint(3)]) / 2.0;
|
||||
w = raw_boxes[box_offset + uint(0)]
|
||||
+ raw_boxes[box_offset + uint(2)];
|
||||
h = raw_boxes[box_offset + uint(1)]
|
||||
+ raw_boxes[box_offset + uint(3)];
|
||||
}
|
||||
|
||||
float anchor_yc = raw_anchors[anchor_offset + uint(0)];
|
||||
|
@ -1172,7 +1230,7 @@ kernel void decodeKernel(
|
|||
int kp_offset =
|
||||
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
||||
float kp_y, kp_x;
|
||||
if (reverse_output_order == int(0)) {
|
||||
if (output_format_flag == int(0)) {
|
||||
kp_y = raw_boxes[kp_offset + int(0)];
|
||||
kp_x = raw_boxes[kp_offset + int(1)];
|
||||
} else {
|
||||
|
|
|
@ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions {
|
|||
// Whether to reverse the order of predicted x, y from output.
|
||||
// If false, the order is [y_center, x_center, h, w], if true the order is
|
||||
// [x_center, y_center, w, h].
|
||||
// DEPRECATED. Use `box_format` instead.
|
||||
optional bool reverse_output_order = 14 [default = false];
|
||||
// The ids of classes that should be ignored during decoding the score for
|
||||
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
|
||||
|
@ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions {
|
|||
oneof box_indices {
|
||||
BoxBoundariesIndices box_boundaries_indices = 23;
|
||||
}
|
||||
|
||||
// Tells the calculator how to convert the detector output to bounding boxes.
|
||||
// Replaces `reverse_output_order` to support more bbox output formats.
|
||||
enum BoxFormat {
|
||||
// if UNSPECIFIED, the calculator assumes YXHW
|
||||
UNSPECIFIED = 0;
|
||||
// [y_center, x_center, height, width]
|
||||
YXHW = 1;
|
||||
// [x_center, y_center, width, height]
|
||||
XYWH = 2;
|
||||
// [xmin, ymin, xmax, ymax]
|
||||
XYXY = 3;
|
||||
}
|
||||
optional BoxFormat box_format = 24 [default = UNSPECIFIED];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user