Internal change

PiperOrigin-RevId: 483751427
This commit is contained in:
MediaPipe Team 2022-10-25 12:51:22 -07:00 committed by Copybara-Service
parent 21abfc9125
commit 36bd9abb8f
4 changed files with 0 additions and 46 deletions

View File

@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
} }
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
const auto& input_tensors = *kInTensors(cc); const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
auto raw_scores = view.buffer<float>(); auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
if (options.has_tensor_index()) {
classification_list->set_tensor_index(options.tensor_index());
}
if (options.has_tensor_name()) {
classification_list->set_tensor_name(options.tensor_name());
}
if (is_binary_classification_) { if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification(); Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification(); Classification* class_second = classification_list->add_classification();

View File

@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions {
// that are not in the `allow_classes` field will be completely ignored. // that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive. // `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true]; repeated int32 allow_classes = 8 [packed = true];
// The optional index of the tensor these classifications originate from.
optional int32 tensor_index = 10;
// The optional name of the tensor these classifications originate from.
optional string tensor_name = 11;
} }

View File

@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest,
} }
} }
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithTensorNameAndIndex) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
tensor_index: 1
tensor_name: "foo"
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the tensor_index and tensor_name fields are correctly set.
EXPECT_EQ(classification_list.tensor_index(), 1);
EXPECT_EQ(classification_list.tensor_name(), "foo");
}
TEST_F(TensorsToClassificationCalculatorTest, TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) { ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -37,10 +37,6 @@ message Classification {
// Group of Classification protos. // Group of Classification protos.
message ClassificationList { message ClassificationList {
repeated Classification classification = 1; repeated Classification classification = 1;
// Optional index of the tensor that produced these classifications.
optional int32 tensor_index = 2;
// Optional name of the tensor that produced these classifications.
optional string tensor_name = 3;
} }
// Group of ClassificationList protos. // Group of ClassificationList protos.