Add label_map filtering into filter_detection drishti calculator.
PiperOrigin-RevId: 517515046
This commit is contained in:
parent
1a456dcbf9
commit
3dede1a9a5
|
@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
|
||||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
constexpr char kLabelsTag[] = "LABELS";
|
constexpr char kLabelsTag[] = "LABELS";
|
||||||
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
||||||
|
constexpr char kLabelMapTag[] = "LABEL_MAP";
|
||||||
|
|
||||||
using mediapipe::RE2;
|
using mediapipe::RE2;
|
||||||
using Detections = std::vector<Detection>;
|
using Detections = std::vector<Detection>;
|
||||||
|
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Set<std::unique_ptr<std::map<int, std::string>>>();
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
||||||
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
||||||
cc->InputSidePackets().HasTag(kLabelsCsvTag);
|
cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
|
||||||
|
cc->InputSidePackets().HasTag(kLabelMapTag);
|
||||||
if (limit_labels_) {
|
if (limit_labels_) {
|
||||||
Strings allowlist_labels;
|
Strings allowlist_labels;
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
|
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
for (auto& e : allowlist_labels) {
|
for (auto& e : allowlist_labels) {
|
||||||
absl::StripAsciiWhitespace(&e);
|
absl::StripAsciiWhitespace(&e);
|
||||||
}
|
}
|
||||||
} else {
|
} else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
|
||||||
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
||||||
|
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
auto label_map = cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Get<std::unique_ptr<std::map<int, std::string>>>()
|
||||||
|
.get();
|
||||||
|
for (const auto& [_, v] : *label_map) {
|
||||||
|
allowlist_labels.push_back(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
|
||||||
|
auto runner = std::make_unique<CalculatorRunner>(
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "FilterDetectionCalculator"
|
||||||
|
input_stream: "DETECTION:input"
|
||||||
|
input_side_packet: "LABEL_MAP:input_map"
|
||||||
|
output_stream: "DETECTION:output"
|
||||||
|
options {
|
||||||
|
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
runner->MutableInputs()->Tag("DETECTION").packets = {
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "d"
|
||||||
|
score: 1
|
||||||
|
score: 0.8
|
||||||
|
score: 0.3
|
||||||
|
score: 0.9
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(20)),
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "e"
|
||||||
|
score: 0.6
|
||||||
|
score: 0.4
|
||||||
|
score: 0.2
|
||||||
|
score: 0.7
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(40)),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto label_map = std::make_unique<std::map<int, std::string>>();
|
||||||
|
(*label_map)[0] = "a";
|
||||||
|
(*label_map)[1] = "b";
|
||||||
|
(*label_map)[2] = "c";
|
||||||
|
runner->MutableSidePackets()->Tag("LABEL_MAP") =
|
||||||
|
AdoptAsUniquePtr(label_map.release());
|
||||||
|
|
||||||
|
// Run graph.
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
|
||||||
|
// Check output.
|
||||||
|
EXPECT_THAT(
|
||||||
|
runner->Outputs().Tag("DETECTION").packets,
|
||||||
|
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(20)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" label: "b" score: 1 score: 0.8
|
||||||
|
)pb")), // Packet 1 at timestamp 20.
|
||||||
|
PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(40)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" score: 0.6
|
||||||
|
)pb")) // Packet 2 at timestamp 40.
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user