From 3ac3b03ed59ceedb9b12a90cb44000b29a981b31 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 11 Aug 2023 13:11:47 -0700 Subject: [PATCH] Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE PiperOrigin-RevId: 556063007 --- mediapipe/framework/BUILD | 1 + mediapipe/framework/deps/registration.h | 3 - mediapipe/framework/packet.h | 57 ++++++------------- .../framework/packet_registration_test.cc | 37 +++++++++++- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 721cacc95..3143fc2d8 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1655,6 +1655,7 @@ cc_test( ":packet", ":packet_test_cc_proto", ":type_map", + "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:gtest_main", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index c67f07305..67ab0b161 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -477,9 +477,6 @@ class GlobalFactoryRegistry { class RegistratorName { \ private: \ /* The member below triggers instantiation of the registration static. */ \ - /* Note that the constructor of calculator subclasses is only invoked */ \ - /* through the registration token, and so we cannot simply use the */ \ - /* static in theconstructor. */ \ typename Internal##RegistratorName::RequireStatics register_; \ }; diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index 39c6321c8..4a3399f1c 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -455,60 +455,37 @@ struct is_concrete_proto_t !std::is_same{} && !std::is_same{}> {}; -// Registers a message type. T must be a non-cv-qualified concrete proto type. template -struct MessageRegistrationImpl { - static NoDestructor registration; - // This could have been a lambda inside registration's initializer below, but - // MSVC has a bug with lambdas, so we put it here as a workaround. - static std::unique_ptr> CreateMessageHolder() { - return absl::make_unique>(new T); - } -}; +std::unique_ptr CreateMessageHolder() { + return absl::make_unique>(new T); +} -// Static members of template classes can be defined in the header. -template -NoDestructor - MessageRegistrationImpl::registration(MessageHolderRegistry::Register( - T{}.GetTypeName(), MessageRegistrationImpl::CreateMessageHolder)); +// Registers a message type. T must be a non-cv-qualified concrete proto type. +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry, + T{}.GetTypeName(), CreateMessageHolder) // For non-Message payloads, this does nothing. template -struct HolderSupport { - static void EnsureStaticInit() {} -}; +struct HolderPayloadRegistrator {}; // This template ensures that, for each concrete MessageLite subclass that is // stored in a Packet, we register a function that allows us to create a // Holder with the correct payload type from the proto's type name. +// +// We must use std::remove_cv to ensure we don't try to register Foo twice if +// there are Holder and Holder. TODO: lift this +// up to Holder? template -struct HolderSupport{}>::type> { - // We must use std::remove_cv to ensure we don't try to register Foo twice if - // there are Holder and Holder. TODO: lift this - // up to Holder? - using R = MessageRegistrationImpl::type>; - // For the registration static member to be instantiated, it needs to be - // referenced in a context that requires the definition to exist (see ISO/IEC - // C++ 2003 standard, 14.7.1). Calling this ensures that's the case. - // We need two different call-sites to cover proto types for which packets - // are only ever created (i.e. the protos are only produced by calculators) - // and proto types for which packets are only ever consumed (i.e. the protos - // are only consumed by calculators). - static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); } -}; +struct HolderPayloadRegistrator< + T, typename std::enable_if{}>::type> + : private MessageRegistrator::type> {}; template -class Holder : public HolderBase { +class Holder : public HolderBase, private HolderPayloadRegistrator { public: - explicit Holder(const T* ptr) : ptr_(ptr) { - HolderSupport::EnsureStaticInit(); - } + explicit Holder(const T* ptr) : ptr_(ptr) {} ~Holder() override { delete_helper(); } - const T& data() const { - HolderSupport::EnsureStaticInit(); - return *ptr_; - } + const T& data() const { return *ptr_; } TypeId GetTypeId() const final { return kTypeId; } // Releases the underlying data pointer and transfers the ownership to a // unique pointer. diff --git a/mediapipe/framework/packet_registration_test.cc b/mediapipe/framework/packet_registration_test.cc index 30c7c7893..7b2ea1f79 100644 --- a/mediapipe/framework/packet_registration_test.cc +++ b/mediapipe/framework/packet_registration_test.cc @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "absl/strings/str_cat.h" +#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_test.pb.h" @@ -24,6 +28,9 @@ namespace mediapipe { namespace { +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Stream; + namespace test_ns { constexpr char kOutTag[] = "OUT"; @@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator); } // namespace test_ns -TEST(PacketTest, InputTypeRegistration) { +TEST(PacketRegistrationTest, InputTypeRegistration) { using testing::Contains; ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(), "mediapipe.InputOnlyProto"); @@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) { Contains("mediapipe.InputOnlyProto")); } +TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) { + CalculatorGraphConfig config; + { + Graph graph; + Stream input = + graph.In(0).SetName("in").Cast(); + + auto& sink_node = graph.AddNode("TestSinkCalculator"); + input.ConnectTo(sink_node.In(test_ns::kInTag)); + Stream output = sink_node.Out(test_ns::kOutTag).Cast(); + + output.ConnectTo(graph.Out(0)).SetName("out"); + + config = graph.GetConfig(); + } + + CalculatorGraph calculator_graph; + MP_ASSERT_OK(calculator_graph.Initialize(std::move(config))); + MP_ASSERT_OK(calculator_graph.StartRun({})); + + int value = 10; + auto proto = std::make_unique(); + proto->set_x(value); + MP_ASSERT_OK(calculator_graph.AddPacketToInputStream( + "in", Adopt(proto.release()).At(Timestamp(0)))); + MP_ASSERT_OK(calculator_graph.WaitUntilIdle()); +} + } // namespace } // namespace mediapipe