Set confidence score of the bounding box label.

PiperOrigin-RevId: 554508925
This commit is contained in:
Zu Kim 2023-08-07 10:00:18 -07:00 committed by Copybara-Service
parent e10bcd1bfd
commit 22054cd468
2 changed files with 19 additions and 0 deletions

View File

@ -243,6 +243,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearBBoxNumRegions(key, sequence_.get()); mpms::ClearBBoxNumRegions(key, sequence_.get());
mpms::ClearBBoxLabelString(key, sequence_.get()); mpms::ClearBBoxLabelString(key, sequence_.get());
mpms::ClearBBoxLabelIndex(key, sequence_.get()); mpms::ClearBBoxLabelIndex(key, sequence_.get());
mpms::ClearBBoxLabelConfidence(key, sequence_.get());
mpms::ClearBBoxClassString(key, sequence_.get()); mpms::ClearBBoxClassString(key, sequence_.get());
mpms::ClearBBoxClassIndex(key, sequence_.get()); mpms::ClearBBoxClassIndex(key, sequence_.get());
mpms::ClearBBoxTrackString(key, sequence_.get()); mpms::ClearBBoxTrackString(key, sequence_.get());
@ -427,6 +428,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearBBoxNumRegions(prefix, sequence_.get()); mpms::ClearBBoxNumRegions(prefix, sequence_.get());
mpms::ClearBBoxLabelString(prefix, sequence_.get()); mpms::ClearBBoxLabelString(prefix, sequence_.get());
mpms::ClearBBoxLabelIndex(prefix, sequence_.get()); mpms::ClearBBoxLabelIndex(prefix, sequence_.get());
mpms::ClearBBoxLabelConfidence(prefix, sequence_.get());
mpms::ClearBBoxClassString(prefix, sequence_.get()); mpms::ClearBBoxClassString(prefix, sequence_.get());
mpms::ClearBBoxClassIndex(prefix, sequence_.get()); mpms::ClearBBoxClassIndex(prefix, sequence_.get());
mpms::ClearBBoxTrackString(prefix, sequence_.get()); mpms::ClearBBoxTrackString(prefix, sequence_.get());
@ -494,6 +496,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
std::vector<Location> predicted_locations; std::vector<Location> predicted_locations;
std::vector<std::string> predicted_class_strings; std::vector<std::string> predicted_class_strings;
std::vector<float> predicted_class_confidences;
std::vector<int> predicted_label_ids; std::vector<int> predicted_label_ids;
for (auto& detection : for (auto& detection :
cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) { cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
@ -522,6 +525,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (detection.label_id_size() > 0) { if (detection.label_id_size() > 0) {
predicted_label_ids.push_back(detection.label_id(0)); predicted_label_ids.push_back(detection.label_id(0));
} }
if (detection.score_size() > 0) {
predicted_class_confidences.push_back(detection.score(0));
}
} }
} }
if (!predicted_locations.empty()) { if (!predicted_locations.empty()) {
@ -535,6 +541,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
if (!predicted_label_ids.empty()) { if (!predicted_label_ids.empty()) {
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get()); mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
} }
if (!predicted_class_confidences.empty()) {
mpms::AddBBoxLabelConfidence(key, predicted_class_confidences,
sequence_.get());
}
} }
} }
} }

View File

@ -593,6 +593,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]); ASSERT_EQ(1, class_indices[1]);
auto class_scores =
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
} }
} }
@ -735,6 +739,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i); auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]); ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]); ASSERT_EQ(1, class_indices[1]);
auto class_scores =
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
} }
} }
@ -1129,6 +1137,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
mpms::AddBBoxNumRegions(-1, input_sequence.get()); mpms::AddBBoxNumRegions(-1, input_sequence.get());
mpms::AddBBoxLabelString({"anything"}, input_sequence.get()); mpms::AddBBoxLabelString({"anything"}, input_sequence.get());
mpms::AddBBoxLabelIndex({-1}, input_sequence.get()); mpms::AddBBoxLabelIndex({-1}, input_sequence.get());
mpms::AddBBoxLabelConfidence({-1}, input_sequence.get());
mpms::AddBBoxClassString({"anything"}, input_sequence.get()); mpms::AddBBoxClassString({"anything"}, input_sequence.get());
mpms::AddBBoxClassIndex({-1}, input_sequence.get()); mpms::AddBBoxClassIndex({-1}, input_sequence.get());
mpms::AddBBoxTrackString({"anything"}, input_sequence.get()); mpms::AddBBoxTrackString({"anything"}, input_sequence.get());