diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 20043870c..11bcd21c6 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -112,6 +112,14 @@ class MultiPort : public Single { std::vector>& vec_; }; +namespace internal_builder { + +template +using AllowCast = std::integral_constant && + !std::is_same_v>; + +} // namespace internal_builder + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template @@ -122,6 +130,13 @@ class DestinationImpl { explicit DestinationImpl(std::vector>* vec) : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + + template {}, int> = 0> + DestinationImpl Cast() { + return DestinationImpl(&base_); + } + DestinationBase& base_; }; @@ -165,12 +180,8 @@ class SourceImpl { return AddTarget(dest); } - template - struct AllowCast - : public std::integral_constant && - !std::is_same_v> {}; - - template {}, int> = 0> + template {}, int> = 0> SourceImpl Cast() { return SourceImpl(base_); } diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3244e092d..810c52527 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) { node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); + any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( node { @@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) { output_stream: "ANY_OUTPUT:any_type_output" } input_stream: "GRAPH_ANY_INPUT:__stream_0" + output_stream: "GRAPH_ANY_OUTPUT:any_type_output" )pb"); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); }