From b3f9587bc2d35b70c49d959bc00c96896b120cc4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 5 Oct 2023 14:07:04 -0700 Subject: [PATCH] Add stream API merge utils. PiperOrigin-RevId: 571124981 --- mediapipe/framework/api2/stream/BUILD | 22 ++++++++++++ mediapipe/framework/api2/stream/merge.h | 20 +++++++++++ mediapipe/framework/api2/stream/merge_test.cc | 35 +++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 mediapipe/framework/api2/stream/merge.h create mode 100644 mediapipe/framework/api2/stream/merge_test.cc diff --git a/mediapipe/framework/api2/stream/BUILD b/mediapipe/framework/api2/stream/BUILD index 86e9052dd..127393d37 100644 --- a/mediapipe/framework/api2/stream/BUILD +++ b/mediapipe/framework/api2/stream/BUILD @@ -237,6 +237,28 @@ cc_test( ], ) +cc_library( + name = "merge", + hdrs = ["merge.h"], + deps = [ + "//mediapipe/calculators/core:merge_calculator", + "//mediapipe/framework/api2:builder", + ], +) + +cc_test( + name = "merge_test", + srcs = ["merge_test.cc"], + deps = [ + ":merge", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + ], +) + cc_library( name = "presence", hdrs = ["presence.h"], diff --git a/mediapipe/framework/api2/stream/merge.h b/mediapipe/framework/api2/stream/merge.h new file mode 100644 index 000000000..f3e54f9b0 --- /dev/null +++ b/mediapipe/framework/api2/stream/merge.h @@ -0,0 +1,20 @@ +#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_MERGE_H_ +#define MEDIAPIPE_FRAMEWORK_API2_STREAM_MERGE_H_ + +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe::api2::builder { + +// Updates @graph to choose @a stream if it's available (not empty stream at +// specific timestamp) or @b stream otherwise. +template +Stream Merge(Stream a, Stream b, Graph& graph) { + auto& merge_node = graph.AddNode("MergeCalculator"); + a.ConnectTo(merge_node.In("")[0]); + b.ConnectTo(merge_node.In("")[1]); + return merge_node.Out("").template Cast(); +} + +} // namespace mediapipe::api2::builder + +#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_MERGE_H_ diff --git a/mediapipe/framework/api2/stream/merge_test.cc b/mediapipe/framework/api2/stream/merge_test.cc new file mode 100644 index 000000000..bd1bc0aa3 --- /dev/null +++ b/mediapipe/framework/api2/stream/merge_test.cc @@ -0,0 +1,35 @@ +#include "mediapipe/framework/api2/stream/merge.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe::api2::builder { +namespace { + +TEST(Merge, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Stream input_a = graph.In("INPUT_A").Cast(); + Stream input_b = graph.In("INPUT_B").Cast(); + Stream input = Merge(input_a, input_b, graph); + input.SetName("input"); + + EXPECT_THAT( + graph.GetConfig(), + EqualsProto(mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "MergeCalculator" + input_stream: "__stream_0" + input_stream: "__stream_1" + output_stream: "input" + } + input_stream: "INPUT_A:__stream_0" + input_stream: "INPUT_B:__stream_1" + )pb"))); +} + +} // namespace +} // namespace mediapipe::api2::builder