Update Destination with Cast to make it aligned with Source.
PiperOrigin-RevId: 479433889
This commit is contained in:
parent
d0707227e4
commit
71a4680a16
|
@ -112,6 +112,14 @@ class MultiPort : public Single {
|
||||||
std::vector<std::unique_ptr<Base>>& vec_;
|
std::vector<std::unique_ptr<Base>>& vec_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal_builder {
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||||
|
!std::is_same_v<T, U>>;
|
||||||
|
|
||||||
|
} // namespace internal_builder
|
||||||
|
|
||||||
// These classes wrap references to the underlying source/destination
|
// These classes wrap references to the underlying source/destination
|
||||||
// endpoints, adding type information and the user-visible API.
|
// endpoints, adding type information and the user-visible API.
|
||||||
template <bool IsSide, typename T = internal::Generic>
|
template <bool IsSide, typename T = internal::Generic>
|
||||||
|
@ -122,6 +130,13 @@ class DestinationImpl {
|
||||||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||||
|
|
||||||
|
template <typename U,
|
||||||
|
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||||
|
DestinationImpl<IsSide, U> Cast() {
|
||||||
|
return DestinationImpl<IsSide, U>(&base_);
|
||||||
|
}
|
||||||
|
|
||||||
DestinationBase& base_;
|
DestinationBase& base_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -165,12 +180,8 @@ class SourceImpl {
|
||||||
return AddTarget(dest);
|
return AddTarget(dest);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U>
|
template <typename U,
|
||||||
struct AllowCast
|
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||||
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
|
||||||
!std::is_same_v<T, U>> {};
|
|
||||||
|
|
||||||
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
|
||||||
SourceImpl<IsSide, U> Cast() {
|
SourceImpl<IsSide, U> Cast() {
|
||||||
return SourceImpl<IsSide, U>(base_);
|
return SourceImpl<IsSide, U>(base_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
||||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||||
any_type_output.SetName("any_type_output");
|
any_type_output.SetName("any_type_output");
|
||||||
|
|
||||||
|
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||||
|
|
||||||
CalculatorGraphConfig expected =
|
CalculatorGraphConfig expected =
|
||||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
node {
|
node {
|
||||||
|
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
||||||
output_stream: "ANY_OUTPUT:any_type_output"
|
output_stream: "ANY_OUTPUT:any_type_output"
|
||||||
}
|
}
|
||||||
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||||
|
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
|
||||||
)pb");
|
)pb");
|
||||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user