mediapipe/mediapipe/calculators/util/non_max_suppression_calculator.cc
MediaPipe Team 294687295d Project import generated by Copybara.
PiperOrigin-RevId: 263889205
2019-08-16 18:56:48 -07:00

378 lines
15 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
typedef std::vector<Detection> Detections;
typedef std::vector<std::pair<int, float>> IndexedScores;
namespace {
constexpr char kImageTag[] = "IMAGE";
bool SortBySecond(const std::pair<int, float>& indexed_score_0,
const std::pair<int, float>& indexed_score_1) {
return (indexed_score_0.second > indexed_score_1.second);
}
// Removes all but the max scoring label and its score from the detection.
// Returns true if the detection has at least one label.
bool RetainMaxScoringLabelOnly(Detection* detection) {
if (detection->label_id_size() == 0 && detection->label_size() == 0) {
return false;
}
CHECK(detection->label_id_size() == detection->score_size() ||
detection->label_size() == detection->score_size())
<< "Number of scores must be equal to number of detections.";
std::vector<std::pair<int, float>> indexed_scores;
for (int k = 0; k < detection->score_size(); ++k) {
indexed_scores.push_back(std::make_pair(k, detection->score(k)));
}
std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);
const int top_index = indexed_scores[0].first;
detection->clear_score();
detection->add_score(indexed_scores[0].second);
if (detection->label_id_size() > top_index) {
const int top_label_id = detection->label_id(top_index);
detection->clear_label_id();
detection->add_label_id(top_label_id);
} else {
const std::string top_label = detection->label(top_index);
detection->clear_label();
detection->add_label(top_label);
}
return true;
}
// Computes an overlap similarity between two rectangles. Similarity measure is
// defined by overlap_type parameter.
float OverlapSimilarity(
const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
const Rectangle_f& rect1, const Rectangle_f& rect2) {
if (!rect1.Intersects(rect2)) return 0.0f;
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
float normalization;
switch (overlap_type) {
case NonMaxSuppressionCalculatorOptions::JACCARD:
normalization = Rectangle_f(rect1).Union(rect2).Area();
break;
case NonMaxSuppressionCalculatorOptions::MODIFIED_JACCARD:
normalization = rect2.Area();
break;
case NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION:
normalization = rect1.Area() + rect2.Area() - intersection_area;
break;
default:
LOG(FATAL) << "Unrecognized overlap type: " << overlap_type;
}
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
}
// Computes an overlap similarity between two locations by first extracting the
// relative box (dimension normalized by frame width/height) from the location.
float OverlapSimilarity(
const int frame_width, const int frame_height,
const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
const Location& location1, const Location& location2) {
const auto rect1 = location1.ConvertToRelativeBBox(frame_width, frame_height);
const auto rect2 = location2.ConvertToRelativeBBox(frame_width, frame_height);
return OverlapSimilarity(overlap_type, rect1, rect2);
}
// Computes an overlap similarity between two locations by first extracting the
// relative box from the location. It assumes that a relative-box representation
// is already available in the location, and therefore frame width and height
// are not needed for further normalization.
float OverlapSimilarity(
const NonMaxSuppressionCalculatorOptions::OverlapType overlap_type,
const Location& location1, const Location& location2) {
const auto rect1 = location1.GetRelativeBBox();
const auto rect2 = location2.GetRelativeBBox();
return OverlapSimilarity(overlap_type, rect1, rect2);
}
} // namespace
// A calculator performing non-maximum suppression on a set of detections.
// Inputs:
// 1. IMAGE (optional): A stream of ImageFrame used to obtain the frame size.
// No image data is used. Not needed if the detection bounding boxes are
// already represented in normalized dimensions (0.0~1.0).
// 2. A variable number of input streams of type std::vector<Detection>. The
// exact number of such streams should be set via num_detection_streams
// field in the calculator options.
//
// Outputs: a single stream of type std::vector<Detection> containing a subset
// of the input detections after non-maximum suppression.
//
// Example config:
// node {
// calculator: "NonMaxSuppressionCalculator"
// input_stream: "IMAGE:frames"
// input_stream: "detections1"
// input_stream: "detections2"
// output_stream: "detections"
// options {
// [mediapipe.NonMaxSuppressionCalculatorOptions.ext] {
// num_detection_streams: 2
// max_num_detections: 10
// min_suppression_threshold: 0.2
// overlap_type: JACCARD
// }
// }
// }
class NonMaxSuppressionCalculator : public CalculatorBase {
public:
NonMaxSuppressionCalculator() = default;
~NonMaxSuppressionCalculator() override = default;
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
const auto& options = cc->Options<NonMaxSuppressionCalculatorOptions>();
if (cc->Inputs().HasTag(kImageTag)) {
cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
}
for (int k = 0; k < options.num_detection_streams(); ++k) {
cc->Inputs().Index(k).Set<Detections>();
}
cc->Outputs().Index(0).Set<Detections>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<NonMaxSuppressionCalculatorOptions>();
CHECK_GT(options_.num_detection_streams(), 0)
<< "At least one detection stream need to be specified.";
CHECK_NE(options_.max_num_detections(), 0)
<< "max_num_detections=0 is not a valid value. Please choose a "
<< "positive number of you want to limit the number of output "
<< "detections, or set -1 if you do not want any limit.";
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
// Add all input detections to the same vector.
Detections input_detections;
for (int i = 0; i < options_.num_detection_streams(); ++i) {
const auto& detections_packet = cc->Inputs().Index(i).Value();
// Check whether this stream has a packet for this timestamp.
if (detections_packet.IsEmpty()) {
continue;
}
const auto& detections = detections_packet.Get<Detections>();
input_detections.insert(input_detections.end(), detections.begin(),
detections.end());
}
// Check if there are any detections at all.
if (input_detections.empty()) {
if (options_.return_empty_detections()) {
cc->Outputs().Index(0).Add(new Detections(), cc->InputTimestamp());
}
return ::mediapipe::OkStatus();
}
// Remove all but the maximum scoring label from each input detection. This
// corresponds to non-maximum suppression among detections which have
// identical locations.
Detections pruned_detections;
pruned_detections.reserve(input_detections.size());
for (auto& detection : input_detections) {
if (RetainMaxScoringLabelOnly(&detection)) {
pruned_detections.push_back(detection);
}
}
// Copy all the scores (there is a single score in each detection after
// the above pruning) to an indexed vector for sorting. The first value is
// the index of the detection in the original vector from which the score
// stems, while the second is the actual score.
IndexedScores indexed_scores;
indexed_scores.reserve(pruned_detections.size());
for (int index = 0; index < pruned_detections.size(); ++index) {
indexed_scores.push_back(
std::make_pair(index, pruned_detections[index].score(0)));
}
std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);
const int max_num_detections =
(options_.max_num_detections() > -1)
? options_.max_num_detections()
: static_cast<int>(indexed_scores.size());
// A set of detections and locations, wrapping the location data from each
// detection, which are retained after the non-maximum suppression.
auto* retained_detections = new Detections();
retained_detections->reserve(max_num_detections);
if (options_.algorithm() == NonMaxSuppressionCalculatorOptions::WEIGHTED) {
WeightedNonMaxSuppression(indexed_scores, pruned_detections,
max_num_detections, cc, retained_detections);
} else {
NonMaxSuppression(indexed_scores, pruned_detections, max_num_detections,
cc, retained_detections);
}
cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
private:
void NonMaxSuppression(const IndexedScores& indexed_scores,
const Detections& detections, int max_num_detections,
CalculatorContext* cc, Detections* output_detections) {
std::vector<Location> retained_locations;
retained_locations.reserve(max_num_detections);
// We traverse the detections by decreasing score.
for (const auto& indexed_score : indexed_scores) {
const auto& detection = detections[indexed_score.first];
if (options_.min_score_threshold() > 0 &&
detection.score(0) < options_.min_score_threshold()) {
break;
}
const Location location(detection.location_data());
bool suppressed = false;
// The current detection is suppressed iff there exists a retained
// detection, whose location overlaps more than the specified
// threshold with the location of the current detection.
for (const auto& retained_location : retained_locations) {
float similarity;
if (cc->Inputs().HasTag(kImageTag)) {
const auto& frame = cc->Inputs().Tag(kImageTag).Get<ImageFrame>();
similarity = OverlapSimilarity(frame.Width(), frame.Height(),
options_.overlap_type(),
retained_location, location);
} else {
similarity = OverlapSimilarity(options_.overlap_type(),
retained_location, location);
}
if (similarity > options_.min_suppression_threshold()) {
suppressed = true;
break;
}
}
if (!suppressed) {
output_detections->push_back(detection);
retained_locations.push_back(location);
}
if (output_detections->size() >= max_num_detections) {
break;
}
}
}
void WeightedNonMaxSuppression(const IndexedScores& indexed_scores,
const Detections& detections,
int max_num_detections, CalculatorContext* cc,
Detections* output_detections) {
IndexedScores remained_indexed_scores;
remained_indexed_scores.assign(indexed_scores.begin(),
indexed_scores.end());
IndexedScores remained;
IndexedScores candidates;
output_detections->clear();
while (!remained_indexed_scores.empty()) {
const auto& detection = detections[remained_indexed_scores[0].first];
if (options_.min_score_threshold() > 0 &&
detection.score(0) < options_.min_score_threshold()) {
break;
}
remained.clear();
candidates.clear();
const Location location(detection.location_data());
// This includes the first box.
for (const auto& indexed_score : remained_indexed_scores) {
Location rest_location(detections[indexed_score.first].location_data());
float similarity =
OverlapSimilarity(options_.overlap_type(), rest_location, location);
if (similarity > options_.min_suppression_threshold()) {
candidates.push_back(indexed_score);
} else {
remained.push_back(indexed_score);
}
}
auto weighted_detection = detection;
if (!candidates.empty()) {
const int num_keypoints =
detection.location_data().relative_keypoints_size();
std::vector<float> keypoints(num_keypoints * 2);
float w_xmin = 0.0f;
float w_ymin = 0.0f;
float w_xmax = 0.0f;
float w_ymax = 0.0f;
float total_score = 0.0f;
for (const auto& candidate : candidates) {
total_score += candidate.second;
const auto& location_data =
detections[candidate.first].location_data();
const auto& bbox = location_data.relative_bounding_box();
w_xmin += bbox.xmin() * candidate.second;
w_ymin += bbox.ymin() * candidate.second;
w_xmax += (bbox.xmin() + bbox.width()) * candidate.second;
w_ymax += (bbox.ymin() + bbox.height()) * candidate.second;
for (int i = 0; i < num_keypoints; ++i) {
keypoints[i * 2] +=
location_data.relative_keypoints(i).x() * candidate.second;
keypoints[i * 2 + 1] +=
location_data.relative_keypoints(i).y() * candidate.second;
}
}
auto* weighted_location = weighted_detection.mutable_location_data()
->mutable_relative_bounding_box();
weighted_location->set_xmin(w_xmin / total_score);
weighted_location->set_ymin(w_ymin / total_score);
weighted_location->set_width((w_xmax / total_score) -
weighted_location->xmin());
weighted_location->set_height((w_ymax / total_score) -
weighted_location->ymin());
for (int i = 0; i < num_keypoints; ++i) {
auto* keypoint = weighted_detection.mutable_location_data()
->mutable_relative_keypoints(i);
keypoint->set_x(keypoints[i * 2] / total_score);
keypoint->set_y(keypoints[i * 2 + 1] / total_score);
}
}
remained_indexed_scores = std::move(remained);
output_detections->push_back(weighted_detection);
}
}
NonMaxSuppressionCalculatorOptions options_;
};
REGISTER_CALCULATOR(NonMaxSuppressionCalculator);
} // namespace mediapipe