Update Destination with Cast to make it aligned with Source.
PiperOrigin-RevId: 479433889
This commit is contained in:
parent
d0707227e4
commit
71a4680a16
|
@ -112,6 +112,14 @@ class MultiPort : public Single {
|
|||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
||||
namespace internal_builder {
|
||||
|
||||
template <typename T, typename U>
|
||||
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>>;
|
||||
|
||||
} // namespace internal_builder
|
||||
|
||||
// These classes wrap references to the underlying source/destination
|
||||
// endpoints, adding type information and the user-visible API.
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
|
@ -122,6 +130,13 @@ class DestinationImpl {
|
|||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
DestinationImpl<IsSide, U> Cast() {
|
||||
return DestinationImpl<IsSide, U>(&base_);
|
||||
}
|
||||
|
||||
DestinationBase& base_;
|
||||
};
|
||||
|
||||
|
@ -165,12 +180,8 @@ class SourceImpl {
|
|||
return AddTarget(dest);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
struct AllowCast
|
||||
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>> {};
|
||||
|
||||
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
SourceImpl<IsSide, U> Cast() {
|
||||
return SourceImpl<IsSide, U>(base_);
|
||||
}
|
||||
|
|
|
@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
|
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
output_stream: "ANY_OUTPUT:any_type_output"
|
||||
}
|
||||
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user