Add proper Cast for MultiPort
PiperOrigin-RevId: 487012509
This commit is contained in:
parent
7a1e55b872
commit
ace098f370
|
@ -106,6 +106,13 @@ class MultiPort : public Single {
|
|||
return Single{&GetWithAutoGrow(&vec_, index)};
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
auto Cast() {
|
||||
using SingleCastT =
|
||||
std::invoke_result_t<decltype(&Single::template Cast<U>), Single*>;
|
||||
return MultiPort<SingleCastT>(&vec_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
|
|
@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
|
||||
builder::MultiSource<int> int_input = any_input.Cast<int>();
|
||||
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
builder::MultiDestination<int> int_output = any_output.Cast<int>();
|
||||
int_input >> int_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_INPUT:__stream_0"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
|
||||
builder::Source<AnyType> any_input = any_multi_input;
|
||||
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
|
||||
builder::Destination<AnyType> any_output = any_multi_output;
|
||||
any_input >> any_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_INPUT:__stream_0"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
|
||||
builder::Graph graph;
|
||||
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
|
||||
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
|
||||
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
|
||||
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
int_input >> int_output;
|
||||
any_input >> any_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "ANY_OUTPUT:__stream_0"
|
||||
input_stream: "INT_INPUT:__stream_1"
|
||||
output_stream: "ANY_OUTPUT:__stream_0"
|
||||
output_stream: "INT_OUTPUT:__stream_1"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user