concatenate stream utility function.

PiperOrigin-RevId: 568997695
This commit is contained in:
MediaPipe Team 2023-09-27 16:47:07 -07:00 committed by Copybara-Service
parent 983fda5d4e
commit 2ecccaf076
5 changed files with 304 additions and 0 deletions

View File

@ -325,6 +325,7 @@ cc_library(
":concatenate_vector_calculator_cc_proto", ":concatenate_vector_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",

View File

@ -18,6 +18,7 @@
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
@ -128,6 +129,19 @@ class ConcatenateClassificationListCalculator
}; };
MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListCalculator);
class ConcatenateJointListCalculator
: public ConcatenateListsCalculator<Joint, JointList> {
protected:
int ListSize(const JointList& list) const override {
return list.joint_size();
}
const Joint GetItem(const JointList& list, int idx) const override {
return list.joint(idx);
}
Joint* AddItem(JointList& list) const override { return list.add_joint(); }
};
MEDIAPIPE_REGISTER_NODE(ConcatenateJointListCalculator);
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -2,6 +2,38 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) licenses(["notice"])
cc_library(
name = "concatenate",
hdrs = ["concatenate.h"],
deps = [
"//mediapipe/calculators/core:concatenate_proto_list_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
],
)
cc_test(
name = "concatenate_test",
srcs = ["concatenate_test.cc"],
deps = [
":concatenate",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)
cc_library( cc_library(
name = "detections_to_rects", name = "detections_to_rects",
srcs = ["detections_to_rects.cc"], srcs = ["detections_to_rects.cc"],

View File

@ -0,0 +1,69 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
#include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe::api2::builder {
namespace internal_stream_concatenate {
// Helper function that adds a node to a graph, that is capable of concatenating
// a specific type (T).
template <class T>
GenericNode& AddConcatenateVectorNode(Graph& graph) {
if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
return graph.AddNode("ConcatenateLandmarkListCalculator");
} else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
return graph.AddNode("ConcatenateJointListCalculator");
} else if constexpr (std::is_same_v<T, std::vector<Tensor>>) {
return graph.AddNode("ConcatenateTensorVectorCalculator");
} else {
static_assert(dependent_false<T>::value,
"Concatenate node is not available for the specified type.");
}
}
template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams,
const bool only_emit_if_all_present,
Graph& graph) {
auto& concatenator = AddConcatenateVectorNode<PayloadT>(graph);
for (int i = 0; i < streams.size(); ++i) {
streams[i].ConnectTo(concatenator.In("")[i]);
}
auto& concatenator_opts =
concatenator
.template GetOptions<mediapipe::ConcatenateVectorCalculatorOptions>();
concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present);
return concatenator.Out("").template Cast<PayloadT>();
}
} // namespace internal_stream_concatenate
template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams, Graph& graph) {
return internal_stream_concatenate::Concatenate(
streams, /*only_emit_if_all_present=*/false, graph);
}
template <typename StreamsT,
typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) {
return internal_stream_concatenate::Concatenate(
streams, /*only_emit_if_all_present=*/true, graph);
}
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_

View File

@ -0,0 +1,188 @@
#include "mediapipe/framework/api2/stream/concatenate.h"
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder {
namespace {
TEST(Concatenate, ConcatenateLandmarkList) {
Graph graph;
std::vector<Stream<LandmarkList>> items = {
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
Stream<LandmarkList> landmark_list = Concatenate(items, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateLandmarkListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "landmark_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "LMK_LIST:0:__stream_0"
input_stream: "LMK_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(Concatenate, ConcatenateLandmarkList_IfAllPresent) {
Graph graph;
std::vector<Stream<LandmarkList>> items = {
graph.In("LMK_LIST")[0].Cast<LandmarkList>(),
graph.In("LMK_LIST")[1].Cast<LandmarkList>()};
Stream<LandmarkList> landmark_list = ConcatenateIfAllPresent(items, graph);
landmark_list.SetName("landmark_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateLandmarkListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "landmark_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "LMK_LIST:0:__stream_0"
input_stream: "LMK_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(Concatenate, ConcatenateJointList) {
Graph graph;
std::vector<Stream<JointList>> items = {
graph.In("JT_LIST")[0].Cast<JointList>(),
graph.In("JT_LIST")[1].Cast<JointList>()};
Stream<JointList> joint_list = Concatenate(items, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateJointListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "JT_LIST:0:__stream_0"
input_stream: "JT_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(Concatenate, ConcatenateJointList_IfAllPresent) {
Graph graph;
std::vector<Stream<JointList>> items = {
graph.In("JT_LIST")[0].Cast<JointList>(),
graph.In("JT_LIST")[1].Cast<JointList>()};
Stream<JointList> joint_list = ConcatenateIfAllPresent(items, graph);
joint_list.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateJointListCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "JT_LIST:0:__stream_0"
input_stream: "JT_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(Concatenate, ConcatenateTensorVectorList) {
Graph graph;
std::vector<Stream<std::vector<Tensor>>> items = {
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
Stream<std::vector<Tensor>> tensors = Concatenate(items, graph);
tensors.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateTensorVectorCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: false
}
}
}
input_stream: "VT_LIST:0:__stream_0"
input_stream: "VT_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(Concatenate, ConcatenateTensorVectorList_IfAllPresent) {
Graph graph;
std::vector<Stream<std::vector<Tensor>>> items = {
graph.In("VT_LIST")[0].Cast<std::vector<Tensor>>(),
graph.In("VT_LIST")[1].Cast<std::vector<Tensor>>()};
Stream<std::vector<Tensor>> tensors = ConcatenateIfAllPresent(items, graph);
tensors.SetName("joint_list");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "ConcatenateTensorVectorCalculator"
input_stream: "__stream_0"
input_stream: "__stream_1"
output_stream: "joint_list"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
input_stream: "VT_LIST:0:__stream_0"
input_stream: "VT_LIST:1:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder