fix TopKScoresCalculator doesn't work with TOP_K_CLASSIFICATIONS

This commit is contained in:
Luke 2022-12-21 20:02:02 +09:00
parent 1341720d6d
commit 4922fb0dff
2 changed files with 47 additions and 5 deletions

View File

@ -39,9 +39,8 @@
namespace mediapipe {
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
constexpr char kSummaryTag[] = "SUMMARY";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
}
if (cc->Outputs().HasTag(kClassificationsTag)) {
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
cc->Outputs().Tag(kTopKClassificationsTag).Set<ClassificationList>();
}
if (cc->Outputs().HasTag(kSummaryTag)) {
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification =
@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
classification->set_label(top_k_labels[index]);
}
}
cc->Outputs().Tag(kTopKClassificationsTag).Add(
classification_list.release(), cc->InputTimestamp());
}
return absl::OkStatus();
}

View File

@ -15,6 +15,7 @@
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
@ -25,6 +26,7 @@ namespace mediapipe {
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
EXPECT_NEAR(0.3, scores[2], 1e-5);
}
TEST(TopKScoresCalculatorTest, TestTopKClassifications) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TopKScoresCalculator"
input_stream: "SCORES:score_vector"
output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications"
options: {
[mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 }
}
)pb"));
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& classifications_outputs =
runner.Outputs().Tag(kTopKClassificationsTag).packets;
ASSERT_EQ(1, classifications_outputs.size());
const auto& classification_list =
classifications_outputs[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
EXPECT_EQ(3, classification_list.classification(0).index());
EXPECT_EQ(0, classification_list.classification(1).index());
EXPECT_EQ(2, classification_list.classification(2).index());
EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5);
EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5);
EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5);
ASSERT_FALSE(classification_list.classification(0).has_label());
ASSERT_FALSE(classification_list.classification(1).has_label());
ASSERT_FALSE(classification_list.classification(2).has_label());
}
} // namespace mediapipe