Add support for [xmin, ymin, xmax, ymax] style of bbox output
PiperOrigin-RevId: 509942540
This commit is contained in:
parent
40c3e72c9c
commit
796a96d842
|
@ -64,6 +64,8 @@ bool CanUseGpu() {
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
|
using BoxFormat = ::mediapipe::TensorsToDetectionsCalculatorOptions::BoxFormat;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
|
||||||
|
@ -126,6 +128,15 @@ absl::Status CheckCustomTensorMapping(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BoxFormat GetBoxFormat(const TensorsToDetectionsCalculatorOptions& options) {
|
||||||
|
if (options.has_box_format()) {
|
||||||
|
return options.box_format();
|
||||||
|
} else if (options.reverse_output_order()) {
|
||||||
|
return mediapipe::TensorsToDetectionsCalculatorOptions::XYWH;
|
||||||
|
}
|
||||||
|
return mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Convert result Tensors from object detection models into MediaPipe
|
// Convert result Tensors from object detection models into MediaPipe
|
||||||
|
@ -211,6 +222,8 @@ class TensorsToDetectionsCalculator : public Node {
|
||||||
int num_boxes_ = 0;
|
int num_boxes_ = 0;
|
||||||
int num_coords_ = 0;
|
int num_coords_ = 0;
|
||||||
int max_results_ = -1;
|
int max_results_ = -1;
|
||||||
|
BoxFormat box_output_format_ =
|
||||||
|
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||||
|
|
||||||
// Set of allowed or ignored class indices.
|
// Set of allowed or ignored class indices.
|
||||||
struct ClassIndexSet {
|
struct ClassIndexSet {
|
||||||
|
@ -655,6 +668,7 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
|
||||||
num_classes_ = options_.num_classes();
|
num_classes_ = options_.num_classes();
|
||||||
num_boxes_ = options_.num_boxes();
|
num_boxes_ = options_.num_boxes();
|
||||||
num_coords_ = options_.num_coords();
|
num_coords_ = options_.num_coords();
|
||||||
|
box_output_format_ = GetBoxFormat(options_);
|
||||||
CHECK_NE(options_.max_results(), 0)
|
CHECK_NE(options_.max_results(), 0)
|
||||||
<< "The maximum number of the top-scored detection results must be "
|
<< "The maximum number of the top-scored detection results must be "
|
||||||
"non-zero.";
|
"non-zero.";
|
||||||
|
@ -728,17 +742,32 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
|
||||||
for (int i = 0; i < num_boxes_; ++i) {
|
for (int i = 0; i < num_boxes_; ++i) {
|
||||||
const int box_offset = i * num_coords_ + options_.box_coord_offset();
|
const int box_offset = i * num_coords_ + options_.box_coord_offset();
|
||||||
|
|
||||||
float y_center = raw_boxes[box_offset];
|
float y_center = 0.0;
|
||||||
float x_center = raw_boxes[box_offset + 1];
|
float x_center = 0.0;
|
||||||
float h = raw_boxes[box_offset + 2];
|
float h = 0.0;
|
||||||
float w = raw_boxes[box_offset + 3];
|
float w = 0.0;
|
||||||
if (options_.reverse_output_order()) {
|
// TODO
|
||||||
x_center = raw_boxes[box_offset];
|
switch (box_output_format_) {
|
||||||
y_center = raw_boxes[box_offset + 1];
|
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
|
||||||
w = raw_boxes[box_offset + 2];
|
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
|
||||||
h = raw_boxes[box_offset + 3];
|
y_center = raw_boxes[box_offset];
|
||||||
|
x_center = raw_boxes[box_offset + 1];
|
||||||
|
h = raw_boxes[box_offset + 2];
|
||||||
|
w = raw_boxes[box_offset + 3];
|
||||||
|
break;
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
|
||||||
|
x_center = raw_boxes[box_offset];
|
||||||
|
y_center = raw_boxes[box_offset + 1];
|
||||||
|
w = raw_boxes[box_offset + 2];
|
||||||
|
h = raw_boxes[box_offset + 3];
|
||||||
|
break;
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
|
||||||
|
x_center = (-raw_boxes[box_offset] + raw_boxes[box_offset + 2]) / 2;
|
||||||
|
y_center = (-raw_boxes[box_offset + 1] + raw_boxes[box_offset + 3]) / 2;
|
||||||
|
w = raw_boxes[box_offset + 2] + raw_boxes[box_offset];
|
||||||
|
h = raw_boxes[box_offset + 3] + raw_boxes[box_offset + 1];
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
x_center =
|
x_center =
|
||||||
x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
|
x_center / options_.x_scale() * anchors[i].w() + anchors[i].x_center();
|
||||||
y_center =
|
y_center =
|
||||||
|
@ -856,8 +885,22 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) {
|
absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) {
|
||||||
|
int output_format_flag = 0;
|
||||||
|
switch (box_output_format_) {
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::UNSPECIFIED:
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::YXHW:
|
||||||
|
output_format_flag = 0;
|
||||||
|
break;
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::XYWH:
|
||||||
|
output_format_flag = 1;
|
||||||
|
break;
|
||||||
|
case mediapipe::TensorsToDetectionsCalculatorOptions::XYXY:
|
||||||
|
output_format_flag = 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, output_format_flag]()
|
||||||
|
-> absl::Status {
|
||||||
// A shader to decode detection boxes.
|
// A shader to decode detection boxes.
|
||||||
const std::string decode_src = absl::Substitute(
|
const std::string decode_src = absl::Substitute(
|
||||||
R"( #version 310 es
|
R"( #version 310 es
|
||||||
|
@ -879,7 +922,7 @@ layout(std430, binding = 2) readonly buffer Input1 {
|
||||||
} raw_anchors;
|
} raw_anchors;
|
||||||
|
|
||||||
uint num_coords = uint($0);
|
uint num_coords = uint($0);
|
||||||
int reverse_output_order = int($1);
|
int output_format_flag = int($1);
|
||||||
int apply_exponential = int($2);
|
int apply_exponential = int($2);
|
||||||
int box_coord_offset = int($3);
|
int box_coord_offset = int($3);
|
||||||
int num_keypoints = int($4);
|
int num_keypoints = int($4);
|
||||||
|
@ -892,17 +935,25 @@ void main() {
|
||||||
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
|
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
|
||||||
|
|
||||||
float y_center, x_center, h, w;
|
float y_center, x_center, h, w;
|
||||||
|
if (output_format_flag == int(0)) {
|
||||||
if (reverse_output_order == int(0)) {
|
|
||||||
y_center = raw_boxes.data[box_offset + uint(0)];
|
y_center = raw_boxes.data[box_offset + uint(0)];
|
||||||
x_center = raw_boxes.data[box_offset + uint(1)];
|
x_center = raw_boxes.data[box_offset + uint(1)];
|
||||||
h = raw_boxes.data[box_offset + uint(2)];
|
h = raw_boxes.data[box_offset + uint(2)];
|
||||||
w = raw_boxes.data[box_offset + uint(3)];
|
w = raw_boxes.data[box_offset + uint(3)];
|
||||||
} else {
|
} else if (output_format_flag == int(1)) {
|
||||||
x_center = raw_boxes.data[box_offset + uint(0)];
|
x_center = raw_boxes.data[box_offset + uint(0)];
|
||||||
y_center = raw_boxes.data[box_offset + uint(1)];
|
y_center = raw_boxes.data[box_offset + uint(1)];
|
||||||
w = raw_boxes.data[box_offset + uint(2)];
|
w = raw_boxes.data[box_offset + uint(2)];
|
||||||
h = raw_boxes.data[box_offset + uint(3)];
|
h = raw_boxes.data[box_offset + uint(3)];
|
||||||
|
} else if (output_format_flag == int(2)) {
|
||||||
|
x_center = (-raw_boxes.data[box_offset + uint(0)]
|
||||||
|
+raw_boxes.data[box_offset + uint(2)]) / 2.0;
|
||||||
|
y_center = (-raw_boxes.data[box_offset + uint(1)]
|
||||||
|
+raw_boxes.data[box_offset + uint(3)]) / 2.0;
|
||||||
|
w = raw_boxes.data[box_offset + uint(0)]
|
||||||
|
+ raw_boxes.data[box_offset + uint(2)];
|
||||||
|
h = raw_boxes.data[box_offset + uint(1)]
|
||||||
|
+ raw_boxes.data[box_offset + uint(3)];
|
||||||
}
|
}
|
||||||
|
|
||||||
float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
|
float anchor_yc = raw_anchors.data[anchor_offset + uint(0)];
|
||||||
|
@ -936,7 +987,7 @@ void main() {
|
||||||
int kp_offset =
|
int kp_offset =
|
||||||
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
||||||
float kp_y, kp_x;
|
float kp_y, kp_x;
|
||||||
if (reverse_output_order == int(0)) {
|
if (output_format_flag == int(0)) {
|
||||||
kp_y = raw_boxes.data[kp_offset + int(0)];
|
kp_y = raw_boxes.data[kp_offset + int(0)];
|
||||||
kp_x = raw_boxes.data[kp_offset + int(1)];
|
kp_x = raw_boxes.data[kp_offset + int(1)];
|
||||||
} else {
|
} else {
|
||||||
|
@ -949,8 +1000,7 @@ void main() {
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
options_.num_coords(), // box xywh
|
options_.num_coords(), // box xywh
|
||||||
options_.reverse_output_order() ? 1 : 0,
|
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||||
options_.apply_exponential_on_box_size() ? 1 : 0,
|
|
||||||
options_.box_coord_offset(), options_.num_keypoints(),
|
options_.box_coord_offset(), options_.num_keypoints(),
|
||||||
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
||||||
|
|
||||||
|
@ -1104,7 +1154,7 @@ kernel void decodeKernel(
|
||||||
uint2 gid [[ thread_position_in_grid ]]) {
|
uint2 gid [[ thread_position_in_grid ]]) {
|
||||||
|
|
||||||
uint num_coords = uint($0);
|
uint num_coords = uint($0);
|
||||||
int reverse_output_order = int($1);
|
int output_format_flag = int($1);
|
||||||
int apply_exponential = int($2);
|
int apply_exponential = int($2);
|
||||||
int box_coord_offset = int($3);
|
int box_coord_offset = int($3);
|
||||||
int num_keypoints = int($4);
|
int num_keypoints = int($4);
|
||||||
|
@ -1112,8 +1162,7 @@ kernel void decodeKernel(
|
||||||
int num_values_per_keypt = int($6);
|
int num_values_per_keypt = int($6);
|
||||||
)",
|
)",
|
||||||
options_.num_coords(), // box xywh
|
options_.num_coords(), // box xywh
|
||||||
options_.reverse_output_order() ? 1 : 0,
|
output_format_flag, options_.apply_exponential_on_box_size() ? 1 : 0,
|
||||||
options_.apply_exponential_on_box_size() ? 1 : 0,
|
|
||||||
options_.box_coord_offset(), options_.num_keypoints(),
|
options_.box_coord_offset(), options_.num_keypoints(),
|
||||||
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
|
||||||
decode_src += absl::Substitute(
|
decode_src += absl::Substitute(
|
||||||
|
@ -1129,16 +1178,25 @@ kernel void decodeKernel(
|
||||||
|
|
||||||
float y_center, x_center, h, w;
|
float y_center, x_center, h, w;
|
||||||
|
|
||||||
if (reverse_output_order == int(0)) {
|
if (output_format_flag == int(0)) {
|
||||||
y_center = raw_boxes[box_offset + uint(0)];
|
y_center = raw_boxes[box_offset + uint(0)];
|
||||||
x_center = raw_boxes[box_offset + uint(1)];
|
x_center = raw_boxes[box_offset + uint(1)];
|
||||||
h = raw_boxes[box_offset + uint(2)];
|
h = raw_boxes[box_offset + uint(2)];
|
||||||
w = raw_boxes[box_offset + uint(3)];
|
w = raw_boxes[box_offset + uint(3)];
|
||||||
} else {
|
} else if (output_format_flag == int(1)) {
|
||||||
x_center = raw_boxes[box_offset + uint(0)];
|
x_center = raw_boxes[box_offset + uint(0)];
|
||||||
y_center = raw_boxes[box_offset + uint(1)];
|
y_center = raw_boxes[box_offset + uint(1)];
|
||||||
w = raw_boxes[box_offset + uint(2)];
|
w = raw_boxes[box_offset + uint(2)];
|
||||||
h = raw_boxes[box_offset + uint(3)];
|
h = raw_boxes[box_offset + uint(3)];
|
||||||
|
} else if (output_format_flag == int(2)) {
|
||||||
|
x_center = (-raw_boxes[box_offset + uint(0)]
|
||||||
|
+raw_boxes[box_offset + uint(2)]) / 2.0;
|
||||||
|
y_center = (-raw_boxes[box_offset + uint(1)]
|
||||||
|
+raw_boxes[box_offset + uint(3)]) / 2.0;
|
||||||
|
w = raw_boxes[box_offset + uint(0)]
|
||||||
|
+ raw_boxes[box_offset + uint(2)];
|
||||||
|
h = raw_boxes[box_offset + uint(1)]
|
||||||
|
+ raw_boxes[box_offset + uint(3)];
|
||||||
}
|
}
|
||||||
|
|
||||||
float anchor_yc = raw_anchors[anchor_offset + uint(0)];
|
float anchor_yc = raw_anchors[anchor_offset + uint(0)];
|
||||||
|
@ -1172,7 +1230,7 @@ kernel void decodeKernel(
|
||||||
int kp_offset =
|
int kp_offset =
|
||||||
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
|
||||||
float kp_y, kp_x;
|
float kp_y, kp_x;
|
||||||
if (reverse_output_order == int(0)) {
|
if (output_format_flag == int(0)) {
|
||||||
kp_y = raw_boxes[kp_offset + int(0)];
|
kp_y = raw_boxes[kp_offset + int(0)];
|
||||||
kp_x = raw_boxes[kp_offset + int(1)];
|
kp_x = raw_boxes[kp_offset + int(1)];
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -54,6 +54,7 @@ message TensorsToDetectionsCalculatorOptions {
|
||||||
// Whether to reverse the order of predicted x, y from output.
|
// Whether to reverse the order of predicted x, y from output.
|
||||||
// If false, the order is [y_center, x_center, h, w], if true the order is
|
// If false, the order is [y_center, x_center, h, w], if true the order is
|
||||||
// [x_center, y_center, w, h].
|
// [x_center, y_center, w, h].
|
||||||
|
// DEPRECATED. Use `box_format` instead.
|
||||||
optional bool reverse_output_order = 14 [default = false];
|
optional bool reverse_output_order = 14 [default = false];
|
||||||
// The ids of classes that should be ignored during decoding the score for
|
// The ids of classes that should be ignored during decoding the score for
|
||||||
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
|
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
|
||||||
|
@ -112,4 +113,18 @@ message TensorsToDetectionsCalculatorOptions {
|
||||||
oneof box_indices {
|
oneof box_indices {
|
||||||
BoxBoundariesIndices box_boundaries_indices = 23;
|
BoxBoundariesIndices box_boundaries_indices = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tells the calculator how to convert the detector output to bounding boxes.
|
||||||
|
// Replaces `reverse_output_order` to support more bbox output formats.
|
||||||
|
enum BoxFormat {
|
||||||
|
// if UNSPECIFIED, the calculator assumes YXHW
|
||||||
|
UNSPECIFIED = 0;
|
||||||
|
// [y_center, x_center, height, width]
|
||||||
|
YXHW = 1;
|
||||||
|
// [x_center, y_center, width, height]
|
||||||
|
XYWH = 2;
|
||||||
|
// [xmin, ymin, xmax, ymax]
|
||||||
|
XYXY = 3;
|
||||||
|
}
|
||||||
|
optional BoxFormat box_format = 24 [default = UNSPECIFIED];
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user