Add ConcatenateClassificationListCalculator
PiperOrigin-RevId: 485398597
This commit is contained in:
parent
700971de70
commit
c6a64683f6
|
@ -328,6 +328,7 @@ cc_library(
|
||||||
":concatenate_vector_calculator_cc_proto",
|
":concatenate_vector_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -344,6 +345,7 @@ cc_test(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
@ -111,6 +112,22 @@ class ConcatenateLandmarkListCalculator
|
||||||
};
|
};
|
||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListCalculator);
|
||||||
|
|
||||||
|
class ConcatenateClassificationListCalculator
|
||||||
|
: public ConcatenateListsCalculator<Classification, ClassificationList> {
|
||||||
|
protected:
|
||||||
|
int ListSize(const ClassificationList& list) const override {
|
||||||
|
return list.classification_size();
|
||||||
|
}
|
||||||
|
const Classification GetItem(const ClassificationList& list,
|
||||||
|
int idx) const override {
|
||||||
|
return list.classification(idx);
|
||||||
|
}
|
||||||
|
Classification* AddItem(ClassificationList& list) const override {
|
||||||
|
return list.add_classification();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -70,6 +71,16 @@ void AddInputLandmarkLists(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddInputClassificationLists(
|
||||||
|
const std::vector<ClassificationList>& input_classifications_vec,
|
||||||
|
int64 timestamp, CalculatorRunner* runner) {
|
||||||
|
for (int i = 0; i < input_classifications_vec.size(); ++i) {
|
||||||
|
runner->MutableInputs()->Index(i).packets.push_back(
|
||||||
|
MakePacket<ClassificationList>(input_classifications_vec[i])
|
||||||
|
.At(Timestamp(timestamp)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
TEST(ConcatenateNormalizedLandmarkListCalculatorTest, EmptyVectorInputs) {
|
||||||
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
CalculatorRunner runner("ConcatenateNormalizedLandmarkListCalculator",
|
||||||
/*options_string=*/"", /*num_inputs=*/3,
|
/*options_string=*/"", /*num_inputs=*/3,
|
||||||
|
@ -181,4 +192,39 @@ TEST(ConcatenateNormalizedLandmarkListCalculatorTest, OneEmptyStreamNoOutput) {
|
||||||
EXPECT_EQ(0, outputs.size());
|
EXPECT_EQ(0, outputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenateClassificationListCalculatorTest, OneTimestamp) {
|
||||||
|
CalculatorRunner runner("ConcatenateClassificationListCalculator",
|
||||||
|
/*options_string=*/
|
||||||
|
"[mediapipe.ConcatenateVectorCalculatorOptions.ext]: "
|
||||||
|
"{only_emit_if_all_present: true}",
|
||||||
|
/*num_inputs=*/2,
|
||||||
|
/*num_outputs=*/1, /*num_side_packets=*/0);
|
||||||
|
|
||||||
|
auto input_0 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||||
|
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||||
|
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||||
|
)pb");
|
||||||
|
auto input_1 = ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||||
|
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||||
|
)pb");
|
||||||
|
std::vector<ClassificationList> inputs = {input_0, input_1};
|
||||||
|
AddInputClassificationLists(inputs, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
auto result = outputs[0].Get<ClassificationList>();
|
||||||
|
EXPECT_THAT(ParseTextProtoOrDie<ClassificationList>(R"pb(
|
||||||
|
classification: { index: 0 score: 0.2 label: "test_0" }
|
||||||
|
classification: { index: 1 score: 0.3 label: "test_1" }
|
||||||
|
classification: { index: 2 score: 0.4 label: "test_2" }
|
||||||
|
classification: { index: 3 score: 0.2 label: "test_3" }
|
||||||
|
classification: { index: 4 score: 0.3 label: "test_4" }
|
||||||
|
)pb"),
|
||||||
|
EqualsProto(result));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user