Add label_map filtering into filter_detection drishti calculator.

PiperOrigin-RevId: 517515046
This commit is contained in:
MediaPipe Team 2023-03-17 14:51:42 -07:00 committed by Copybara-Service
parent 1a456dcbf9
commit 3dede1a9a5
2 changed files with 80 additions and 2 deletions

View File

@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS"; constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV"; constexpr char kLabelsCsvTag[] = "LABELS_CSV";
constexpr char kLabelMapTag[] = "LABEL_MAP";
using mediapipe::RE2; using mediapipe::RE2;
using Detections = std::vector<Detection>; using Detections = std::vector<Detection>;
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>(); cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
cc->InputSidePackets()
.Tag(kLabelMapTag)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<FilterDetectionCalculatorOptions>(); options_ = cc->Options<FilterDetectionCalculatorOptions>();
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
cc->InputSidePackets().HasTag(kLabelsCsvTag); cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
cc->InputSidePackets().HasTag(kLabelMapTag);
if (limit_labels_) { if (limit_labels_) {
Strings allowlist_labels; Strings allowlist_labels;
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
for (auto& e : allowlist_labels) { for (auto& e : allowlist_labels) {
absl::StripAsciiWhitespace(&e); absl::StripAsciiWhitespace(&e);
} }
} else { } else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>(); allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
auto label_map = cc->InputSidePackets()
.Tag(kLabelMapTag)
.Get<std::unique_ptr<std::map<int, std::string>>>()
.get();
for (const auto& [_, v] : *label_map) {
allowlist_labels.push_back(v);
}
} }
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end()); allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
} }

View File

@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
)); ));
} }
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
auto runner = std::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "FilterDetectionCalculator"
input_stream: "DETECTION:input"
input_side_packet: "LABEL_MAP:input_map"
output_stream: "DETECTION:output"
options {
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
}
)pb"));
runner->MutableInputs()->Tag("DETECTION").packets = {
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "d"
score: 1
score: 0.8
score: 0.3
score: 0.9
)pb"))
.At(Timestamp(20)),
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "e"
score: 0.6
score: 0.4
score: 0.2
score: 0.7
)pb"))
.At(Timestamp(40)),
};
auto label_map = std::make_unique<std::map<int, std::string>>();
(*label_map)[0] = "a";
(*label_map)[1] = "b";
(*label_map)[2] = "c";
runner->MutableSidePackets()->Tag("LABEL_MAP") =
AdoptAsUniquePtr(label_map.release());
// Run graph.
MP_ASSERT_OK(runner->Run());
// Check output.
EXPECT_THAT(
runner->Outputs().Tag("DETECTION").packets,
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(20)),
EqualsProto(R"pb(
label: "a" label: "b" score: 1 score: 0.8
)pb")), // Packet 1 at timestamp 20.
PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(40)),
EqualsProto(R"pb(
label: "a" score: 0.6
)pb")) // Packet 2 at timestamp 40.
));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe