diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index a679a80fd..6ac60d2c1 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -167,6 +167,7 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], @@ -413,6 +414,7 @@ cc_library( ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/port:status", "@com_google_absl//absl/memory", ], diff --git a/mediapipe/calculators/util/filter_detections_calculator.cc b/mediapipe/calculators/util/filter_detections_calculator.cc index a1f23ba83..7b5bcca4c 100644 --- a/mediapipe/calculators/util/filter_detections_calculator.cc +++ b/mediapipe/calculators/util/filter_detections_calculator.cc @@ -21,11 +21,13 @@ #include "mediapipe/calculators/util/filter_detections_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { const char kInputDetectionsTag[] = "INPUT_DETECTIONS"; +const char kImageSizeTag[] = "IMAGE_SIZE"; // const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS"; // @@ -41,6 +43,10 @@ class FilterDetectionsCalculator : public CalculatorBase { cc->Inputs().Tag(kInputDetectionsTag).Set>(); cc->Outputs().Tag(kOutputDetectionsTag).Set>(); + if (cc->Inputs().HasTag(kImageSizeTag)) { + cc->Inputs().Tag(kImageSizeTag).Set>(); + } + return absl::OkStatus(); } @@ -48,21 +54,51 @@ class FilterDetectionsCalculator : public CalculatorBase { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); + if (options_.has_min_pixel_size() || options_.has_max_pixel_size()) { + RET_CHECK(cc->Inputs().HasTag(kImageSizeTag)); + } + return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { const auto& input_detections = cc->Inputs().Tag(kInputDetectionsTag).Get>(); - auto output_detections = absl::make_unique>(); + int image_width = 0; + int image_height = 0; + if (cc->Inputs().HasTag(kImageSizeTag)) { + std::tie(image_width, image_height) = + cc->Inputs().Tag(kImageSizeTag).Get>(); + } + for (const Detection& detection : input_detections) { - RET_CHECK_GT(detection.score_size(), 0); - // Note: only score at index 0 supported. - if (detection.score(0) >= options_.min_score()) { - output_detections->push_back(detection); + if (options_.has_min_score()) { + RET_CHECK_GT(detection.score_size(), 0); + // Note: only score at index 0 supported. + if (detection.score(0) < options_.min_score()) { + continue; + } } + // Matches rect_size in + // mediapipe/calculators/util/rect_to_render_scale_calculator.cc + const float rect_size = + std::max(detection.location_data().relative_bounding_box().width() * + image_width, + detection.location_data().relative_bounding_box().height() * + image_height); + if (options_.has_min_pixel_size()) { + if (rect_size < options_.min_pixel_size()) { + continue; + } + } + if (options_.has_max_pixel_size()) { + if (rect_size > options_.max_pixel_size()) { + continue; + } + } + output_detections->push_back(detection); } cc->Outputs() diff --git a/mediapipe/calculators/util/filter_detections_calculator.proto b/mediapipe/calculators/util/filter_detections_calculator.proto index e16898c79..2b23236d6 100644 --- a/mediapipe/calculators/util/filter_detections_calculator.proto +++ b/mediapipe/calculators/util/filter_detections_calculator.proto @@ -25,4 +25,10 @@ message FilterDetectionsCalculatorOptions { // Detections lower than this score get filtered out. optional float min_score = 1; + + // Detections smaller than this size *in pixels* get filtered out. + optional float min_pixel_size = 2; + + // Detections larger than this size *in pixels* get filtered out. + optional float max_pixel_size = 3; } diff --git a/mediapipe/calculators/util/filter_detections_calculator_test.cc b/mediapipe/calculators/util/filter_detections_calculator_test.cc index 58b3fe41d..78093827b 100644 --- a/mediapipe/calculators/util/filter_detections_calculator_test.cc +++ b/mediapipe/calculators/util/filter_detections_calculator_test.cc @@ -17,6 +17,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -27,8 +28,8 @@ namespace { using ::testing::ElementsAre; -absl::Status RunGraph(std::vector& input_detections, - std::vector* output_detections) { +absl::Status RunScoreGraph(std::vector& input_detections, + std::vector* output_detections) { CalculatorRunner runner(R"pb( calculator: "FilterDetectionsCalculator" input_stream: "INPUT_DETECTIONS:input_detections" @@ -53,7 +54,7 @@ absl::Status RunGraph(std::vector& input_detections, return absl::OkStatus(); } -TEST(FilterDetectionsCalculatorTest, TestFilterDetections) { +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsScore) { std::vector input_detections; Detection d1, d2; d1.add_score(0.2); @@ -62,12 +63,12 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetections) { input_detections.push_back(d2); std::vector output_detections; - MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections)); EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2))); } -TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) { +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsScoreMultiple) { std::vector input_detections; Detection d1, d2, d3, d4; d1.add_score(0.3); @@ -80,7 +81,7 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) { input_detections.push_back(d4); std::vector output_detections; - MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections)); EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3), mediapipe::EqualsProto(d4))); @@ -90,10 +91,69 @@ TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) { std::vector input_detections; std::vector output_detections; - MP_EXPECT_OK(RunGraph(input_detections, &output_detections)); + MP_EXPECT_OK(RunScoreGraph(input_detections, &output_detections)); EXPECT_EQ(output_detections.size(), 0); } +absl::Status RunSizeGraph(std::vector& input_detections, + std::pair image_dimensions, + std::vector* output_detections) { + CalculatorRunner runner(R"pb( + calculator: "FilterDetectionsCalculator" + input_stream: "INPUT_DETECTIONS:input_detections" + input_stream: "IMAGE_SIZE:image_dimensions" + output_stream: "OUTPUT_DETECTIONS:output_detections" + options { + [mediapipe.FilterDetectionsCalculatorOptions.ext] { min_pixel_size: 50 } + } + )pb"); + + const Timestamp input_timestamp = Timestamp(0); + runner.MutableInputs() + ->Tag("INPUT_DETECTIONS") + .packets.push_back(MakePacket>(input_detections) + .At(input_timestamp)); + runner.MutableInputs() + ->Tag("IMAGE_SIZE") + .packets.push_back(MakePacket>(image_dimensions) + .At(input_timestamp)); + MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed."; + + const std::vector& output_packets = + runner.Outputs().Tag("OUTPUT_DETECTIONS").packets; + RET_CHECK_EQ(output_packets.size(), 1); + + *output_detections = output_packets[0].Get>(); + return absl::OkStatus(); +} + +TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMinSize) { + std::vector input_detections; + Detection d1, d2, d3, d4, d5; + d1.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.5); + d1.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.49); + d2.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.4); + d2.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.4); + d3.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.49); + d3.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.5); + d4.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.49); + d4.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.49); + d5.mutable_location_data()->mutable_relative_bounding_box()->set_height(0.5); + d5.mutable_location_data()->mutable_relative_bounding_box()->set_width(0.5); + input_detections.push_back(d1); + input_detections.push_back(d2); + input_detections.push_back(d3); + input_detections.push_back(d4); + input_detections.push_back(d5); + + std::vector output_detections; + MP_EXPECT_OK(RunSizeGraph(input_detections, {100, 100}, &output_detections)); + + EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d1), + mediapipe::EqualsProto(d3), + mediapipe::EqualsProto(d5))); +} + } // namespace } // namespace mediapipe