fix TopKScoresCalculator doesn't work with TOP_K_CLASSIFICATIONS
This commit is contained in:
parent
1341720d6d
commit
4922fb0dff
|
@ -39,9 +39,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
|
||||
constexpr char kSummaryTag[] = "SUMMARY";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
|
||||
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
|
||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||
|
@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
|
|||
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kClassificationsTag)) {
|
||||
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
||||
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
|
||||
cc->Outputs().Tag(kTopKClassificationsTag).Set<ClassificationList>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
|
||||
|
@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
|||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
|
||||
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
Classification* classification =
|
||||
|
@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
|||
classification->set_label(top_k_labels[index]);
|
||||
}
|
||||
}
|
||||
|
||||
cc->Outputs().Tag(kTopKClassificationsTag).Add(
|
||||
classification_list.release(), cc->InputTimestamp());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
@ -25,6 +26,7 @@ namespace mediapipe {
|
|||
|
||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
|
||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
||||
|
@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
|||
EXPECT_NEAR(0.3, scores[2], 1e-5);
|
||||
}
|
||||
|
||||
TEST(TopKScoresCalculatorTest, TestTopKClassifications) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TopKScoresCalculator"
|
||||
input_stream: "SCORES:score_vector"
|
||||
output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications"
|
||||
options: {
|
||||
[mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 }
|
||||
}
|
||||
)pb"));
|
||||
|
||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag(kScoresTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& classifications_outputs =
|
||||
runner.Outputs().Tag(kTopKClassificationsTag).packets;
|
||||
ASSERT_EQ(1, classifications_outputs.size());
|
||||
const auto& classification_list =
|
||||
classifications_outputs[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
|
||||
EXPECT_EQ(3, classification_list.classification(0).index());
|
||||
EXPECT_EQ(0, classification_list.classification(1).index());
|
||||
EXPECT_EQ(2, classification_list.classification(2).index());
|
||||
|
||||
EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5);
|
||||
EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5);
|
||||
EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5);
|
||||
|
||||
ASSERT_FALSE(classification_list.classification(0).has_label());
|
||||
ASSERT_FALSE(classification_list.classification(1).has_label());
|
||||
ASSERT_FALSE(classification_list.classification(2).has_label());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user