Add proper Cast for MultiPort

PiperOrigin-RevId: 487012509
This commit is contained in:
MediaPipe Team 2022-11-08 11:55:04 -08:00 committed by Copybara-Service
parent 7a1e55b872
commit ace098f370
2 changed files with 58 additions and 0 deletions

View File

@ -106,6 +106,13 @@ class MultiPort : public Single {
return Single{&GetWithAutoGrow(&vec_, index)};
}
template <typename U>
auto Cast() {
using SingleCastT =
std::invoke_result_t<decltype(&Single::template Cast<U>), Single*>;
return MultiPort<SingleCastT>(&vec_);
}
private:
std::vector<std::unique_ptr<Base>>& vec_;
};

View File

@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
builder::MultiSource<int> int_input = any_input.Cast<int>();
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
builder::MultiDestination<int> int_output = any_output.Cast<int>();
int_input >> int_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_INPUT:__stream_0"
output_stream: "ANY_OUTPUT:__stream_0"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
builder::Graph graph;
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
builder::Source<AnyType> any_input = any_multi_input;
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
builder::Destination<AnyType> any_output = any_multi_output;
any_input >> any_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_INPUT:__stream_0"
output_stream: "ANY_OUTPUT:__stream_0"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
builder::Graph graph;
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
int_input >> int_output;
any_input >> any_output;
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "ANY_OUTPUT:__stream_0"
input_stream: "INT_INPUT:__stream_1"
output_stream: "ANY_OUTPUT:__stream_0"
output_stream: "INT_OUTPUT:__stream_1"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test
} // namespace api2
} // namespace mediapipe