fix TopKScoresCalculator doesn't work with TOP_K_CLASSIFICATIONS
This commit is contained in:
		
							parent
							
								
									1341720d6d
								
							
						
					
					
						commit
						4922fb0dff
					
				| 
						 | 
					@ -39,9 +39,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
 | 
					 | 
				
			||||||
constexpr char kSummaryTag[] = "SUMMARY";
 | 
					constexpr char kSummaryTag[] = "SUMMARY";
 | 
				
			||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
					constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
 | 
				
			||||||
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
 | 
					constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
 | 
				
			||||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
					constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
				
			||||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
					constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
				
			||||||
| 
						 | 
					@ -98,8 +97,8 @@ absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
 | 
				
			||||||
  if (cc->Outputs().HasTag(kTopKLabelsTag)) {
 | 
					  if (cc->Outputs().HasTag(kTopKLabelsTag)) {
 | 
				
			||||||
    cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
 | 
					    cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (cc->Outputs().HasTag(kClassificationsTag)) {
 | 
					  if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
 | 
				
			||||||
    cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
 | 
					    cc->Outputs().Tag(kTopKClassificationsTag).Set<ClassificationList>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (cc->Outputs().HasTag(kSummaryTag)) {
 | 
					  if (cc->Outputs().HasTag(kSummaryTag)) {
 | 
				
			||||||
    cc->Outputs().Tag(kSummaryTag).Set<std::string>();
 | 
					    cc->Outputs().Tag(kSummaryTag).Set<std::string>();
 | 
				
			||||||
| 
						 | 
					@ -210,7 +209,7 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
                       .At(cc->InputTimestamp()));
 | 
					                       .At(cc->InputTimestamp()));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (cc->Outputs().HasTag(kTopKClassificationTag)) {
 | 
					  if (cc->Outputs().HasTag(kTopKClassificationsTag)) {
 | 
				
			||||||
    auto classification_list = absl::make_unique<ClassificationList>();
 | 
					    auto classification_list = absl::make_unique<ClassificationList>();
 | 
				
			||||||
    for (int index = 0; index < top_k_indexes.size(); ++index) {
 | 
					    for (int index = 0; index < top_k_indexes.size(); ++index) {
 | 
				
			||||||
      Classification* classification =
 | 
					      Classification* classification =
 | 
				
			||||||
| 
						 | 
					@ -221,6 +220,9 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
        classification->set_label(top_k_labels[index]);
 | 
					        classification->set_label(top_k_labels[index]);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cc->Outputs().Tag(kTopKClassificationsTag).Add(
 | 
				
			||||||
 | 
					        classification_list.release(), cc->InputTimestamp());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/calculator_runner.h"
 | 
					#include "mediapipe/framework/calculator_runner.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gmock.h"
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gtest.h"
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
					#include "mediapipe/framework/port/parse_text_proto.h"
 | 
				
			||||||
| 
						 | 
					@ -25,6 +26,7 @@ namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
					constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
 | 
				
			||||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
					constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
 | 
				
			||||||
 | 
					constexpr char kTopKClassificationsTag[] = "TOP_K_CLASSIFICATIONS";
 | 
				
			||||||
constexpr char kScoresTag[] = "SCORES";
 | 
					constexpr char kScoresTag[] = "SCORES";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
 | 
					TEST(TopKScoresCalculatorTest, TestNodeConfig) {
 | 
				
			||||||
| 
						 | 
					@ -157,4 +159,42 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
 | 
				
			||||||
  EXPECT_NEAR(0.3, scores[2], 1e-5);
 | 
					  EXPECT_NEAR(0.3, scores[2], 1e-5);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TopKScoresCalculatorTest, TestTopKClassifications) {
 | 
				
			||||||
 | 
					  CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
 | 
				
			||||||
 | 
					    calculator: "TopKScoresCalculator"
 | 
				
			||||||
 | 
					    input_stream: "SCORES:score_vector"
 | 
				
			||||||
 | 
					    output_stream: "TOP_K_CLASSIFICATIONS:top_k_classifications"
 | 
				
			||||||
 | 
					    options: {
 | 
				
			||||||
 | 
					      [mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 3 }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  runner.MutableInputs()
 | 
				
			||||||
 | 
					      ->Tag(kScoresTag)
 | 
				
			||||||
 | 
					      .packets.push_back(
 | 
				
			||||||
 | 
					          MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner.Run());
 | 
				
			||||||
 | 
					  const std::vector<Packet>& classifications_outputs =
 | 
				
			||||||
 | 
					      runner.Outputs().Tag(kTopKClassificationsTag).packets;
 | 
				
			||||||
 | 
					  ASSERT_EQ(1, classifications_outputs.size());
 | 
				
			||||||
 | 
					  const auto& classification_list =
 | 
				
			||||||
 | 
					      classifications_outputs[0].Get<ClassificationList>();
 | 
				
			||||||
 | 
					  EXPECT_EQ(3, classification_list.classification_size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(3, classification_list.classification(0).index());
 | 
				
			||||||
 | 
					  EXPECT_EQ(0, classification_list.classification(1).index());
 | 
				
			||||||
 | 
					  EXPECT_EQ(2, classification_list.classification(2).index());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_NEAR(1.0, classification_list.classification(0).score(), 1e-5);
 | 
				
			||||||
 | 
					  EXPECT_NEAR(0.9, classification_list.classification(1).score(), 1e-5);
 | 
				
			||||||
 | 
					  EXPECT_NEAR(0.3, classification_list.classification(2).score(), 1e-5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ASSERT_FALSE(classification_list.classification(0).has_label());
 | 
				
			||||||
 | 
					  ASSERT_FALSE(classification_list.classification(1).has_label());
 | 
				
			||||||
 | 
					  ASSERT_FALSE(classification_list.classification(2).has_label());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user