Object Detector deduplication

PiperOrigin-RevId: 493716159
This commit is contained in:
MediaPipe Team 2022-12-07 14:52:58 -08:00 committed by Copybara-Service
parent ef1507ed5d
commit 91664eb254
4 changed files with 138 additions and 1 deletions

View File

@ -456,6 +456,23 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "detections_deduplicate_calculator",
srcs = [
"detections_deduplicate_calculator.cc",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_library(
name = "rect_transformation_calculator",
srcs = ["rect_transformation_calculator.cc"],

View File

@ -0,0 +1,114 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
struct BoundingBoxHash {
size_t operator()(const LocationData::BoundingBox& bbox) const {
return std::hash<int>{}(bbox.xmin()) ^ std::hash<int>{}(bbox.ymin()) ^
std::hash<int>{}(bbox.width()) ^ std::hash<int>{}(bbox.height());
}
};
struct BoundingBoxEq {
bool operator()(const LocationData::BoundingBox& lhs,
const LocationData::BoundingBox& rhs) const {
return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() &&
lhs.width() == rhs.width() && lhs.height() == rhs.height();
}
};
} // namespace
// This Calculator deduplicates the bunding boxes with exactly the same
// coordinates, and folds the labels into a single Detection proto. Note
// non-maximum-suppression remove the overlapping bounding boxes within a class,
// while the deduplication operation merges bounding boxes from different
// classes.
// Example config:
// node {
// calculator: "DetectionsDeduplicateCalculator"
// input_stream: "detections"
// output_stream: "deduplicated_detections"
// }
class DetectionsDeduplicateCalculator : public Node {
public:
static constexpr Input<std::vector<Detection>> kIn{""};
static constexpr Output<std::vector<Detection>> kOut{""};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Open(mediapipe::CalculatorContext* cc) {
cc->SetOffset(::mediapipe::TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(mediapipe::CalculatorContext* cc) {
const std::vector<Detection>& raw_detections = kIn(cc).Get();
absl::flat_hash_map<LocationData::BoundingBox, Detection*, BoundingBoxHash,
BoundingBoxEq>
bbox_to_detections;
std::vector<Detection> deduplicated_detections;
for (const auto& detection : raw_detections) {
if (!detection.has_location_data() ||
!detection.location_data().has_bounding_box()) {
return absl::InvalidArgumentError(
"The location data of Detections must be BoundingBox.");
}
if (bbox_to_detections.contains(
detection.location_data().bounding_box())) {
// The bbox location already exists. Merge the detection labels into
// the existing detection proto.
Detection& deduplicated_detection =
*bbox_to_detections[detection.location_data().bounding_box()];
deduplicated_detection.mutable_score()->MergeFrom(detection.score());
deduplicated_detection.mutable_label()->MergeFrom(detection.label());
deduplicated_detection.mutable_label_id()->MergeFrom(
detection.label_id());
deduplicated_detection.mutable_display_name()->MergeFrom(
detection.display_name());
} else {
// The bbox location appears first time. Add the detection to output
// detection vector.
deduplicated_detections.push_back(detection);
bbox_to_detections[detection.location_data().bounding_box()] =
&deduplicated_detections.back();
}
}
kOut(cc).Send(std::move(deduplicated_detections));
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -63,6 +63,7 @@ cc_library(
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/calculators/util:detection_projection_calculator",
"//mediapipe/calculators/util:detection_transformation_calculator",
"//mediapipe/calculators/util:detections_deduplicate_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",

View File

@ -662,11 +662,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
detection_transformation.Out(kPixelDetectionsTag) >>
detection_label_id_to_text.In("");
// Deduplicate Detections with same bounding box coordinates.
auto& detections_deduplicate =
graph.AddNode("DetectionsDeduplicateCalculator");
detection_label_id_to_text.Out("") >> detections_deduplicate.In("");
// Outputs the labeled detections and the processed image as the subgraph
// output streams.
return {{
/* detections= */
detection_label_id_to_text[Output<std::vector<Detection>>("")],
detections_deduplicate[Output<std::vector<Detection>>("")],
/* image= */ preprocessing[Output<Image>(kImageTag)],
}};
}