Internal change
PiperOrigin-RevId: 483751427
This commit is contained in:
parent
21abfc9125
commit
36bd9abb8f
|
@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
|
||||||
const auto& input_tensors = *kInTensors(cc);
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
auto raw_scores = view.buffer<float>();
|
auto raw_scores = view.buffer<float>();
|
||||||
|
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
if (options.has_tensor_index()) {
|
|
||||||
classification_list->set_tensor_index(options.tensor_index());
|
|
||||||
}
|
|
||||||
if (options.has_tensor_name()) {
|
|
||||||
classification_list->set_tensor_name(options.tensor_name());
|
|
||||||
}
|
|
||||||
if (is_binary_classification_) {
|
if (is_binary_classification_) {
|
||||||
Classification* class_first = classification_list->add_classification();
|
Classification* class_first = classification_list->add_classification();
|
||||||
Classification* class_second = classification_list->add_classification();
|
Classification* class_second = classification_list->add_classification();
|
||||||
|
|
|
@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions {
|
||||||
// that are not in the `allow_classes` field will be completely ignored.
|
// that are not in the `allow_classes` field will be completely ignored.
|
||||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||||
repeated int32 allow_classes = 8 [packed = true];
|
repeated int32 allow_classes = 8 [packed = true];
|
||||||
|
|
||||||
// The optional index of the tensor these classifications originate from.
|
|
||||||
optional int32 tensor_index = 10;
|
|
||||||
// The optional name of the tensor these classifications originate from.
|
|
||||||
optional string tensor_name = 11;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
|
||||||
CorrectOutputWithTensorNameAndIndex) {
|
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
|
||||||
calculator: "TensorsToClassificationCalculator"
|
|
||||||
input_stream: "TENSORS:tensors"
|
|
||||||
output_stream: "CLASSIFICATIONS:classifications"
|
|
||||||
options {
|
|
||||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
|
||||||
tensor_index: 1
|
|
||||||
tensor_name: "foo"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"));
|
|
||||||
|
|
||||||
BuildGraph(&runner, {0, 0.5, 1});
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
|
||||||
|
|
||||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
|
||||||
|
|
||||||
EXPECT_EQ(1, output_packets_.size());
|
|
||||||
|
|
||||||
const auto& classification_list =
|
|
||||||
output_packets_[0].Get<ClassificationList>();
|
|
||||||
EXPECT_EQ(3, classification_list.classification_size());
|
|
||||||
|
|
||||||
// Verify that the tensor_index and tensor_name fields are correctly set.
|
|
||||||
EXPECT_EQ(classification_list.tensor_index(), 1);
|
|
||||||
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
ClassNameAllowlistWithLabelItems) {
|
ClassNameAllowlistWithLabelItems) {
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
|
|
@ -37,10 +37,6 @@ message Classification {
|
||||||
// Group of Classification protos.
|
// Group of Classification protos.
|
||||||
message ClassificationList {
|
message ClassificationList {
|
||||||
repeated Classification classification = 1;
|
repeated Classification classification = 1;
|
||||||
// Optional index of the tensor that produced these classifications.
|
|
||||||
optional int32 tensor_index = 2;
|
|
||||||
// Optional name of the tensor that produced these classifications.
|
|
||||||
optional string tensor_name = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group of ClassificationList protos.
|
// Group of ClassificationList protos.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user