From a96581e3b7685c9f12dfc206b885a362f9fb0396 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 27 Oct 2023 14:08:39 -0700 Subject: [PATCH] TensorsToDetectionsCalculator supports multi clasees for a bbox. PiperOrigin-RevId: 577300797 --- .../tensors_to_detections_calculator.cc | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index aa2cfe734..95682d633 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -15,7 +15,6 @@ #include #include -#include "absl/log/absl_log.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" @@ -215,7 +214,8 @@ class TensorsToDetectionsCalculator : public Node { const int* detection_classes, std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, - float box_xmax, float score, int class_id, + float box_xmax, absl::Span scores, + absl::Span class_ids, bool flip_vertically); bool IsClassIndexAllowed(int class_index); @@ -223,6 +223,7 @@ class TensorsToDetectionsCalculator : public Node { int num_boxes_ = 0; int num_coords_ = 0; int max_results_ = -1; + int classes_per_detection_ = 1; BoxFormat box_output_format_ = mediapipe::TensorsToDetectionsCalculatorOptions::YXHW; @@ -484,6 +485,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( auto num_boxes_view = num_boxes_tensor->GetCpuReadView(); auto num_boxes = num_boxes_view.buffer(); num_boxes_ = num_boxes[0]; + // The detection model with Detection_PostProcess op may output duplicate + // boxes with different classes, in the following format: + // num_boxes_tensor = [num_boxes] + // detection_classes_tensor = [box_1_class_1, box_1_class_2, ...] + // detection_scores_tensor = [box_1_score_1, box_1_score_2, ... ] + // detection_boxes_tensor = [box_1, box1, ... ] + // Each box repeats classes_per_detection_ times. + // Note Detection_PostProcess op is only supported in CPU. + RET_CHECK_EQ(max_detections % num_boxes_, 0); + classes_per_detection_ = max_detections / num_boxes_; auto detection_boxes_view = detection_boxes_tensor->GetCpuReadView(); auto detection_boxes = detection_boxes_view.buffer(); @@ -493,8 +504,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( auto detection_classes_view = detection_classes_tensor->GetCpuReadView(); auto detection_classes_ptr = detection_classes_view.buffer(); - std::vector detection_classes(num_boxes_); - for (int i = 0; i < num_boxes_; ++i) { + std::vector detection_classes(num_boxes_ * classes_per_detection_); + for (int i = 0; i < detection_classes.size(); ++i) { detection_classes[i] = static_cast(detection_classes_ptr[i]); } MP_RETURN_IF_ERROR(ConvertToDetections(detection_boxes, detection_scores, @@ -863,7 +874,8 @@ absl::Status TensorsToDetectionsCalculator::DecodeBoxes( absl::Status TensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { - for (int i = 0; i < num_boxes_; ++i) { + for (int i = 0; i < num_boxes_ * classes_per_detection_; + i += classes_per_detection_) { if (max_results_ > 0 && output_detections->size() == max_results_) { break; } @@ -880,7 +892,9 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]], /*box_ymax=*/detection_boxes[box_offset + box_indices_[2]], /*box_xmax=*/detection_boxes[box_offset + box_indices_[3]], - detection_scores[i], detection_classes[i], options_.flip_vertically()); + absl::MakeConstSpan(detection_scores + i, classes_per_detection_), + absl::MakeConstSpan(detection_classes + i, classes_per_detection_), + options_.flip_vertically()); const auto& bbox = detection.location_data().relative_bounding_box(); if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || std::isnan(bbox.height())) { @@ -910,11 +924,17 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( } Detection TensorsToDetectionsCalculator::ConvertToDetection( - float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, - int class_id, bool flip_vertically) { + float box_ymin, float box_xmin, float box_ymax, float box_xmax, + absl::Span scores, absl::Span class_ids, + bool flip_vertically) { Detection detection; - detection.add_score(score); - detection.add_label_id(class_id); + for (int i = 0; i < scores.size(); ++i) { + if (!IsClassIndexAllowed(class_ids[i])) { + continue; + } + detection.add_score(scores[i]); + detection.add_label_id(class_ids[i]); + } LocationData* location_data = detection.mutable_location_data(); location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);