Add label_map filtering into filter_detection drishti calculator.
PiperOrigin-RevId: 517515046
This commit is contained in:
parent
1a456dcbf9
commit
3dede1a9a5
|
@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
|
|||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kLabelsTag[] = "LABELS";
|
||||
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
||||
constexpr char kLabelMapTag[] = "LABEL_MAP";
|
||||
|
||||
using mediapipe::RE2;
|
||||
using Detections = std::vector<Detection>;
|
||||
|
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
|
|||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kLabelMapTag)
|
||||
.Set<std::unique_ptr<std::map<int, std::string>>>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
|||
cc->SetOffset(TimestampDiff(0));
|
||||
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
||||
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
||||
cc->InputSidePackets().HasTag(kLabelsCsvTag);
|
||||
cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
|
||||
cc->InputSidePackets().HasTag(kLabelMapTag);
|
||||
if (limit_labels_) {
|
||||
Strings allowlist_labels;
|
||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||
|
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
|||
for (auto& e : allowlist_labels) {
|
||||
absl::StripAsciiWhitespace(&e);
|
||||
}
|
||||
} else {
|
||||
} else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
|
||||
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
||||
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||
auto label_map = cc->InputSidePackets()
|
||||
.Tag(kLabelMapTag)
|
||||
.Get<std::unique_ptr<std::map<int, std::string>>>()
|
||||
.get();
|
||||
for (const auto& [_, v] : *label_map) {
|
||||
allowlist_labels.push_back(v);
|
||||
}
|
||||
}
|
||||
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
||||
}
|
||||
|
|
|
@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
|
|||
));
|
||||
}
|
||||
|
||||
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
|
||||
auto runner = std::make_unique<CalculatorRunner>(
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "FilterDetectionCalculator"
|
||||
input_stream: "DETECTION:input"
|
||||
input_side_packet: "LABEL_MAP:input_map"
|
||||
output_stream: "DETECTION:output"
|
||||
options {
|
||||
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
|
||||
}
|
||||
)pb"));
|
||||
|
||||
runner->MutableInputs()->Tag("DETECTION").packets = {
|
||||
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||
label: "a"
|
||||
label: "b"
|
||||
label: "c"
|
||||
label: "d"
|
||||
score: 1
|
||||
score: 0.8
|
||||
score: 0.3
|
||||
score: 0.9
|
||||
)pb"))
|
||||
.At(Timestamp(20)),
|
||||
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||
label: "a"
|
||||
label: "b"
|
||||
label: "c"
|
||||
label: "e"
|
||||
score: 0.6
|
||||
score: 0.4
|
||||
score: 0.2
|
||||
score: 0.7
|
||||
)pb"))
|
||||
.At(Timestamp(40)),
|
||||
};
|
||||
|
||||
auto label_map = std::make_unique<std::map<int, std::string>>();
|
||||
(*label_map)[0] = "a";
|
||||
(*label_map)[1] = "b";
|
||||
(*label_map)[2] = "c";
|
||||
runner->MutableSidePackets()->Tag("LABEL_MAP") =
|
||||
AdoptAsUniquePtr(label_map.release());
|
||||
|
||||
// Run graph.
|
||||
MP_ASSERT_OK(runner->Run());
|
||||
|
||||
// Check output.
|
||||
EXPECT_THAT(
|
||||
runner->Outputs().Tag("DETECTION").packets,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
|
||||
Eq(Timestamp(20)),
|
||||
EqualsProto(R"pb(
|
||||
label: "a" label: "b" score: 1 score: 0.8
|
||||
)pb")), // Packet 1 at timestamp 20.
|
||||
PacketContainsTimestampAndPayload<Detection>(
|
||||
Eq(Timestamp(40)),
|
||||
EqualsProto(R"pb(
|
||||
label: "a" score: 0.6
|
||||
)pb")) // Packet 2 at timestamp 40.
|
||||
));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user