Set confidence score of the bounding box label.
PiperOrigin-RevId: 554508925
This commit is contained in:
parent
e10bcd1bfd
commit
22054cd468
|
@ -243,6 +243,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
mpms::ClearBBoxNumRegions(key, sequence_.get());
|
mpms::ClearBBoxNumRegions(key, sequence_.get());
|
||||||
mpms::ClearBBoxLabelString(key, sequence_.get());
|
mpms::ClearBBoxLabelString(key, sequence_.get());
|
||||||
mpms::ClearBBoxLabelIndex(key, sequence_.get());
|
mpms::ClearBBoxLabelIndex(key, sequence_.get());
|
||||||
|
mpms::ClearBBoxLabelConfidence(key, sequence_.get());
|
||||||
mpms::ClearBBoxClassString(key, sequence_.get());
|
mpms::ClearBBoxClassString(key, sequence_.get());
|
||||||
mpms::ClearBBoxClassIndex(key, sequence_.get());
|
mpms::ClearBBoxClassIndex(key, sequence_.get());
|
||||||
mpms::ClearBBoxTrackString(key, sequence_.get());
|
mpms::ClearBBoxTrackString(key, sequence_.get());
|
||||||
|
@ -427,6 +428,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
mpms::ClearBBoxNumRegions(prefix, sequence_.get());
|
mpms::ClearBBoxNumRegions(prefix, sequence_.get());
|
||||||
mpms::ClearBBoxLabelString(prefix, sequence_.get());
|
mpms::ClearBBoxLabelString(prefix, sequence_.get());
|
||||||
mpms::ClearBBoxLabelIndex(prefix, sequence_.get());
|
mpms::ClearBBoxLabelIndex(prefix, sequence_.get());
|
||||||
|
mpms::ClearBBoxLabelConfidence(prefix, sequence_.get());
|
||||||
mpms::ClearBBoxClassString(prefix, sequence_.get());
|
mpms::ClearBBoxClassString(prefix, sequence_.get());
|
||||||
mpms::ClearBBoxClassIndex(prefix, sequence_.get());
|
mpms::ClearBBoxClassIndex(prefix, sequence_.get());
|
||||||
mpms::ClearBBoxTrackString(prefix, sequence_.get());
|
mpms::ClearBBoxTrackString(prefix, sequence_.get());
|
||||||
|
@ -494,6 +496,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
std::vector<Location> predicted_locations;
|
std::vector<Location> predicted_locations;
|
||||||
std::vector<std::string> predicted_class_strings;
|
std::vector<std::string> predicted_class_strings;
|
||||||
|
std::vector<float> predicted_class_confidences;
|
||||||
std::vector<int> predicted_label_ids;
|
std::vector<int> predicted_label_ids;
|
||||||
for (auto& detection :
|
for (auto& detection :
|
||||||
cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
|
cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
|
||||||
|
@ -522,6 +525,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (detection.label_id_size() > 0) {
|
if (detection.label_id_size() > 0) {
|
||||||
predicted_label_ids.push_back(detection.label_id(0));
|
predicted_label_ids.push_back(detection.label_id(0));
|
||||||
}
|
}
|
||||||
|
if (detection.score_size() > 0) {
|
||||||
|
predicted_class_confidences.push_back(detection.score(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!predicted_locations.empty()) {
|
if (!predicted_locations.empty()) {
|
||||||
|
@ -535,6 +541,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (!predicted_label_ids.empty()) {
|
if (!predicted_label_ids.empty()) {
|
||||||
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
|
mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
|
||||||
}
|
}
|
||||||
|
if (!predicted_class_confidences.empty()) {
|
||||||
|
mpms::AddBBoxLabelConfidence(key, predicted_class_confidences,
|
||||||
|
sequence_.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -593,6 +593,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
|
||||||
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
|
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
|
||||||
ASSERT_EQ(0, class_indices[0]);
|
ASSERT_EQ(0, class_indices[0]);
|
||||||
ASSERT_EQ(1, class_indices[1]);
|
ASSERT_EQ(1, class_indices[1]);
|
||||||
|
auto class_scores =
|
||||||
|
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
|
||||||
|
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
|
||||||
|
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -735,6 +739,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
||||||
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
|
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
|
||||||
ASSERT_EQ(0, class_indices[0]);
|
ASSERT_EQ(0, class_indices[0]);
|
||||||
ASSERT_EQ(1, class_indices[1]);
|
ASSERT_EQ(1, class_indices[1]);
|
||||||
|
auto class_scores =
|
||||||
|
mpms::GetPredictedBBoxLabelConfidenceAt(output_sequence, i);
|
||||||
|
ASSERT_FLOAT_EQ(0.5, class_scores[0]);
|
||||||
|
ASSERT_FLOAT_EQ(0.75, class_scores[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1129,6 +1137,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
||||||
mpms::AddBBoxNumRegions(-1, input_sequence.get());
|
mpms::AddBBoxNumRegions(-1, input_sequence.get());
|
||||||
mpms::AddBBoxLabelString({"anything"}, input_sequence.get());
|
mpms::AddBBoxLabelString({"anything"}, input_sequence.get());
|
||||||
mpms::AddBBoxLabelIndex({-1}, input_sequence.get());
|
mpms::AddBBoxLabelIndex({-1}, input_sequence.get());
|
||||||
|
mpms::AddBBoxLabelConfidence({-1}, input_sequence.get());
|
||||||
mpms::AddBBoxClassString({"anything"}, input_sequence.get());
|
mpms::AddBBoxClassString({"anything"}, input_sequence.get());
|
||||||
mpms::AddBBoxClassIndex({-1}, input_sequence.get());
|
mpms::AddBBoxClassIndex({-1}, input_sequence.get());
|
||||||
mpms::AddBBoxTrackString({"anything"}, input_sequence.get());
|
mpms::AddBBoxTrackString({"anything"}, input_sequence.get());
|
||||||
|
|
Loading…
Reference in New Issue
Block a user