Add ConcatenateClassificationListCalculator

PiperOrigin-RevId: 485398597
This commit is contained in:
MediaPipe Team 2022-11-01 13:12:20 -07:00 committed by Copybara-Service
parent 700971de70
commit c6a64683f6
3 changed files with 65 additions and 0 deletions

View File

@ -328,6 +328,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -344,6 +345,7 @@ cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",

View File

@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
};
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
class ConcatenateClassificationListCalculator
: public ConcatenateListsCalculator<Classification, ClassificationList> {
protected:
int ListSize(const ClassificationList& list) const override {
return list.classification_size();
}
const Classification GetItem(const ClassificationList& list,
int idx) const override {
return list.classification(idx);
}
Classification* AddItem(ClassificationList& list) const override {
return list.add_classification();
}
};
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
}
}
void AddInputClassificationLists(
const std::vector<ClassificationList>& input_classifications_vec,
int64 timestamp, CalculatorRunner* runner) {
for (int i = 0; i < input_classifications_vec.size(); ++i) {
runner->MutableInputs()->Index(i).packets.push_back(
MakePacket<ClassificationList>(input_classifications_vec[i])
.At(Timestamp(timestamp)));
}
}
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
/*options_string=*/"", /*num_inputs=*/3,
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size());
}
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
CalculatorRunner runner("ConcatenateClassificationListCalculator",
/*options_string=*/
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
"{only_emit_if_all_present: true}",
/*num_inputs=*/2,
/*num_outputs=*/1, /*num_side_packets=*/0);
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
)pb");
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb");
std::vector<ClassificationList> inputs = {input_0, input_1};
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
auto result = outputs[0].Get<ClassificationList>();
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
classification: { index: 0 score: 0.2 label: "test_0" }
classification: { index: 1 score: 0.3 label: "test_1" }
classification: { index: 2 score: 0.4 label: "test_2" }
classification: { index: 3 score: 0.2 label: "test_3" }
classification: { index: 4 score: 0.3 label: "test_4" }
)pb"),
EqualsProto(result));
}
} // namespace mediapipe