Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE

PiperOrigin-RevId: 556063007
This commit is contained in:
MediaPipe Team 2023-08-11 13:11:47 -07:00 committed by Copybara-Service
parent c448d54aa7
commit 3ac3b03ed5
4 changed files with 54 additions and 44 deletions

View File

@ -1655,6 +1655,7 @@ cc_test(
":packet",
":packet_test_cc_proto",
":type_map",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",

View File

@ -477,9 +477,6 @@ class GlobalFactoryRegistry {
class RegistratorName { \
private: \
/* The member below triggers instantiation of the registration static. */ \
/* Note that the constructor of calculator subclasses is only invoked */ \
/* through the registration token, and so we cannot simply use the */ \
/* static in theconstructor. */ \
typename Internal##RegistratorName<T>::RequireStatics register_; \
};

View File

@ -455,60 +455,37 @@ struct is_concrete_proto_t
!std::is_same<proto_ns::MessageLite, T>{} &&
!std::is_same<proto_ns::Message, T>{}> {};
// Registers a message type. T must be a non-cv-qualified concrete proto type.
template <typename T>
struct MessageRegistrationImpl {
static NoDestructor<mediapipe::RegistrationToken> registration;
// This could have been a lambda inside registration's initializer below, but
// MSVC has a bug with lambdas, so we put it here as a workaround.
static std::unique_ptr<Holder<T>> CreateMessageHolder() {
return absl::make_unique<Holder<T>>(new T);
}
};
std::unique_ptr<HolderBase> CreateMessageHolder() {
return absl::make_unique<Holder<T>>(new T);
}
// Static members of template classes can be defined in the header.
template <typename T>
NoDestructor<mediapipe::RegistrationToken>
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
// Registers a message type. T must be a non-cv-qualified concrete proto type.
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry,
T{}.GetTypeName(), CreateMessageHolder<T>)
// For non-Message payloads, this does nothing.
template <typename T, typename Enable = void>
struct HolderSupport {
static void EnsureStaticInit() {}
};
struct HolderPayloadRegistrator {};
// This template ensures that, for each concrete MessageLite subclass that is
// stored in a Packet, we register a function that allows us to create a
// Holder with the correct payload type from the proto's type name.
//
// We must use std::remove_cv to ensure we don't try to register Foo twice if
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
// up to Holder?
template <typename T>
struct HolderSupport<T,
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
// We must use std::remove_cv to ensure we don't try to register Foo twice if
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
// up to Holder?
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
// For the registration static member to be instantiated, it needs to be
// referenced in a context that requires the definition to exist (see ISO/IEC
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
// We need two different call-sites to cover proto types for which packets
// are only ever created (i.e. the protos are only produced by calculators)
// and proto types for which packets are only ever consumed (i.e. the protos
// are only consumed by calculators).
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
};
struct HolderPayloadRegistrator<
T, typename std::enable_if<is_concrete_proto_t<T>{}>::type>
: private MessageRegistrator<typename std::remove_cv<T>::type> {};
template <typename T>
class Holder : public HolderBase {
class Holder : public HolderBase, private HolderPayloadRegistrator<T> {
public:
explicit Holder(const T* ptr) : ptr_(ptr) {
HolderSupport<T>::EnsureStaticInit();
}
explicit Holder(const T* ptr) : ptr_(ptr) {}
~Holder() override { delete_helper(); }
const T& data() const {
HolderSupport<T>::EnsureStaticInit();
return *ptr_;
}
const T& data() const { return *ptr_; }
TypeId GetTypeId() const final { return kTypeId<T>; }
// Releases the underlying data pointer and transfers the ownership to a
// unique pointer.

View File

@ -12,7 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_test.pb.h"
@ -24,6 +28,9 @@
namespace mediapipe {
namespace {
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Stream;
namespace test_ns {
constexpr char kOutTag[] = "OUT";
@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator);
} // namespace test_ns
TEST(PacketTest, InputTypeRegistration) {
TEST(PacketRegistrationTest, InputTypeRegistration) {
using testing::Contains;
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
"mediapipe.InputOnlyProto");
@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) {
Contains("mediapipe.InputOnlyProto"));
}
TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) {
CalculatorGraphConfig config;
{
Graph graph;
Stream<mediapipe::InputOnlyProto> input =
graph.In(0).SetName("in").Cast<mediapipe::InputOnlyProto>();
auto& sink_node = graph.AddNode("TestSinkCalculator");
input.ConnectTo(sink_node.In(test_ns::kInTag));
Stream<int> output = sink_node.Out(test_ns::kOutTag).Cast<int>();
output.ConnectTo(graph.Out(0)).SetName("out");
config = graph.GetConfig();
}
CalculatorGraph calculator_graph;
MP_ASSERT_OK(calculator_graph.Initialize(std::move(config)));
MP_ASSERT_OK(calculator_graph.StartRun({}));
int value = 10;
auto proto = std::make_unique<mediapipe::InputOnlyProto>();
proto->set_x(value);
MP_ASSERT_OK(calculator_graph.AddPacketToInputStream(
"in", Adopt(proto.release()).At(Timestamp(0))));
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
}
} // namespace
} // namespace mediapipe