TensorsToDetectionsCalculator supports multi clasees for a bbox.
PiperOrigin-RevId: 577300797
This commit is contained in:
parent
d73ef24406
commit
a96581e3b7
|
@ -15,7 +15,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/log/absl_log.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||||
|
@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node {
|
||||||
const int* detection_classes,
|
const int* detection_classes,
|
||||||
std::vector<Detection>* output_detections);
|
std::vector<Detection>* output_detections);
|
||||||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||||
float box_xmax, float score, int class_id,
|
float box_xmax, absl::Span<const float> scores,
|
||||||
|
absl::Span<const int> class_ids,
|
||||||
bool flip_vertically);
|
bool flip_vertically);
|
||||||
bool IsClassIndexAllowed(int class_index);
|
bool IsClassIndexAllowed(int class_index);
|
||||||
|
|
||||||
|
@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node {
|
||||||
int num_boxes_ = 0;
|
int num_boxes_ = 0;
|
||||||
int num_coords_ = 0;
|
int num_coords_ = 0;
|
||||||
int max_results_ = -1;
|
int max_results_ = -1;
|
||||||
|
int classes_per_detection_ = 1;
|
||||||
BoxFormat box_output_format_ =
|
BoxFormat box_output_format_ =
|
||||||
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW;
|
||||||
|
|
||||||
|
@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
||||||
auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
|
auto num_boxes_view = num_boxes_tensor->GetCpuReadView();
|
||||||
auto num_boxes = num_boxes_view.buffer<float>();
|
auto num_boxes = num_boxes_view.buffer<float>();
|
||||||
num_boxes_ = num_boxes[0];
|
num_boxes_ = num_boxes[0];
|
||||||
|
// The detection model with Detection_PostProcess op may output duplicate
|
||||||
|
// boxes with different classes, in the following format:
|
||||||
|
// num_boxes_tensor = [num_boxes]
|
||||||
|
// detection_classes_tensor = [box_1_class_1, box_1_class_2, ...]
|
||||||
|
// detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ]
|
||||||
|
// detection_boxes_tensor = [box_1, box1, ... ]
|
||||||
|
// Each box repeats classes_per_detection_ times.
|
||||||
|
// Note Detection_PostProcess op is only supported in CPU.
|
||||||
|
RET_CHECK_EQ(max_detections % num_boxes_, 0);
|
||||||
|
classes_per_detection_ = max_detections / num_boxes_;
|
||||||
|
|
||||||
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
|
auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView();
|
||||||
auto detection_boxes = detection_boxes_view.buffer<float>();
|
auto detection_boxes = detection_boxes_view.buffer<float>();
|
||||||
|
@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
||||||
|
|
||||||
auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
|
auto detection_classes_view = detection_classes_tensor->GetCpuReadView();
|
||||||
auto detection_classes_ptr = detection_classes_view.buffer<float>();
|
auto detection_classes_ptr = detection_classes_view.buffer<float>();
|
||||||
std::vector<int> detection_classes(num_boxes_);
|
std::vector<int> detection_classes(num_boxes_ * classes_per_detection_);
|
||||||
for (int i = 0; i < num_boxes_; ++i) {
|
for (int i = 0; i < detection_classes.size(); ++i) {
|
||||||
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
|
detection_classes[i] = static_cast<int>(detection_classes_ptr[i]);
|
||||||
}
|
}
|
||||||
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
|
MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores,
|
||||||
|
@ -863,7 +874,8 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes(
|
||||||
absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
||||||
const float* detection_boxes, const float* detection_scores,
|
const float* detection_boxes, const float* detection_scores,
|
||||||
const int* detection_classes, std::vector<Detection>* output_detections) {
|
const int* detection_classes, std::vector<Detection>* output_detections) {
|
||||||
for (int i = 0; i < num_boxes_; ++i) {
|
for (int i = 0; i < num_boxes_ * classes_per_detection_;
|
||||||
|
i += classes_per_detection_) {
|
||||||
if (max_results_ > 0 && output_detections->size() == max_results_) {
|
if (max_results_ > 0 && output_detections->size() == max_results_) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -880,7 +892,9 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
||||||
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
|
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
|
||||||
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
|
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
|
||||||
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
|
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
|
||||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
|
||||||
|
absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
|
||||||
|
options_.flip_vertically());
|
||||||
const auto& bbox = detection.location_data().relative_bounding_box();
|
const auto& bbox = detection.location_data().relative_bounding_box();
|
||||||
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
|
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
|
||||||
std::isnan(bbox.height())) {
|
std::isnan(bbox.height())) {
|
||||||
|
@ -910,11 +924,17 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
||||||
}
|
}
|
||||||
|
|
||||||
Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
Detection TensorsToDetectionsCalculator::ConvertToDetection(
|
||||||
float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score,
|
float box_ymin, float box_xmin, float box_ymax, float box_xmax,
|
||||||
int class_id, bool flip_vertically) {
|
absl::Span<const float> scores, absl::Span<const int> class_ids,
|
||||||
|
bool flip_vertically) {
|
||||||
Detection detection;
|
Detection detection;
|
||||||
detection.add_score(score);
|
for (int i = 0; i < scores.size(); ++i) {
|
||||||
detection.add_label_id(class_id);
|
if (!IsClassIndexAllowed(class_ids[i])) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
detection.add_score(scores[i]);
|
||||||
|
detection.add_label_id(class_ids[i]);
|
||||||
|
}
|
||||||
|
|
||||||
LocationData* location_data = detection.mutable_location_data();
|
LocationData* location_data = detection.mutable_location_data();
|
||||||
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user