TensorsToDetectionsCalculator supports multi clasees for a bbox.

PiperOrigin-RevId: 577300797
This commit is contained in:
MediaPipe Team 2023-10-27 14:08:39 -07:00 committed by Copybara-Service
parent d73ef24406
commit a96581e3b7

View File

@ -15,7 +15,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/log/absl_log.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node {
const int* detection_classes, const int* detection_classes,
std::vector<Detection>* output_detections); std::vector<Detection>* output_detections);
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id, float box_xmax, absl::Span<const float> scores,
absl::Span<const int> class_ids,
bool flip_vertically); bool flip_vertically);
bool IsClassIndexAllowed(int class_index); bool IsClassIndexAllowed(int class_index);
@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node {
int num_boxes_ = 0; int num_boxes_ = 0;
int num_coords_ = 0; int num_coords_ = 0;
int max_results_ = -1; int max_results_ = -1;
int classes_per_detection_ = 1;
BoxFormat box_output_format_ = BoxFormat box_output_format_ =
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
auto num_boxes_view = num_boxes_tensor->GetCpuReadView(); auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
auto num_boxes = num_boxes_view.buffer<float>(); auto num_boxes = num_boxes_view.buffer<float>();
num_boxes_ = num_boxes[0]; num_boxes_ = num_boxes[0];
// The detection model with Detection_PostProcess op may output duplicate
// boxes with different classes, in the following format:
// num_boxes_tensor = [num_boxes]
// detection_classes_tensor = [box_1_class_1, box_1_class_2, ...]
// detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ]
// detection_boxes_tensor = [box_1, box1, ... ]
// Each box repeats classes_per_detection_ times.
// Note Detection_PostProcess op is only supported in CPU.
RET_CHECK_EQ(max_detections % num_boxes_, 0);
classes_per_detection_ = max_detections / num_boxes_;
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView(); auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
auto detection_boxes = detection_boxes_view.buffer<float>(); auto detection_boxes = detection_boxes_view.buffer<float>();
@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
auto detection_classes_view = detection_classes_tensor->GetCpuReadView(); auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
auto detection_classes_ptr = detection_classes_view.buffer<float>(); auto detection_classes_ptr = detection_classes_view.buffer<float>();
std::vector<int> detection_classes(num_boxes_); std::vector<int> detection_classes(num_boxes_ * classes_per_detection_);
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < detection_classes.size(); ++i) {
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]); detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
} }
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
@ -863,7 +874,8 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
absl::Status TensorsToDetectionsCalculator::ConvertToDetections( absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
const float* detection_boxes, const float* detection_scores, const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections) { const int* detection_classes, std::vector<Detection>* output_detections) {
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < num_boxes_ * classes_per_detection_;
i += classes_per_detection_) {
if (max_results_ > 0 && output_detections->size() == max_results_) { if (max_results_ > 0 && output_detections->size() == max_results_) {
break; break;
} }
@ -880,7 +892,9 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]], /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]], /*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]], /*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
detection_scores[i], detection_classes[i], options_.flip_vertically()); absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
options_.flip_vertically());
const auto& bbox = detection.location_data().relative_bounding_box(); const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
std::isnan(bbox.height())) { std::isnan(bbox.height())) {
@ -910,11 +924,17 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
} }
Detection TensorsToDetectionsCalculator::ConvertToDetection( Detection TensorsToDetectionsCalculator::ConvertToDetection(
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, float box_ymin, float box_xmin, float box_ymax, float box_xmax,
int class_id, bool flip_vertically) { absl::Span<const float> scores, absl::Span<const int> class_ids,
bool flip_vertically) {
Detection detection; Detection detection;
detection.add_score(score); for (int i = 0; i < scores.size(); ++i) {
detection.add_label_id(class_id); if (!IsClassIndexAllowed(class_ids[i])) {
continue;
}
detection.add_score(scores[i]);
detection.add_label_id(class_ids[i]);
}
LocationData* location_data = detection.mutable_location_data(); LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX); location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);