Add proper Cast for MultiPort
PiperOrigin-RevId: 487012509
This commit is contained in:
parent
7a1e55b872
commit
ace098f370
|
@ -106,6 +106,13 @@ class MultiPort : public Single {
|
||||||
return Single{&GetWithAutoGrow(&vec_, index)};
|
return Single{&GetWithAutoGrow(&vec_, index)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
auto Cast() {
|
||||||
|
using SingleCastT =
|
||||||
|
std::invoke_result_t<decltype(&Single::template Cast<U>), Single*>;
|
||||||
|
return MultiPort<SingleCastT>(&vec_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::unique_ptr<Base>>& vec_;
|
std::vector<std::unique_ptr<Base>>& vec_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -445,6 +445,57 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
||||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
|
||||||
|
builder::Graph graph;
|
||||||
|
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
|
||||||
|
builder::MultiSource<int> int_input = any_input.Cast<int>();
|
||||||
|
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||||
|
builder::MultiDestination<int> int_output = any_output.Cast<int>();
|
||||||
|
int_input >> int_output;
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "ANY_INPUT:__stream_0"
|
||||||
|
output_stream: "ANY_OUTPUT:__stream_0"
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
|
||||||
|
builder::Graph graph;
|
||||||
|
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
|
||||||
|
builder::Source<AnyType> any_input = any_multi_input;
|
||||||
|
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
|
||||||
|
builder::Destination<AnyType> any_output = any_multi_output;
|
||||||
|
any_input >> any_output;
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "ANY_INPUT:__stream_0"
|
||||||
|
output_stream: "ANY_OUTPUT:__stream_0"
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
|
||||||
|
builder::Graph graph;
|
||||||
|
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
|
||||||
|
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
|
||||||
|
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
|
||||||
|
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||||
|
int_input >> int_output;
|
||||||
|
any_input >> any_output;
|
||||||
|
|
||||||
|
CalculatorGraphConfig expected =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "ANY_OUTPUT:__stream_0"
|
||||||
|
input_stream: "INT_INPUT:__stream_1"
|
||||||
|
output_stream: "ANY_OUTPUT:__stream_0"
|
||||||
|
output_stream: "INT_OUTPUT:__stream_1"
|
||||||
|
)pb");
|
||||||
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user