fix TopKScoresCalculator doesn't work with TOP_K_CLASSIFICATIONS
This commit is contained in:
parent
1341720d6d
commit
4922fb0dff
|
@ -39,9 +39,8 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
|
|
||||||
constexpr char kSummaryTag[] = "SUMMARY";
|
constexpr char kSummaryTag[] = "SUMMARY";
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
|
||||||
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
|
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
|
||||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||||
|
@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||||
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
|
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag(kClassificationsTag)) {
|
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
|
||||||
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
cc->Outputs().Tag(kTopKClassificationsTag).Set<ClassificationList>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag(kSummaryTag)) {
|
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||||
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
|
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
|
||||||
|
@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
|
if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||||
Classification* classification =
|
Classification* classification =
|
||||||
|
@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||||
classification->set_label(top_k_labels[index]);
|
classification->set_label(top_k_labels[index]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cc->Outputs().Tag(kTopKClassificationsTag).Add(
|
||||||
|
classification_list.release(), cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
@ -25,6 +26,7 @@ namespace mediapipe {
|
||||||
|
|
||||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||||
|
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
|
||||||
constexpr char kScoresTag[] = "SCORES";
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
|
||||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
||||||
|
@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
||||||
EXPECT_NEAR(0.3, scores[2], 1e-5);
|
EXPECT_NEAR(0.3, scores[2], 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TopKScoresCalculatorTest, TestTopKClassifications) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "TopKScoresCalculator"
|
||||||
|
input_stream: "SCORES:score_vector"
|
||||||
|
output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications"
|
||||||
|
options: {
|
||||||
|
[mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 }
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||||
|
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag(kScoresTag)
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const std::vector<Packet>& classifications_outputs =
|
||||||
|
runner.Outputs().Tag(kTopKClassificationsTag).packets;
|
||||||
|
ASSERT_EQ(1, classifications_outputs.size());
|
||||||
|
const auto& classification_list =
|
||||||
|
classifications_outputs[0].Get<ClassificationList>();
|
||||||
|
EXPECT_EQ(3, classification_list.classification_size());
|
||||||
|
|
||||||
|
EXPECT_EQ(3, classification_list.classification(0).index());
|
||||||
|
EXPECT_EQ(0, classification_list.classification(1).index());
|
||||||
|
EXPECT_EQ(2, classification_list.classification(2).index());
|
||||||
|
|
||||||
|
EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5);
|
||||||
|
EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5);
|
||||||
|
EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5);
|
||||||
|
|
||||||
|
ASSERT_FALSE(classification_list.classification(0).has_label());
|
||||||
|
ASSERT_FALSE(classification_list.classification(1).has_label());
|
||||||
|
ASSERT_FALSE(classification_list.classification(2).has_label());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user