concatenate stream utility function.
PiperOrigin-RevId: 568997695
This commit is contained in:
parent
983fda5d4e
commit
2ecccaf076
|
@ -325,6 +325,7 @@ cc_library(
|
||||||
":concatenate_vector_calculator_cc_proto",
|
":concatenate_vector_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
|
@ -128,6 +129,19 @@ class ConcatenateClassificationListCalculator
|
||||||
};
|
};
|
||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
|
||||||
|
|
||||||
|
class ConcatenateJointListCalculator
|
||||||
|
: public ConcatenateListsCalculator<Joint, JointList> {
|
||||||
|
protected:
|
||||||
|
int ListSize(const JointList& list) const override {
|
||||||
|
return list.joint_size();
|
||||||
|
}
|
||||||
|
const Joint GetItem(const JointList& list, int idx) const override {
|
||||||
|
return list.joint(idx);
|
||||||
|
}
|
||||||
|
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ConcatenateJointListCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,38 @@ package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "concatenate",
|
||||||
|
hdrs = ["concatenate.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:concatenate_proto_list_calculator",
|
||||||
|
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||||
|
"//mediapipe/calculators/core:concatenate_vector_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "concatenate_test",
|
||||||
|
srcs = ["concatenate_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":concatenate",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "detections_to_rects",
|
name = "detections_to_rects",
|
||||||
srcs = ["detections_to_rects.cc"],
|
srcs = ["detections_to_rects.cc"],
|
||||||
|
|
69
mediapipe/framework/api2/stream/concatenate.h
Normal file
69
mediapipe/framework/api2/stream/concatenate.h
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
namespace internal_stream_concatenate {
|
||||||
|
|
||||||
|
// Helper function that adds a node to a graph, that is capable of concatenating
|
||||||
|
// a specific type (T).
|
||||||
|
template <class T>
|
||||||
|
GenericNode& AddConcatenateVectorNode(Graph& graph) {
|
||||||
|
if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
|
||||||
|
return graph.AddNode("ConcatenateLandmarkListCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
|
||||||
|
return graph.AddNode("ConcatenateJointListCalculator");
|
||||||
|
} else if constexpr (std::is_same_v<T, std::vector<Tensor>>) {
|
||||||
|
return graph.AddNode("ConcatenateTensorVectorCalculator");
|
||||||
|
} else {
|
||||||
|
static_assert(dependent_false<T>::value,
|
||||||
|
"Concatenate node is not available for the specified type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamsT,
|
||||||
|
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||||
|
Stream<PayloadT> Concatenate(StreamsT& streams,
|
||||||
|
const bool only_emit_if_all_present,
|
||||||
|
Graph& graph) {
|
||||||
|
auto& concatenator = AddConcatenateVectorNode<PayloadT>(graph);
|
||||||
|
for (int i = 0; i < streams.size(); ++i) {
|
||||||
|
streams[i].ConnectTo(concatenator.In("")[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& concatenator_opts =
|
||||||
|
concatenator
|
||||||
|
.template GetOptions<mediapipe::ConcatenateVectorCalculatorOptions>();
|
||||||
|
concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present);
|
||||||
|
|
||||||
|
return concatenator.Out("").template Cast<PayloadT>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace internal_stream_concatenate
|
||||||
|
|
||||||
|
template <typename StreamsT,
|
||||||
|
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||||
|
Stream<PayloadT> Concatenate(StreamsT& streams, Graph& graph) {
|
||||||
|
return internal_stream_concatenate::Concatenate(
|
||||||
|
streams, /*only_emit_if_all_present=*/false, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamsT,
|
||||||
|
typename PayloadT = typename StreamsT::value_type::PayloadT>
|
||||||
|
Stream<PayloadT> ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) {
|
||||||
|
return internal_stream_concatenate::Concatenate(
|
||||||
|
streams, /*only_emit_if_all_present=*/true, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
|
188
mediapipe/framework/api2/stream/concatenate_test.cc
Normal file
188
mediapipe/framework/api2/stream/concatenate_test.cc
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/concatenate.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateLandmarkList) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<LandmarkList>> items = {
|
||||||
|
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
|
||||||
|
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
|
||||||
|
Stream<LandmarkList> landmark_list = Concatenate(items, graph);
|
||||||
|
landmark_list.SetName("landmark_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateLandmarkListCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "landmark_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "LMK_LIST:0:__stream_0"
|
||||||
|
input_stream: "LMK_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateLandmarkList_IfAllPresent) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<LandmarkList>> items = {
|
||||||
|
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
|
||||||
|
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
|
||||||
|
Stream<LandmarkList> landmark_list = ConcatenateIfAllPresent(items, graph);
|
||||||
|
landmark_list.SetName("landmark_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateLandmarkListCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "landmark_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "LMK_LIST:0:__stream_0"
|
||||||
|
input_stream: "LMK_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateJointList) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<JointList>> items = {
|
||||||
|
graph.In("JT_LIST")[0].Cast<JointList>(),
|
||||||
|
graph.In("JT_LIST")[1].Cast<JointList>()};
|
||||||
|
Stream<JointList> joint_list = Concatenate(items, graph);
|
||||||
|
joint_list.SetName("joint_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateJointListCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "joint_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "JT_LIST:0:__stream_0"
|
||||||
|
input_stream: "JT_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateJointList_IfAllPresent) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<JointList>> items = {
|
||||||
|
graph.In("JT_LIST")[0].Cast<JointList>(),
|
||||||
|
graph.In("JT_LIST")[1].Cast<JointList>()};
|
||||||
|
Stream<JointList> joint_list = ConcatenateIfAllPresent(items, graph);
|
||||||
|
joint_list.SetName("joint_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateJointListCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "joint_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "JT_LIST:0:__stream_0"
|
||||||
|
input_stream: "JT_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateTensorVectorList) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<std::vector<Tensor>>> items = {
|
||||||
|
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
|
||||||
|
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
|
||||||
|
Stream<std::vector<Tensor>> tensors = Concatenate(items, graph);
|
||||||
|
tensors.SetName("joint_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateTensorVectorCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "joint_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "VT_LIST:0:__stream_0"
|
||||||
|
input_stream: "VT_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Concatenate, ConcatenateTensorVectorList_IfAllPresent) {
|
||||||
|
Graph graph;
|
||||||
|
std::vector<Stream<std::vector<Tensor>>> items = {
|
||||||
|
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
|
||||||
|
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
|
||||||
|
|
||||||
|
Stream<std::vector<Tensor>> tensors = ConcatenateIfAllPresent(items, graph);
|
||||||
|
tensors.SetName("joint_list");
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "ConcatenateTensorVectorCalculator"
|
||||||
|
input_stream: "__stream_0"
|
||||||
|
input_stream: "__stream_1"
|
||||||
|
output_stream: "joint_list"
|
||||||
|
options {
|
||||||
|
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
|
||||||
|
only_emit_if_all_present: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "VT_LIST:0:__stream_0"
|
||||||
|
input_stream: "VT_LIST:1:__stream_1"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user