Add label_map filtering into filter_detection drishti calculator.
PiperOrigin-RevId: 517515046
This commit is contained in:
		
							parent
							
								
									1a456dcbf9
								
							
						
					
					
						commit
						3dede1a9a5
					
				| 
						 | 
					@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
 | 
				
			||||||
constexpr char kDetectionsTag[] = "DETECTIONS";
 | 
					constexpr char kDetectionsTag[] = "DETECTIONS";
 | 
				
			||||||
constexpr char kLabelsTag[] = "LABELS";
 | 
					constexpr char kLabelsTag[] = "LABELS";
 | 
				
			||||||
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
 | 
					constexpr char kLabelsCsvTag[] = "LABELS_CSV";
 | 
				
			||||||
 | 
					constexpr char kLabelMapTag[] = "LABEL_MAP";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using mediapipe::RE2;
 | 
					using mediapipe::RE2;
 | 
				
			||||||
using Detections = std::vector<Detection>;
 | 
					using Detections = std::vector<Detection>;
 | 
				
			||||||
| 
						 | 
					@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
 | 
				
			||||||
  if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
 | 
					  if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
 | 
				
			||||||
    cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
 | 
					    cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
 | 
				
			||||||
 | 
					    cc->InputSidePackets()
 | 
				
			||||||
 | 
					        .Tag(kLabelMapTag)
 | 
				
			||||||
 | 
					        .Set<std::unique_ptr<std::map<int, std::string>>>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
  cc->SetOffset(TimestampDiff(0));
 | 
					  cc->SetOffset(TimestampDiff(0));
 | 
				
			||||||
  options_ = cc->Options<FilterDetectionCalculatorOptions>();
 | 
					  options_ = cc->Options<FilterDetectionCalculatorOptions>();
 | 
				
			||||||
  limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
 | 
					  limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
 | 
				
			||||||
                  cc->InputSidePackets().HasTag(kLabelsCsvTag);
 | 
					                  cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
 | 
				
			||||||
 | 
					                  cc->InputSidePackets().HasTag(kLabelMapTag);
 | 
				
			||||||
  if (limit_labels_) {
 | 
					  if (limit_labels_) {
 | 
				
			||||||
    Strings allowlist_labels;
 | 
					    Strings allowlist_labels;
 | 
				
			||||||
    if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
 | 
					    if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
 | 
				
			||||||
| 
						 | 
					@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
      for (auto& e : allowlist_labels) {
 | 
					      for (auto& e : allowlist_labels) {
 | 
				
			||||||
        absl::StripAsciiWhitespace(&e);
 | 
					        absl::StripAsciiWhitespace(&e);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    } else {
 | 
					    } else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
 | 
				
			||||||
      allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
 | 
					      allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
 | 
				
			||||||
 | 
					    } else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
 | 
				
			||||||
 | 
					      auto label_map = cc->InputSidePackets()
 | 
				
			||||||
 | 
					                           .Tag(kLabelMapTag)
 | 
				
			||||||
 | 
					                           .Get<std::unique_ptr<std::map<int, std::string>>>()
 | 
				
			||||||
 | 
					                           .get();
 | 
				
			||||||
 | 
					      for (const auto& [_, v] : *label_map) {
 | 
				
			||||||
 | 
					        allowlist_labels.push_back(v);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
 | 
					    allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
 | 
				
			||||||
                  ));
 | 
					                  ));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
 | 
				
			||||||
 | 
					  auto runner = std::make_unique<CalculatorRunner>(
 | 
				
			||||||
 | 
					      ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
 | 
				
			||||||
 | 
					        calculator: "FilterDetectionCalculator"
 | 
				
			||||||
 | 
					        input_stream: "DETECTION:input"
 | 
				
			||||||
 | 
					        input_side_packet: "LABEL_MAP:input_map"
 | 
				
			||||||
 | 
					        output_stream: "DETECTION:output"
 | 
				
			||||||
 | 
					        options {
 | 
				
			||||||
 | 
					          [mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Tag("DETECTION").packets = {
 | 
				
			||||||
 | 
					      MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
 | 
				
			||||||
 | 
					        label: "a"
 | 
				
			||||||
 | 
					        label: "b"
 | 
				
			||||||
 | 
					        label: "c"
 | 
				
			||||||
 | 
					        label: "d"
 | 
				
			||||||
 | 
					        score: 1
 | 
				
			||||||
 | 
					        score: 0.8
 | 
				
			||||||
 | 
					        score: 0.3
 | 
				
			||||||
 | 
					        score: 0.9
 | 
				
			||||||
 | 
					      )pb"))
 | 
				
			||||||
 | 
					          .At(Timestamp(20)),
 | 
				
			||||||
 | 
					      MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
 | 
				
			||||||
 | 
					        label: "a"
 | 
				
			||||||
 | 
					        label: "b"
 | 
				
			||||||
 | 
					        label: "c"
 | 
				
			||||||
 | 
					        label: "e"
 | 
				
			||||||
 | 
					        score: 0.6
 | 
				
			||||||
 | 
					        score: 0.4
 | 
				
			||||||
 | 
					        score: 0.2
 | 
				
			||||||
 | 
					        score: 0.7
 | 
				
			||||||
 | 
					      )pb"))
 | 
				
			||||||
 | 
					          .At(Timestamp(40)),
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto label_map = std::make_unique<std::map<int, std::string>>();
 | 
				
			||||||
 | 
					  (*label_map)[0] = "a";
 | 
				
			||||||
 | 
					  (*label_map)[1] = "b";
 | 
				
			||||||
 | 
					  (*label_map)[2] = "c";
 | 
				
			||||||
 | 
					  runner->MutableSidePackets()->Tag("LABEL_MAP") =
 | 
				
			||||||
 | 
					      AdoptAsUniquePtr(label_map.release());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Run graph.
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner->Run());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check output.
 | 
				
			||||||
 | 
					  EXPECT_THAT(
 | 
				
			||||||
 | 
					      runner->Outputs().Tag("DETECTION").packets,
 | 
				
			||||||
 | 
					      ElementsAre(PacketContainsTimestampAndPayload<Detection>(
 | 
				
			||||||
 | 
					                      Eq(Timestamp(20)),
 | 
				
			||||||
 | 
					                      EqualsProto(R"pb(
 | 
				
			||||||
 | 
					                        label: "a" label: "b" score: 1 score: 0.8
 | 
				
			||||||
 | 
					                      )pb")),  // Packet 1 at timestamp 20.
 | 
				
			||||||
 | 
					                  PacketContainsTimestampAndPayload<Detection>(
 | 
				
			||||||
 | 
					                      Eq(Timestamp(40)),
 | 
				
			||||||
 | 
					                      EqualsProto(R"pb(
 | 
				
			||||||
 | 
					                        label: "a" score: 0.6
 | 
				
			||||||
 | 
					                      )pb"))  // Packet 2 at timestamp 40.
 | 
				
			||||||
 | 
					                  ));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user