Add more filtering methods to detection filter calculator.
PiperOrigin-RevId: 507581281
This commit is contained in:
parent
f4b0cf1cff
commit
e2ef78433f
|
@ -167,6 +167,7 @@ cc_test(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
],
|
],
|
||||||
|
@ -413,6 +414,7 @@ cc_library(
|
||||||
":filter_detections_calculator_cc_proto",
|
":filter_detections_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
|
|
|
@ -21,11 +21,13 @@
|
||||||
#include "mediapipe/calculators/util/filter_detections_calculator.pb.h"
|
#include "mediapipe/calculators/util/filter_detections_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
const char kInputDetectionsTag[] = "INPUT_DETECTIONS";
|
const char kInputDetectionsTag[] = "INPUT_DETECTIONS";
|
||||||
|
const char kImageSizeTag[] = "IMAGE_SIZE"; // <width, height>
|
||||||
const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS";
|
const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS";
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -41,6 +43,10 @@ class FilterDetectionsCalculator : public CalculatorBase {
|
||||||
cc->Inputs().Tag(kInputDetectionsTag).Set<std::vector<Detection>>();
|
cc->Inputs().Tag(kInputDetectionsTag).Set<std::vector<Detection>>();
|
||||||
cc->Outputs().Tag(kOutputDetectionsTag).Set<std::vector<Detection>>();
|
cc->Outputs().Tag(kOutputDetectionsTag).Set<std::vector<Detection>>();
|
||||||
|
|
||||||
|
if (cc->Inputs().HasTag(kImageSizeTag)) {
|
||||||
|
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||||
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,21 +54,51 @@ class FilterDetectionsCalculator : public CalculatorBase {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
options_ = cc->Options<mediapipe::FilterDetectionsCalculatorOptions>();
|
options_ = cc->Options<mediapipe::FilterDetectionsCalculatorOptions>();
|
||||||
|
|
||||||
|
if (options_.has_min_pixel_size() || options_.has_max_pixel_size()) {
|
||||||
|
RET_CHECK(cc->Inputs().HasTag(kImageSizeTag));
|
||||||
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
const auto& input_detections =
|
const auto& input_detections =
|
||||||
cc->Inputs().Tag(kInputDetectionsTag).Get<std::vector<Detection>>();
|
cc->Inputs().Tag(kInputDetectionsTag).Get<std::vector<Detection>>();
|
||||||
|
|
||||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||||
|
|
||||||
|
int image_width = 0;
|
||||||
|
int image_height = 0;
|
||||||
|
if (cc->Inputs().HasTag(kImageSizeTag)) {
|
||||||
|
std::tie(image_width, image_height) =
|
||||||
|
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||||
|
}
|
||||||
|
|
||||||
for (const Detection& detection : input_detections) {
|
for (const Detection& detection : input_detections) {
|
||||||
RET_CHECK_GT(detection.score_size(), 0);
|
if (options_.has_min_score()) {
|
||||||
// Note: only score at index 0 supported.
|
RET_CHECK_GT(detection.score_size(), 0);
|
||||||
if (detection.score(0) >= options_.min_score()) {
|
// Note: only score at index 0 supported.
|
||||||
output_detections->push_back(detection);
|
if (detection.score(0) < options_.min_score()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// Matches rect_size in
|
||||||
|
// mediapipe/calculators/util/rect_to_render_scale_calculator.cc
|
||||||
|
const float rect_size =
|
||||||
|
std::max(detection.location_data().relative_bounding_box().width() *
|
||||||
|
image_width,
|
||||||
|
detection.location_data().relative_bounding_box().height() *
|
||||||
|
image_height);
|
||||||
|
if (options_.has_min_pixel_size()) {
|
||||||
|
if (rect_size < options_.min_pixel_size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (options_.has_max_pixel_size()) {
|
||||||
|
if (rect_size > options_.max_pixel_size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_detections->push_back(detection);
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
|
|
|
@ -25,4 +25,10 @@ message FilterDetectionsCalculatorOptions {
|
||||||
|
|
||||||
// Detections lower than this score get filtered out.
|
// Detections lower than this score get filtered out.
|
||||||
optional float min_score = 1;
|
optional float min_score = 1;
|
||||||
|
|
||||||
|
// Detections smaller than this size *in pixels* get filtered out.
|
||||||
|
optional float min_pixel_size = 2;
|
||||||
|
|
||||||
|
// Detections larger than this size *in pixels* get filtered out.
|
||||||
|
optional float max_pixel_size = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
@ -27,8 +28,8 @@ namespace {
|
||||||
|
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
absl::Status RunGraph(std::vector<Detection>& input_detections,
|
absl::Status RunScoreGraph(std::vector<Detection>& input_detections,
|
||||||
std::vector<Detection>* output_detections) {
|
std::vector<Detection>* output_detections) {
|
||||||
CalculatorRunner runner(R"pb(
|
CalculatorRunner runner(R"pb(
|
||||||
calculator: "FilterDetectionsCalculator"
|
calculator: "FilterDetectionsCalculator"
|
||||||
input_stream: "INPUT_DETECTIONS:input_detections"
|
input_stream: "INPUT_DETECTIONS:input_detections"
|
||||||
|
@ -53,7 +54,7 @@ absl::Status RunGraph(std::vector<Detection>& input_detections,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(FilterDetectionsCalculatorTest, TestFilterDetections) {
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsScore) {
|
||||||
std::vector<Detection> input_detections;
|
std::vector<Detection> input_detections;
|
||||||
Detection d1, d2;
|
Detection d1, d2;
|
||||||
d1.add_score(0.2);
|
d1.add_score(0.2);
|
||||||
|
@ -62,12 +63,12 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetections) {
|
||||||
input_detections.push_back(d2);
|
input_detections.push_back(d2);
|
||||||
|
|
||||||
std::vector<Detection> output_detections;
|
std::vector<Detection> output_detections;
|
||||||
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2)));
|
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) {
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsScoreMultiple) {
|
||||||
std::vector<Detection> input_detections;
|
std::vector<Detection> input_detections;
|
||||||
Detection d1, d2, d3, d4;
|
Detection d1, d2, d3, d4;
|
||||||
d1.add_score(0.3);
|
d1.add_score(0.3);
|
||||||
|
@ -80,7 +81,7 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) {
|
||||||
input_detections.push_back(d4);
|
input_detections.push_back(d4);
|
||||||
|
|
||||||
std::vector<Detection> output_detections;
|
std::vector<Detection> output_detections;
|
||||||
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3),
|
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3),
|
||||||
mediapipe::EqualsProto(d4)));
|
mediapipe::EqualsProto(d4)));
|
||||||
|
@ -90,10 +91,69 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) {
|
||||||
std::vector<Detection> input_detections;
|
std::vector<Detection> input_detections;
|
||||||
|
|
||||||
std::vector<Detection> output_detections;
|
std::vector<Detection> output_detections;
|
||||||
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
EXPECT_EQ(output_detections.size(), 0);
|
EXPECT_EQ(output_detections.size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status RunSizeGraph(std::vector<Detection>& input_detections,
|
||||||
|
std::pair<int, int> image_dimensions,
|
||||||
|
std::vector<Detection>* output_detections) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FilterDetectionsCalculator"
|
||||||
|
input_stream: "INPUT_DETECTIONS:input_detections"
|
||||||
|
input_stream: "IMAGE_SIZE:image_dimensions"
|
||||||
|
output_stream: "OUTPUT_DETECTIONS:output_detections"
|
||||||
|
options {
|
||||||
|
[mediapipe.FilterDetectionsCalculatorOptions.ext] { min_pixel_size: 50 }
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
const Timestamp input_timestamp = Timestamp(0);
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("INPUT_DETECTIONS")
|
||||||
|
.packets.push_back(MakePacket<std::vector<Detection>>(input_detections)
|
||||||
|
.At(input_timestamp));
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("IMAGE_SIZE")
|
||||||
|
.packets.push_back(MakePacket<std::pair<int, int>>(image_dimensions)
|
||||||
|
.At(input_timestamp));
|
||||||
|
MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed.";
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner.Outputs().Tag("OUTPUT_DETECTIONS").packets;
|
||||||
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
|
|
||||||
|
*output_detections = output_packets[0].Get<std::vector<Detection>>();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMinSize) {
|
||||||
|
std::vector<Detection> input_detections;
|
||||||
|
Detection d1, d2, d3, d4, d5;
|
||||||
|
d1.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.5);
|
||||||
|
d1.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.49);
|
||||||
|
d2.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.4);
|
||||||
|
d2.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.4);
|
||||||
|
d3.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.49);
|
||||||
|
d3.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.5);
|
||||||
|
d4.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.49);
|
||||||
|
d4.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.49);
|
||||||
|
d5.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.5);
|
||||||
|
d5.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.5);
|
||||||
|
input_detections.push_back(d1);
|
||||||
|
input_detections.push_back(d2);
|
||||||
|
input_detections.push_back(d3);
|
||||||
|
input_detections.push_back(d4);
|
||||||
|
input_detections.push_back(d5);
|
||||||
|
|
||||||
|
std::vector<Detection> output_detections;
|
||||||
|
MP_EXPECT_OK(RunSizeGraph(input_detections, {100, 100}, &output_detections));
|
||||||
|
|
||||||
|
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d1),
|
||||||
|
mediapipe::EqualsProto(d3),
|
||||||
|
mediapipe::EqualsProto(d5)));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user