From 71a4680a1698761e1d55b4543f3a145536f4041f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 6 Oct 2022 15:56:59 -0700 Subject: [PATCH] Update Destination with Cast to make it aligned with Source. PiperOrigin-RevId: 479433889 --- mediapipe/framework/api2/builder.h | 23 +++++++++++++++++------ mediapipe/framework/api2/builder_test.cc | 3 +++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 20043870c..11bcd21c6 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -112,6 +112,14 @@ class MultiPort : public Single { std::vector>& vec_; }; +namespace internal_builder { + +template +using AllowCast = std::integral_constant && + !std::is_same_v>; + +} // namespace internal_builder + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template @@ -122,6 +130,13 @@ class DestinationImpl { explicit DestinationImpl(std::vector>* vec) : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + + template {}, int> = 0> + DestinationImpl Cast() { + return DestinationImpl(&base_); + } + DestinationBase& base_; }; @@ -165,12 +180,8 @@ class SourceImpl { return AddTarget(dest); } - template - struct AllowCast - : public std::integral_constant && - !std::is_same_v> {}; - - template {}, int> = 0> + template {}, int> = 0> SourceImpl Cast() { return SourceImpl(base_); } diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3244e092d..810c52527 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) { node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); + any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( node { @@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) { output_stream: "ANY_OUTPUT:any_type_output" } input_stream: "GRAPH_ANY_INPUT:__stream_0" + output_stream: "GRAPH_ANY_OUTPUT:any_type_output" )pb"); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); }