Add tensor_index and tensor_name fields to ClassificationList
PiperOrigin-RevId: 482901854
This commit is contained in:
parent
36d69971a7
commit
d0437b7f91
|
@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||
|
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
|||
auto raw_scores = view.buffer<float>();
|
||||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
if (options.has_tensor_index()) {
|
||||
classification_list->set_tensor_index(options.tensor_index());
|
||||
}
|
||||
if (options.has_tensor_name()) {
|
||||
classification_list->set_tensor_name(options.tensor_name());
|
||||
}
|
||||
if (is_binary_classification_) {
|
||||
Classification* class_first = classification_list->add_classification();
|
||||
Classification* class_second = classification_list->add_classification();
|
||||
|
|
|
@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
|
|||
// that are not in the `allow_classes` field will be completely ignored.
|
||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 allow_classes = 8 [packed = true];
|
||||
|
||||
// The optional index of the tensor these classifications originate from.
|
||||
optional int32 tensor_index = 10;
|
||||
// The optional name of the tensor these classifications originate from.
|
||||
optional string tensor_name = 11;
|
||||
}
|
||||
|
|
|
@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
CorrectOutputWithTensorNameAndIndex) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
tensor_index: 1
|
||||
tensor_name: "foo"
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
|
||||
// Verify that the tensor_index and tensor_name fields are correctly set.
|
||||
EXPECT_EQ(classification_list.tensor_index(), 1);
|
||||
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
ClassNameAllowlistWithLabelItems) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
|
|
|
@ -37,6 +37,10 @@ message Classification {
|
|||
// Group of Classification protos.
|
||||
message ClassificationList {
|
||||
repeated Classification classification = 1;
|
||||
// Optional index of the tensor that produced these classifications.
|
||||
optional int32 tensor_index = 2;
|
||||
// Optional name of the tensor that produced these classifications.
|
||||
optional string tensor_name = 3;
|
||||
}
|
||||
|
||||
// Group of ClassificationList protos.
|
||||
|
|
Loading…
Reference in New Issue
Block a user