Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE

PiperOrigin-RevId: 556063007
This commit is contained in:
MediaPipe Team 2023-08-11 13:11:47 -07:00 committed by Copybara-Service
parent c448d54aa7
commit 3ac3b03ed5
4 changed files with 54 additions and 44 deletions

View File

@ -1655,6 +1655,7 @@ cc_test(
":packet", ":packet",
":packet_test_cc_proto", ":packet_test_cc_proto",
":type_map", ":type_map",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -477,9 +477,6 @@ class GlobalFactoryRegistry {
class RegistratorName { \ class RegistratorName { \
private: \ private: \
/* The member below triggers instantiation of the registration static. */ \ /* The member below triggers instantiation of the registration static. */ \
/* Note that the constructor of calculator subclasses is only invoked */ \
/* through the registration token, and so we cannot simply use the */ \
/* static in theconstructor. */ \
typename Internal##RegistratorName<T>::RequireStatics register_; \ typename Internal##RegistratorName<T>::RequireStatics register_; \
}; };

View File

@ -455,60 +455,37 @@ struct is_concrete_proto_t
!std::is_same<proto_ns::MessageLite, T>{} && !std::is_same<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::Message, T>{}> {}; !std::is_same<proto_ns::Message, T>{}> {};
// Registers a message type. T must be a non-cv-qualified concrete proto type.
template <typename T> template <typename T>
struct MessageRegistrationImpl { std::unique_ptr<HolderBase> CreateMessageHolder() {
static NoDestructor<mediapipe::RegistrationToken> registration;
// This could have been a lambda inside registration's initializer below, but
// MSVC has a bug with lambdas, so we put it here as a workaround.
static std::unique_ptr<Holder<T>> CreateMessageHolder() {
return absl::make_unique<Holder<T>>(new T); return absl::make_unique<Holder<T>>(new T);
} }
};
// Static members of template classes can be defined in the header. // Registers a message type. T must be a non-cv-qualified concrete proto type.
template <typename T> MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry,
NoDestructor<mediapipe::RegistrationToken> T{}.GetTypeName(), CreateMessageHolder<T>)
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
// For non-Message payloads, this does nothing. // For non-Message payloads, this does nothing.
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct HolderSupport { struct HolderPayloadRegistrator {};
static void EnsureStaticInit() {}
};
// This template ensures that, for each concrete MessageLite subclass that is // This template ensures that, for each concrete MessageLite subclass that is
// stored in a Packet, we register a function that allows us to create a // stored in a Packet, we register a function that allows us to create a
// Holder with the correct payload type from the proto's type name. // Holder with the correct payload type from the proto's type name.
template <typename T> //
struct HolderSupport<T,
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
// We must use std::remove_cv to ensure we don't try to register Foo twice if // We must use std::remove_cv to ensure we don't try to register Foo twice if
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this // there are Holder<Foo> and Holder<const Foo>. TODO: lift this
// up to Holder? // up to Holder?
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>; template <typename T>
// For the registration static member to be instantiated, it needs to be struct HolderPayloadRegistrator<
// referenced in a context that requires the definition to exist (see ISO/IEC T, typename std::enable_if<is_concrete_proto_t<T>{}>::type>
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case. : private MessageRegistrator<typename std::remove_cv<T>::type> {};
// We need two different call-sites to cover proto types for which packets
// are only ever created (i.e. the protos are only produced by calculators)
// and proto types for which packets are only ever consumed (i.e. the protos
// are only consumed by calculators).
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
};
template <typename T> template <typename T>
class Holder : public HolderBase { class Holder : public HolderBase, private HolderPayloadRegistrator<T> {
public: public:
explicit Holder(const T* ptr) : ptr_(ptr) { explicit Holder(const T* ptr) : ptr_(ptr) {}
HolderSupport<T>::EnsureStaticInit();
}
~Holder() override { delete_helper(); } ~Holder() override { delete_helper(); }
const T& data() const { const T& data() const { return *ptr_; }
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
TypeId GetTypeId() const final { return kTypeId<T>; } TypeId GetTypeId() const final { return kTypeId<T>; }
// Releases the underlying data pointer and transfers the ownership to a // Releases the underlying data pointer and transfers the ownership to a
// unique pointer. // unique pointer.

View File

@ -12,7 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <utility>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_test.pb.h" #include "mediapipe/framework/packet_test.pb.h"
@ -24,6 +28,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
namespace test_ns { namespace test_ns {
constexpr char kOutTag[] = "OUT"; constexpr char kOutTag[] = "OUT";
@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator);
} // namespace test_ns } // namespace test_ns
TEST(PacketTest, InputTypeRegistration) { TEST(PacketRegistrationTest, InputTypeRegistration) {
using testing::Contains; using testing::Contains;
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(), ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
"mediapipe.InputOnlyProto"); "mediapipe.InputOnlyProto");
@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) {
Contains("mediapipe.InputOnlyProto")); Contains("mediapipe.InputOnlyProto"));
} }
TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) {
CalculatorGraphConfig config;
{
Graph graph;
Stream<mediapipe::InputOnlyProto> input =
graph.In(0).SetName("in").Cast<mediapipe::InputOnlyProto>();
auto& sink_node = graph.AddNode("TestSinkCalculator");
input.ConnectTo(sink_node.In(test_ns::kInTag));
Stream<int> output = sink_node.Out(test_ns::kOutTag).Cast<int>();
output.ConnectTo(graph.Out(0)).SetName("out");
config = graph.GetConfig();
}
CalculatorGraph calculator_graph;
MP_ASSERT_OK(calculator_graph.Initialize(std::move(config)));
MP_ASSERT_OK(calculator_graph.StartRun({}));
int value = 10;
auto proto = std::make_unique<mediapipe::InputOnlyProto>();
proto->set_x(value);
MP_ASSERT_OK(calculator_graph.AddPacketToInputStream(
"in", Adopt(proto.release()).At(Timestamp(0))));
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe