Update Destination with Cast to make it aligned with Source.

PiperOrigin-RevId: 479433889
This commit is contained in:
MediaPipe Team 2022-10-06 15:56:59 -07:00 committed by Copybara-Service
parent d0707227e4
commit 71a4680a16
2 changed files with 20 additions and 6 deletions

View File

@ -112,6 +112,14 @@ class MultiPort : public Single {
std::vector<std::unique_ptr<Base>>& vec_;
};
namespace internal_builder {
template <typename T, typename U>
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>>;
} // namespace internal_builder
// These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic>
@ -122,6 +130,13 @@ class DestinationImpl {
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
DestinationImpl<IsSide, U> Cast() {
return DestinationImpl<IsSide, U>(&base_);
}
DestinationBase& base_;
};
@ -165,12 +180,8 @@ class SourceImpl {
return AddTarget(dest);
}
template <typename U>
struct AllowCast
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_);
}

View File

@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output");
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
output_stream: "ANY_OUTPUT:any_type_output"
}
input_stream: "GRAPH_ANY_INPUT:__stream_0"
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}