Add tensor_index and tensor_name fields to ClassificationList
PiperOrigin-RevId: 482901854
This commit is contained in:
		
							parent
							
								
									36d69971a7
								
							
						
					
					
						commit
						d0437b7f91
					
				| 
						 | 
					@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
 | 
					absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
 | 
					  const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
 | 
				
			||||||
  const auto& input_tensors = *kInTensors(cc);
 | 
					  const auto& input_tensors = *kInTensors(cc);
 | 
				
			||||||
  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
					  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
				
			||||||
  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
					  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
				
			||||||
| 
						 | 
					@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
  auto raw_scores = view.buffer<float>();
 | 
					  auto raw_scores = view.buffer<float>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto classification_list = absl::make_unique<ClassificationList>();
 | 
					  auto classification_list = absl::make_unique<ClassificationList>();
 | 
				
			||||||
 | 
					  if (options.has_tensor_index()) {
 | 
				
			||||||
 | 
					    classification_list->set_tensor_index(options.tensor_index());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (options.has_tensor_name()) {
 | 
				
			||||||
 | 
					    classification_list->set_tensor_name(options.tensor_name());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  if (is_binary_classification_) {
 | 
					  if (is_binary_classification_) {
 | 
				
			||||||
    Classification* class_first = classification_list->add_classification();
 | 
					    Classification* class_first = classification_list->add_classification();
 | 
				
			||||||
    Classification* class_second = classification_list->add_classification();
 | 
					    Classification* class_second = classification_list->add_classification();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
 | 
				
			||||||
  // that are not in the `allow_classes` field will be completely ignored.
 | 
					  // that are not in the `allow_classes` field will be completely ignored.
 | 
				
			||||||
  // `ignore_classes` and `allow_classes` are mutually exclusive.
 | 
					  // `ignore_classes` and `allow_classes` are mutually exclusive.
 | 
				
			||||||
  repeated int32 allow_classes = 8 [packed = true];
 | 
					  repeated int32 allow_classes = 8 [packed = true];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // The optional index of the tensor these classifications originate from.
 | 
				
			||||||
 | 
					  optional int32 tensor_index = 10;
 | 
				
			||||||
 | 
					  // The optional name of the tensor these classifications originate from.
 | 
				
			||||||
 | 
					  optional string tensor_name = 11;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(TensorsToClassificationCalculatorTest,
 | 
				
			||||||
 | 
					       CorrectOutputWithTensorNameAndIndex) {
 | 
				
			||||||
 | 
					  mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
 | 
				
			||||||
 | 
					    calculator: "TensorsToClassificationCalculator"
 | 
				
			||||||
 | 
					    input_stream: "TENSORS:tensors"
 | 
				
			||||||
 | 
					    output_stream: "CLASSIFICATIONS:classifications"
 | 
				
			||||||
 | 
					    options {
 | 
				
			||||||
 | 
					      [mediapipe.TensorsToClassificationCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					        tensor_index: 1
 | 
				
			||||||
 | 
					        tensor_name: "foo"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  BuildGraph(&runner, {0, 0.5, 1});
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner.Run());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(1, output_packets_.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto& classification_list =
 | 
				
			||||||
 | 
					      output_packets_[0].Get<ClassificationList>();
 | 
				
			||||||
 | 
					  EXPECT_EQ(3, classification_list.classification_size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Verify that the tensor_index and tensor_name fields are correctly set.
 | 
				
			||||||
 | 
					  EXPECT_EQ(classification_list.tensor_index(), 1);
 | 
				
			||||||
 | 
					  EXPECT_EQ(classification_list.tensor_name(), "foo");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(TensorsToClassificationCalculatorTest,
 | 
					TEST_F(TensorsToClassificationCalculatorTest,
 | 
				
			||||||
       ClassNameAllowlistWithLabelItems) {
 | 
					       ClassNameAllowlistWithLabelItems) {
 | 
				
			||||||
  mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
 | 
					  mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,10 @@ message Classification {
 | 
				
			||||||
// Group of Classification protos.
 | 
					// Group of Classification protos.
 | 
				
			||||||
message ClassificationList {
 | 
					message ClassificationList {
 | 
				
			||||||
  repeated Classification classification = 1;
 | 
					  repeated Classification classification = 1;
 | 
				
			||||||
 | 
					  // Optional index of the tensor that produced these classifications.
 | 
				
			||||||
 | 
					  optional int32 tensor_index = 2;
 | 
				
			||||||
 | 
					  // Optional name of the tensor that produced these classifications.
 | 
				
			||||||
 | 
					  optional string tensor_name = 3;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Group of ClassificationList protos.
 | 
					// Group of ClassificationList protos.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user