Add tensor_index and tensor_name fields to ClassificationList
PiperOrigin-RevId: 482901854
This commit is contained in:
parent
36d69971a7
commit
d0437b7f91
|
@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
|
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
||||||
const auto& input_tensors = *kInTensors(cc);
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
auto raw_scores = view.buffer<float>();
|
auto raw_scores = view.buffer<float>();
|
||||||
|
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
|
if (options.has_tensor_index()) {
|
||||||
|
classification_list->set_tensor_index(options.tensor_index());
|
||||||
|
}
|
||||||
|
if (options.has_tensor_name()) {
|
||||||
|
classification_list->set_tensor_name(options.tensor_name());
|
||||||
|
}
|
||||||
if (is_binary_classification_) {
|
if (is_binary_classification_) {
|
||||||
Classification* class_first = classification_list->add_classification();
|
Classification* class_first = classification_list->add_classification();
|
||||||
Classification* class_second = classification_list->add_classification();
|
Classification* class_second = classification_list->add_classification();
|
||||||
|
|
|
@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
|
||||||
// that are not in the `allow_classes` field will be completely ignored.
|
// that are not in the `allow_classes` field will be completely ignored.
|
||||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||||
repeated int32 allow_classes = 8 [packed = true];
|
repeated int32 allow_classes = 8 [packed = true];
|
||||||
|
|
||||||
|
// The optional index of the tensor these classifications originate from.
|
||||||
|
optional int32 tensor_index = 10;
|
||||||
|
// The optional name of the tensor these classifications originate from.
|
||||||
|
optional string tensor_name = 11;
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
|
CorrectOutputWithTensorNameAndIndex) {
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "TensorsToClassificationCalculator"
|
||||||
|
input_stream: "TENSORS:tensors"
|
||||||
|
output_stream: "CLASSIFICATIONS:classifications"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||||
|
tensor_index: 1
|
||||||
|
tensor_name: "foo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0, 0.5, 1});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||||
|
|
||||||
|
EXPECT_EQ(1, output_packets_.size());
|
||||||
|
|
||||||
|
const auto& classification_list =
|
||||||
|
output_packets_[0].Get<ClassificationList>();
|
||||||
|
EXPECT_EQ(3, classification_list.classification_size());
|
||||||
|
|
||||||
|
// Verify that the tensor_index and tensor_name fields are correctly set.
|
||||||
|
EXPECT_EQ(classification_list.tensor_index(), 1);
|
||||||
|
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
ClassNameAllowlistWithLabelItems) {
|
ClassNameAllowlistWithLabelItems) {
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
|
|
@ -37,6 +37,10 @@ message Classification {
|
||||||
// Group of Classification protos.
|
// Group of Classification protos.
|
||||||
message ClassificationList {
|
message ClassificationList {
|
||||||
repeated Classification classification = 1;
|
repeated Classification classification = 1;
|
||||||
|
// Optional index of the tensor that produced these classifications.
|
||||||
|
optional int32 tensor_index = 2;
|
||||||
|
// Optional name of the tensor that produced these classifications.
|
||||||
|
optional string tensor_name = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group of ClassificationList protos.
|
// Group of ClassificationList protos.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user