TensorsToDetectionsCalculator supports multi clasees for a bbox.
PiperOrigin-RevId: 577300797
This commit is contained in:
parent
d73ef24406
commit
a96581e3b7
|
@ -15,7 +15,6 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
|
@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
const int* detection_classes,
|
||||
std::vector<Detection>* output_detections);
|
||||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||
float box_xmax, float score, int class_id,
|
||||
float box_xmax, absl::Span<const float> scores,
|
||||
absl::Span<const int> class_ids,
|
||||
bool flip_vertically);
|
||||
bool IsClassIndexAllowed(int class_index);
|
||||
|
||||
|
@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
int num_boxes_ = 0;
|
||||
int num_coords_ = 0;
|
||||
int max_results_ = -1;
|
||||
int classes_per_detection_ = 1;
|
||||
BoxFormat box_output_format_ =
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||
|
||||
|
@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
|
||||
auto num_boxes = num_boxes_view.buffer<float>();
|
||||
num_boxes_ = num_boxes[0];
|
||||
// The detection model with Detection_PostProcess op may output duplicate
|
||||
// boxes with different classes, in the following format:
|
||||
// num_boxes_tensor = [num_boxes]
|
||||
// detection_classes_tensor = [box_1_class_1, box_1_class_2, ...]
|
||||
// detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ]
|
||||
// detection_boxes_tensor = [box_1, box1, ... ]
|
||||
// Each box repeats classes_per_detection_ times.
|
||||
// Note Detection_PostProcess op is only supported in CPU.
|
||||
RET_CHECK_EQ(max_detections % num_boxes_, 0);
|
||||
classes_per_detection_ = max_detections / num_boxes_;
|
||||
|
||||
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
|
||||
auto detection_boxes = detection_boxes_view.buffer<float>();
|
||||
|
@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
|
||||
auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
|
||||
auto detection_classes_ptr = detection_classes_view.buffer<float>();
|
||||
std::vector<int> detection_classes(num_boxes_);
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
std::vector<int> detection_classes(num_boxes_ * classes_per_detection_);
|
||||
for (int i = 0; i < detection_classes.size(); ++i) {
|
||||
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
|
||||
|
@ -863,7 +874,8 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
|
|||
absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
||||
const float* detection_boxes, const float* detection_scores,
|
||||
const int* detection_classes, std::vector<Detection>* output_detections) {
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
for (int i = 0; i < num_boxes_ * classes_per_detection_;
|
||||
i += classes_per_detection_) {
|
||||
if (max_results_ > 0 && output_detections->size() == max_results_) {
|
||||
break;
|
||||
}
|
||||
|
@ -880,7 +892,9 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
|||
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
|
||||
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
|
||||
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
|
||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
||||
absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
|
||||
absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
|
||||
options_.flip_vertically());
|
||||
const auto& bbox = detection.location_data().relative_bounding_box();
|
||||
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
|
||||
std::isnan(bbox.height())) {
|
||||
|
@ -910,11 +924,17 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
|||
}
|
||||
|
||||
Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
|
||||
int class_id, bool flip_vertically) {
|
||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax,
|
||||
absl::Span<const float> scores, absl::Span<const int> class_ids,
|
||||
bool flip_vertically) {
|
||||
Detection detection;
|
||||
detection.add_score(score);
|
||||
detection.add_label_id(class_id);
|
||||
for (int i = 0; i < scores.size(); ++i) {
|
||||
if (!IsClassIndexAllowed(class_ids[i])) {
|
||||
continue;
|
||||
}
|
||||
detection.add_score(scores[i]);
|
||||
detection.add_label_id(class_ids[i]);
|
||||
}
|
||||
|
||||
LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
||||
|
|
Loading…
Reference in New Issue
Block a user