fix TopKScoresCalculator doesn't work with TOP_K_CLASSIFICATIONS
This commit is contained in:
		
							parent
							
								
									1341720d6d
								
							
						
					
					
						commit
						4922fb0dff
					
				| 
						 | 
				
			
			@ -39,9 +39,8 @@
 | 
			
		|||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
 | 
			
		||||
constexpr char kSummaryTag[] = "SUMMARY";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
 | 
			
		||||
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
 | 
			
		||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
			
		||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
			
		||||
| 
						 | 
				
			
			@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		|||
  if (cc->Outputs().HasTag(kTopKLabelsTag)) {
 | 
			
		||||
    cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
 | 
			
		||||
  }
 | 
			
		||||
  if (cc->Outputs().HasTag(kClassificationsTag)) {
 | 
			
		||||
    cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
 | 
			
		||||
  if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
 | 
			
		||||
    cc->Outputs().Tag(kTopKClassificationsTag).Set<ClassificationList>();
 | 
			
		||||
  }
 | 
			
		||||
  if (cc->Outputs().HasTag(kSummaryTag)) {
 | 
			
		||||
    cc->Outputs().Tag(kSummaryTag).Set<std::string>();
 | 
			
		||||
| 
						 | 
				
			
			@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
                       .At(cc->InputTimestamp()));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (cc->Outputs().HasTag(kTopKClassificationTag)) {
 | 
			
		||||
  if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
 | 
			
		||||
    auto classification_list = absl::make_unique<ClassificationList>();
 | 
			
		||||
    for (int index = 0; index < top_k_indexes.size(); ++index) {
 | 
			
		||||
      Classification* classification =
 | 
			
		||||
| 
						 | 
				
			
			@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
        classification->set_label(top_k_labels[index]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cc->Outputs().Tag(kTopKClassificationsTag).Add(
 | 
			
		||||
        classification_list.release(), cc->InputTimestamp());
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/calculator_runner.h"
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -25,6 +26,7 @@ namespace mediapipe {
 | 
			
		|||
 | 
			
		||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
			
		||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
			
		||||
constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
 | 
			
		||||
constexpr char kScoresTag[] = "SCORES";
 | 
			
		||||
 | 
			
		||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
 | 
			
		||||
| 
						 | 
				
			
			@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
 | 
			
		|||
  EXPECT_NEAR(0.3, scores[2], 1e-5);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TopKScoresCalculatorTest, TestTopKClassifications) {
 | 
			
		||||
  CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
 | 
			
		||||
    calculator: "TopKScoresCalculator"
 | 
			
		||||
    input_stream: "SCORES:score_vector"
 | 
			
		||||
    output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications"
 | 
			
		||||
    options: {
 | 
			
		||||
      [mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 }
 | 
			
		||||
    }
 | 
			
		||||
  )pb"));
 | 
			
		||||
 | 
			
		||||
  std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
 | 
			
		||||
 | 
			
		||||
  runner.MutableInputs()
 | 
			
		||||
      ->Tag(kScoresTag)
 | 
			
		||||
      .packets.push_back(
 | 
			
		||||
          MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(runner.Run());
 | 
			
		||||
  const std::vector<Packet>& classifications_outputs =
 | 
			
		||||
      runner.Outputs().Tag(kTopKClassificationsTag).packets;
 | 
			
		||||
  ASSERT_EQ(1, classifications_outputs.size());
 | 
			
		||||
  const auto& classification_list =
 | 
			
		||||
      classifications_outputs[0].Get<ClassificationList>();
 | 
			
		||||
  EXPECT_EQ(3, classification_list.classification_size());
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(3, classification_list.classification(0).index());
 | 
			
		||||
  EXPECT_EQ(0, classification_list.classification(1).index());
 | 
			
		||||
  EXPECT_EQ(2, classification_list.classification(2).index());
 | 
			
		||||
 | 
			
		||||
  EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5);
 | 
			
		||||
  EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5);
 | 
			
		||||
  EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5);
 | 
			
		||||
 | 
			
		||||
  ASSERT_FALSE(classification_list.classification(0).has_label());
 | 
			
		||||
  ASSERT_FALSE(classification_list.classification(1).has_label());
 | 
			
		||||
  ASSERT_FALSE(classification_list.classification(2).has_label());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user