diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index bf7f2b399..5af9ee5e0 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -106,6 +106,13 @@ class MultiPort : public Single { return Single{&GetWithAutoGrow(&vec_, index)}; } + template + auto Cast() { + using SingleCastT = + std::invoke_result_t), Single*>; + return MultiPort(&vec_); + } + private: std::vector>& vec_; }; diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 810c52527..3bf3ec198 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, MultiPortIsCastToMultiPort) { + builder::Graph graph; + builder::MultiSource any_input = graph.In("ANY_INPUT"); + builder::MultiSource int_input = any_input.Cast(); + builder::MultiDestination any_output = graph.Out("ANY_OUTPUT"); + builder::MultiDestination int_output = any_output.Cast(); + int_input >> int_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_INPUT:__stream_0" + output_stream: "ANY_OUTPUT:__stream_0" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) { + builder::Graph graph; + builder::MultiSource any_multi_input = graph.In("ANY_INPUT"); + builder::Source any_input = any_multi_input; + builder::MultiDestination any_multi_output = graph.Out("ANY_OUTPUT"); + builder::Destination any_output = any_multi_output; + any_input >> any_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_INPUT:__stream_0" + output_stream: "ANY_OUTPUT:__stream_0" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) { + builder::Graph graph; + builder::Source int_input = graph.In("INT_INPUT").Cast(); + builder::Source any_input = graph.In("ANY_OUTPUT"); + builder::Destination int_output = graph.Out("INT_OUTPUT").Cast(); + builder::Destination any_output = graph.Out("ANY_OUTPUT"); + int_input >> int_output; + any_input >> any_output; + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "ANY_OUTPUT:__stream_0" + input_stream: "INT_INPUT:__stream_1" + output_stream: "ANY_OUTPUT:__stream_0" + output_stream: "INT_OUTPUT:__stream_1" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + } // namespace test } // namespace api2 } // namespace mediapipe