Object Detector deduplication
PiperOrigin-RevId: 493716159
This commit is contained in:
parent
ef1507ed5d
commit
91664eb254
|
@ -456,6 +456,23 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "detections_deduplicate_calculator",
|
||||||
|
srcs = [
|
||||||
|
"detections_deduplicate_calculator.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "rect_transformation_calculator",
|
name = "rect_transformation_calculator",
|
||||||
srcs = ["rect_transformation_calculator.cc"],
|
srcs = ["rect_transformation_calculator.cc"],
|
||||||
|
|
114
mediapipe/calculators/util/detections_deduplicate_calculator.cc
Normal file
114
mediapipe/calculators/util/detections_deduplicate_calculator.cc
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct BoundingBoxHash {
|
||||||
|
size_t operator()(const LocationData::BoundingBox& bbox) const {
|
||||||
|
return std::hash<int>{}(bbox.xmin()) ^ std::hash<int>{}(bbox.ymin()) ^
|
||||||
|
std::hash<int>{}(bbox.width()) ^ std::hash<int>{}(bbox.height());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BoundingBoxEq {
|
||||||
|
bool operator()(const LocationData::BoundingBox& lhs,
|
||||||
|
const LocationData::BoundingBox& rhs) const {
|
||||||
|
return lhs.xmin() == rhs.xmin() && lhs.ymin() == rhs.ymin() &&
|
||||||
|
lhs.width() == rhs.width() && lhs.height() == rhs.height();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// This Calculator deduplicates the bunding boxes with exactly the same
|
||||||
|
// coordinates, and folds the labels into a single Detection proto. Note
|
||||||
|
// non-maximum-suppression remove the overlapping bounding boxes within a class,
|
||||||
|
// while the deduplication operation merges bounding boxes from different
|
||||||
|
// classes.
|
||||||
|
|
||||||
|
// Example config:
|
||||||
|
// node {
|
||||||
|
// calculator: "DetectionsDeduplicateCalculator"
|
||||||
|
// input_stream: "detections"
|
||||||
|
// output_stream: "deduplicated_detections"
|
||||||
|
// }
|
||||||
|
class DetectionsDeduplicateCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::vector<Detection>> kIn{""};
|
||||||
|
static constexpr Output<std::vector<Detection>> kOut{""};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Open(mediapipe::CalculatorContext* cc) {
|
||||||
|
cc->SetOffset(::mediapipe::TimestampDiff(0));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(mediapipe::CalculatorContext* cc) {
|
||||||
|
const std::vector<Detection>& raw_detections = kIn(cc).Get();
|
||||||
|
absl::flat_hash_map<LocationData::BoundingBox, Detection*, BoundingBoxHash,
|
||||||
|
BoundingBoxEq>
|
||||||
|
bbox_to_detections;
|
||||||
|
std::vector<Detection> deduplicated_detections;
|
||||||
|
for (const auto& detection : raw_detections) {
|
||||||
|
if (!detection.has_location_data() ||
|
||||||
|
!detection.location_data().has_bounding_box()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The location data of Detections must be BoundingBox.");
|
||||||
|
}
|
||||||
|
if (bbox_to_detections.contains(
|
||||||
|
detection.location_data().bounding_box())) {
|
||||||
|
// The bbox location already exists. Merge the detection labels into
|
||||||
|
// the existing detection proto.
|
||||||
|
Detection& deduplicated_detection =
|
||||||
|
*bbox_to_detections[detection.location_data().bounding_box()];
|
||||||
|
deduplicated_detection.mutable_score()->MergeFrom(detection.score());
|
||||||
|
deduplicated_detection.mutable_label()->MergeFrom(detection.label());
|
||||||
|
deduplicated_detection.mutable_label_id()->MergeFrom(
|
||||||
|
detection.label_id());
|
||||||
|
deduplicated_detection.mutable_display_name()->MergeFrom(
|
||||||
|
detection.display_name());
|
||||||
|
} else {
|
||||||
|
// The bbox location appears first time. Add the detection to output
|
||||||
|
// detection vector.
|
||||||
|
deduplicated_detections.push_back(detection);
|
||||||
|
bbox_to_detections[detection.location_data().bounding_box()] =
|
||||||
|
&deduplicated_detections.back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kOut(cc).Send(std::move(deduplicated_detections));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(DetectionsDeduplicateCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -63,6 +63,7 @@ cc_library(
|
||||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
|
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:detection_projection_calculator",
|
"//mediapipe/calculators/util:detection_projection_calculator",
|
||||||
"//mediapipe/calculators/util:detection_transformation_calculator",
|
"//mediapipe/calculators/util:detection_transformation_calculator",
|
||||||
|
"//mediapipe/calculators/util:detections_deduplicate_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
|
|
|
@ -662,11 +662,16 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
||||||
detection_transformation.Out(kPixelDetectionsTag) >>
|
detection_transformation.Out(kPixelDetectionsTag) >>
|
||||||
detection_label_id_to_text.In("");
|
detection_label_id_to_text.In("");
|
||||||
|
|
||||||
|
// Deduplicate Detections with same bounding box coordinates.
|
||||||
|
auto& detections_deduplicate =
|
||||||
|
graph.AddNode("DetectionsDeduplicateCalculator");
|
||||||
|
detection_label_id_to_text.Out("") >> detections_deduplicate.In("");
|
||||||
|
|
||||||
// Outputs the labeled detections and the processed image as the subgraph
|
// Outputs the labeled detections and the processed image as the subgraph
|
||||||
// output streams.
|
// output streams.
|
||||||
return {{
|
return {{
|
||||||
/* detections= */
|
/* detections= */
|
||||||
detection_label_id_to_text[Output<std::vector<Detection>>("")],
|
detections_deduplicate[Output<std::vector<Detection>>("")],
|
||||||
/* image= */ preprocessing[Output<Image>(kImageTag)],
|
/* image= */ preprocessing[Output<Image>(kImageTag)],
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user