diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 9c4b3af7c..74398be42 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -328,6 +328,7 @@ cc_library( ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,6 +345,7 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc index 9dd0dfd99..6c58e1110 100644 --- a/mediapipe/calculators/core/concatenate_proto_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator.cc @@ -18,6 +18,7 @@ #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" @@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator }; MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator); +class ConcatenateClassificationListCalculator + : public ConcatenateListsCalculator { + protected: + int ListSize(const ClassificationList& list) const override { + return list.classification_size(); + } + const Classification GetItem(const ClassificationList& list, + int idx) const override { + return list.classification(idx); + } + Classification* AddItem(ClassificationList& list) const override { + return list.add_classification(); + } +}; +MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc b/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc index fd116ece7..2167cd9d1 100644 --- a/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_proto_list_calculator_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -70,6 +71,16 @@ void AddInputLandmarkLists( } } +void AddInputClassificationLists( + const std::vector& input_classifications_vec, + int64 timestamp, CalculatorRunner* runner) { + for (int i = 0; i < input_classifications_vec.size(); ++i) { + runner->MutableInputs()->Index(i).packets.push_back( + MakePacket(input_classifications_vec[i]) + .At(Timestamp(timestamp))); + } +} + TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) { CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator", /*options_string=*/"", /*num_inputs=*/3, @@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) { EXPECT_EQ(0, outputs.size()); } +TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) { + CalculatorRunner runner("ConcatenateClassificationListCalculator", + /*options_string=*/ + "[mediapipe.ConcatenateVectorCalculatorOptions.ext]: " + "{only_emit_if_all_present: true}", + /*num_inputs=*/2, + /*num_outputs=*/1, /*num_side_packets=*/0); + + auto input_0 = ParseTextProtoOrDie(R"pb( + classification: { index: 0 score: 0.2 label: "test_0" } + classification: { index: 1 score: 0.3 label: "test_1" } + classification: { index: 2 score: 0.4 label: "test_2" } + )pb"); + auto input_1 = ParseTextProtoOrDie(R"pb( + classification: { index: 3 score: 0.2 label: "test_3" } + classification: { index: 4 score: 0.3 label: "test_4" } + )pb"); + std::vector inputs = {input_0, input_1}; + AddInputClassificationLists(inputs, /*timestamp=*/1, &runner); + MP_ASSERT_OK(runner.Run()); + + const std::vector& outputs = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(Timestamp(1), outputs[0].Timestamp()); + auto result = outputs[0].Get(); + EXPECT_THAT(ParseTextProtoOrDie(R"pb( + classification: { index: 0 score: 0.2 label: "test_0" } + classification: { index: 1 score: 0.3 label: "test_1" } + classification: { index: 2 score: 0.4 label: "test_2" } + classification: { index: 3 score: 0.2 label: "test_3" } + classification: { index: 4 score: 0.3 label: "test_4" } + )pb"), + EqualsProto(result)); +} + } // namespace mediapipe