concatenate stream utility function.
PiperOrigin-RevId: 568997695
This commit is contained in:
parent
983fda5d4e
commit
2ecccaf076
|
@ -325,6 +325,7 @@ cc_library(
|
|||
":concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
|
@ -128,6 +129,19 @@ class ConcatenateClassificationListCalculator
|
|||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||
|
||||
class ConcatenateJointListCalculator
|
||||
: public ConcatenateListsCalculator<Joint, JointList> {
|
||||
protected:
|
||||
int ListSize(const JointList& list) const override {
|
||||
return list.joint_size();
|
||||
}
|
||||
const Joint GetItem(const JointList& list, int idx) const override {
|
||||
return list.joint(idx);
|
||||
}
|
||||
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateJointListCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -2,6 +2,38 @@ package(default_visibility = ["//visibility:public"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "concatenate",
|
||||
hdrs = ["concatenate.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:concatenate_proto_list_calculator",
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "concatenate_test",
|
||||
srcs = ["concatenate_test.cc"],
|
||||
deps = [
|
||||
":concatenate",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detections_to_rects",
|
||||
srcs = ["detections_to_rects.cc"],
|
||||
|
|
69
mediapipe/framework/api2/stream/concatenate.h
Normal file
69
mediapipe/framework/api2/stream/concatenate.h
Normal file
|
@ -0,0 +1,69 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
namespace internal_stream_concatenate {
|
||||
|
||||
// Helper function that adds a node to a graph, that is capable of concatenating
|
||||
// a specific type (T).
|
||||
template <class T>
|
||||
GenericNode& AddConcatenateVectorNode(Graph& graph) {
|
||||
if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
|
||||
return graph.AddNode("ConcatenateLandmarkListCalculator");
|
||||
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
|
||||
return graph.AddNode("ConcatenateJointListCalculator");
|
||||
} else if constexpr (std::is_same_v<T, std::vector<Tensor>>) {
|
||||
return graph.AddNode("ConcatenateTensorVectorCalculator");
|
||||
} else {
|
||||
static_assert(dependent_false<T>::value,
|
||||
"Concatenate node is not available for the specified type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StreamsT,
|
||||
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||
Stream<PayloadT> Concatenate(StreamsT& streams,
|
||||
const bool only_emit_if_all_present,
|
||||
Graph& graph) {
|
||||
auto& concatenator = AddConcatenateVectorNode<PayloadT>(graph);
|
||||
for (int i = 0; i < streams.size(); ++i) {
|
||||
streams[i].ConnectTo(concatenator.In("")[i]);
|
||||
}
|
||||
|
||||
auto& concatenator_opts =
|
||||
concatenator
|
||||
.template GetOptions<mediapipe::ConcatenateVectorCalculatorOptions>();
|
||||
concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present);
|
||||
|
||||
return concatenator.Out("").template Cast<PayloadT>();
|
||||
}
|
||||
|
||||
} // namespace internal_stream_concatenate
|
||||
|
||||
template <typename StreamsT,
|
||||
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||
Stream<PayloadT> Concatenate(StreamsT& streams, Graph& graph) {
|
||||
return internal_stream_concatenate::Concatenate(
|
||||
streams, /*only_emit_if_all_present=*/false, graph);
|
||||
}
|
||||
|
||||
template <typename StreamsT,
|
||||
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||
Stream<PayloadT> ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) {
|
||||
return internal_stream_concatenate::Concatenate(
|
||||
streams, /*only_emit_if_all_present=*/true, graph);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
188
mediapipe/framework/api2/stream/concatenate_test.cc
Normal file
188
mediapipe/framework/api2/stream/concatenate_test.cc
Normal file
|
@ -0,0 +1,188 @@
|
|||
#include "mediapipe/framework/api2/stream/concatenate.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
TEST(Concatenate, ConcatenateLandmarkList) {
|
||||
Graph graph;
|
||||
std::vector<Stream<LandmarkList>> items = {
|
||||
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
|
||||
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
|
||||
Stream<LandmarkList> landmark_list = Concatenate(items, graph);
|
||||
landmark_list.SetName("landmark_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "landmark_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: false
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LMK_LIST:0:__stream_0"
|
||||
input_stream: "LMK_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(Concatenate, ConcatenateLandmarkList_IfAllPresent) {
|
||||
Graph graph;
|
||||
std::vector<Stream<LandmarkList>> items = {
|
||||
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
|
||||
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
|
||||
Stream<LandmarkList> landmark_list = ConcatenateIfAllPresent(items, graph);
|
||||
landmark_list.SetName("landmark_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateLandmarkListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "landmark_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "LMK_LIST:0:__stream_0"
|
||||
input_stream: "LMK_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(Concatenate, ConcatenateJointList) {
|
||||
Graph graph;
|
||||
std::vector<Stream<JointList>> items = {
|
||||
graph.In("JT_LIST")[0].Cast<JointList>(),
|
||||
graph.In("JT_LIST")[1].Cast<JointList>()};
|
||||
Stream<JointList> joint_list = Concatenate(items, graph);
|
||||
joint_list.SetName("joint_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateJointListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: false
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "JT_LIST:0:__stream_0"
|
||||
input_stream: "JT_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(Concatenate, ConcatenateJointList_IfAllPresent) {
|
||||
Graph graph;
|
||||
std::vector<Stream<JointList>> items = {
|
||||
graph.In("JT_LIST")[0].Cast<JointList>(),
|
||||
graph.In("JT_LIST")[1].Cast<JointList>()};
|
||||
Stream<JointList> joint_list = ConcatenateIfAllPresent(items, graph);
|
||||
joint_list.SetName("joint_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateJointListCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "JT_LIST:0:__stream_0"
|
||||
input_stream: "JT_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(Concatenate, ConcatenateTensorVectorList) {
|
||||
Graph graph;
|
||||
std::vector<Stream<std::vector<Tensor>>> items = {
|
||||
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
|
||||
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
|
||||
Stream<std::vector<Tensor>> tensors = Concatenate(items, graph);
|
||||
tensors.SetName("joint_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: false
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "VT_LIST:0:__stream_0"
|
||||
input_stream: "VT_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(Concatenate, ConcatenateTensorVectorList_IfAllPresent) {
|
||||
Graph graph;
|
||||
std::vector<Stream<std::vector<Tensor>>> items = {
|
||||
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
|
||||
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
|
||||
|
||||
Stream<std::vector<Tensor>> tensors = ConcatenateIfAllPresent(items, graph);
|
||||
tensors.SetName("joint_list");
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "ConcatenateTensorVectorCalculator"
|
||||
input_stream: "__stream_0"
|
||||
input_stream: "__stream_1"
|
||||
output_stream: "joint_list"
|
||||
options {
|
||||
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||
only_emit_if_all_present: true
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "VT_LIST:0:__stream_0"
|
||||
input_stream: "VT_LIST:1:__stream_1"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user