diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index fe2d599d5..9707ece63 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -39,9 +39,8 @@ namespace mediapipe { -constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION"; constexpr char kSummaryTag[] = "SUMMARY"; -constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS"; constexpr char kTopKLabelsTag[] = "TOP_K_LABELS"; constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; @@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kTopKLabelsTag)) { cc->Outputs().Tag(kTopKLabelsTag).Set>(); } - if (cc->Outputs().HasTag(kClassificationsTag)) { - cc->Outputs().Tag(kClassificationsTag).Set(); + if (cc->Outputs().HasTag(kTopKClassificationsTag)) { + cc->Outputs().Tag(kTopKClassificationsTag).Set(); } if (cc->Outputs().HasTag(kSummaryTag)) { cc->Outputs().Tag(kSummaryTag).Set(); @@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { .At(cc->InputTimestamp())); } - if (cc->Outputs().HasTag(kTopKClassificationTag)) { + if (cc->Outputs().HasTag(kTopKClassificationsTag)) { auto classification_list = absl::make_unique(); for (int index = 0; index < top_k_indexes.size(); ++index) { Classification* classification = @@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { classification->set_label(top_k_labels[index]); } } + + cc->Outputs().Tag(kTopKClassificationsTag).Add( + classification_list.release(), cc->InputTimestamp()); } return absl::OkStatus(); } diff --git a/mediapipe/calculators/util/top_k_scores_calculator_test.cc b/mediapipe/calculators/util/top_k_scores_calculator_test.cc index e5a17af28..3fe25865b 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator_test.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator_test.cc @@ -15,6 +15,7 @@ #include #include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -25,6 +26,7 @@ namespace mediapipe { constexpr char kTopKScoresTag[] = "TOP_K_SCORES"; constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES"; +constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; TEST(TopKScoresCalculatorTest, TestNodeConfig) { @@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) { EXPECT_NEAR(0.3, scores[2], 1e-5); } +TEST(TopKScoresCalculatorTest, TestTopKClassifications) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TopKScoresCalculator" + input_stream: "SCORES:score_vector" + output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications" + options: { + [mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 } + } + )pb")); + + std::vector score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; + + runner.MutableInputs() + ->Tag(kScoresTag) + .packets.push_back( + MakePacket>(score_vector).At(Timestamp(0))); + + MP_ASSERT_OK(runner.Run()); + const std::vector& classifications_outputs = + runner.Outputs().Tag(kTopKClassificationsTag).packets; + ASSERT_EQ(1, classifications_outputs.size()); + const auto& classification_list = + classifications_outputs[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + EXPECT_EQ(3, classification_list.classification(0).index()); + EXPECT_EQ(0, classification_list.classification(1).index()); + EXPECT_EQ(2, classification_list.classification(2).index()); + + EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5); + EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5); + EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5); + + ASSERT_FALSE(classification_list.classification(0).has_label()); + ASSERT_FALSE(classification_list.classification(1).has_label()); + ASSERT_FALSE(classification_list.classification(2).has_label()); +} + } // namespace mediapipe