Move filtering logic of score to ConvertToDetection.

PiperOrigin-RevId: 578189518
This commit is contained in:
MediaPipe Team 2023-10-31 08:16:58 -07:00 committed by Copybara-Service
parent ec032fb018
commit 7da2810b83

View File

@ -879,13 +879,6 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
if (max_results_ > 0 && output_detections->size() == max_results_) { if (max_results_ > 0 && output_detections->size() == max_results_) {
break; break;
} }
if (options_.has_min_score_thresh() &&
detection_scores[i] < options_.min_score_thresh()) {
continue;
}
if (!IsClassIndexAllowed(detection_classes[i])) {
continue;
}
const int box_offset = i * num_coords_; const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection( Detection detection = ConvertToDetection(
/*box_ymin=*/detection_boxes[box_offset + box_indices_[0]], /*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
@ -895,6 +888,11 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
absl::MakeConstSpan(detection_scores + i, classes_per_detection_), absl::MakeConstSpan(detection_scores + i, classes_per_detection_),
absl::MakeConstSpan(detection_classes + i, classes_per_detection_), absl::MakeConstSpan(detection_classes + i, classes_per_detection_),
options_.flip_vertically()); options_.flip_vertically());
// if all the scores and classes are filtered out, we skip the empty
// detection.
if (detection.score().empty()) {
continue;
}
const auto& bbox = detection.location_data().relative_bounding_box(); const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
std::isnan(bbox.height())) { std::isnan(bbox.height())) {
@ -932,6 +930,10 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection(
if (!IsClassIndexAllowed(class_ids[i])) { if (!IsClassIndexAllowed(class_ids[i])) {
continue; continue;
} }
if (options_.has_min_score_thresh() &&
scores[i] < options_.min_score_thresh()) {
continue;
}
detection.add_score(scores[i]); detection.add_score(scores[i]);
detection.add_label_id(class_ids[i]); detection.add_label_id(class_ids[i]);
} }