Add tensor_index and tensor_name fields to ClassificationList

PiperOrigin-RevId: 482901854
This commit is contained in:
MediaPipe Team 2022-10-21 15:25:38 -07:00 committed by Copybara-Service
parent 36d69971a7
commit d0437b7f91
4 changed files with 46 additions and 0 deletions

View File

@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
} }
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
const auto& input_tensors = *kInTensors(cc); const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
auto raw_scores = view.buffer<float>(); auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
if (options.has_tensor_index()) {
classification_list->set_tensor_index(options.tensor_index());
}
if (options.has_tensor_name()) {
classification_list->set_tensor_name(options.tensor_name());
}
if (is_binary_classification_) { if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification(); Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification(); Classification* class_second = classification_list->add_classification();

View File

@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
// that are not in the `allow_classes` field will be completely ignored. // that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive. // `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true]; repeated int32 allow_classes = 8 [packed = true];
// The optional index of the tensor these classifications originate from.
optional int32 tensor_index = 10;
// The optional name of the tensor these classifications originate from.
optional string tensor_name = 11;
} }

View File

@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
} }
} }
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithTensorNameAndIndex) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
tensor_index: 1
tensor_name: "foo"
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the tensor_index and tensor_name fields are correctly set.
EXPECT_EQ(classification_list.tensor_index(), 1);
EXPECT_EQ(classification_list.tensor_name(), "foo");
}
TEST_F(TensorsToClassificationCalculatorTest, TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) { ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -37,6 +37,10 @@ message Classification {
// Group of Classification protos. // Group of Classification protos.
message ClassificationList { message ClassificationList {
repeated Classification classification = 1; repeated Classification classification = 1;
// Optional index of the tensor that produced these classifications.
optional int32 tensor_index = 2;
// Optional name of the tensor that produced these classifications.
optional string tensor_name = 3;
} }
// Group of ClassificationList protos. // Group of ClassificationList protos.