Add ConcatenateClassificationListCalculator
PiperOrigin-RevId: 485398597
This commit is contained in:
parent
700971de70
commit
c6a64683f6
|
@ -328,6 +328,7 @@ cc_library(
|
|||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -344,6 +345,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
|
|||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
||||
|
||||
class ConcatenateClassificationListCalculator
|
||||
: public ConcatenateListsCalculator<Classification, ClassificationList> {
|
||||
protected:
|
||||
int ListSize(const ClassificationList& list) const override {
|
||||
return list.classification_size();
|
||||
}
|
||||
const Classification GetItem(const ClassificationList& list,
|
||||
int idx) const override {
|
||||
return list.classification(idx);
|
||||
}
|
||||
Classification* AddItem(ClassificationList& list) const override {
|
||||
return list.add_classification();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
|
|||
}
|
||||
}
|
||||
|
||||
void AddInputClassificationLists(
|
||||
const std::vector<ClassificationList>& input_classifications_vec,
|
||||
int64 timestamp, CalculatorRunner* runner) {
|
||||
for (int i = 0; i < input_classifications_vec.size(); ++i) {
|
||||
runner->MutableInputs()->Index(i).packets.push_back(
|
||||
MakePacket<ClassificationList>(input_classifications_vec[i])
|
||||
.At(Timestamp(timestamp)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||
/*options_string=*/"", /*num_inputs=*/3,
|
||||
|
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
|||
EXPECT_EQ(0, outputs.size());
|
||||
}
|
||||
|
||||
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
|
||||
CalculatorRunner runner("ConcatenateClassificationListCalculator",
|
||||
/*options_string=*/
|
||||
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||
"{only_emit_if_all_present: true}",
|
||||
/*num_inputs=*/2,
|
||||
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||
|
||||
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
)pb");
|
||||
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb");
|
||||
std::vector<ClassificationList> inputs = {input_0, input_1};
|
||||
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||
auto result = outputs[0].Get<ClassificationList>();
|
||||
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||
)pb"),
|
||||
EqualsProto(result));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user