Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE
PiperOrigin-RevId: 556063007
This commit is contained in:
		
							parent
							
								
									c448d54aa7
								
							
						
					
					
						commit
						3ac3b03ed5
					
				| 
						 | 
					@ -1655,6 +1655,7 @@ cc_test(
 | 
				
			||||||
        ":packet",
 | 
					        ":packet",
 | 
				
			||||||
        ":packet_test_cc_proto",
 | 
					        ":packet_test_cc_proto",
 | 
				
			||||||
        ":type_map",
 | 
					        ":type_map",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
        "//mediapipe/framework/port:core_proto",
 | 
					        "//mediapipe/framework/port:core_proto",
 | 
				
			||||||
        "//mediapipe/framework/port:gtest_main",
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -477,9 +477,6 @@ class GlobalFactoryRegistry {
 | 
				
			||||||
  class RegistratorName {                                                     \
 | 
					  class RegistratorName {                                                     \
 | 
				
			||||||
   private:                                                                   \
 | 
					   private:                                                                   \
 | 
				
			||||||
    /* The member below triggers instantiation of the registration static. */ \
 | 
					    /* The member below triggers instantiation of the registration static. */ \
 | 
				
			||||||
    /* Note that the constructor of calculator subclasses is only invoked  */ \
 | 
					 | 
				
			||||||
    /* through the registration token, and so we cannot simply use the     */ \
 | 
					 | 
				
			||||||
    /* static in theconstructor.                                           */ \
 | 
					 | 
				
			||||||
    typename Internal##RegistratorName<T>::RequireStatics register_;          \
 | 
					    typename Internal##RegistratorName<T>::RequireStatics register_;          \
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -455,60 +455,37 @@ struct is_concrete_proto_t
 | 
				
			||||||
                    !std::is_same<proto_ns::MessageLite, T>{} &&
 | 
					                    !std::is_same<proto_ns::MessageLite, T>{} &&
 | 
				
			||||||
                    !std::is_same<proto_ns::Message, T>{}> {};
 | 
					                    !std::is_same<proto_ns::Message, T>{}> {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Registers a message type. T must be a non-cv-qualified concrete proto type.
 | 
					 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
struct MessageRegistrationImpl {
 | 
					std::unique_ptr<HolderBase> CreateMessageHolder() {
 | 
				
			||||||
  static NoDestructor<mediapipe::RegistrationToken> registration;
 | 
					 | 
				
			||||||
  // This could have been a lambda inside registration's initializer below, but
 | 
					 | 
				
			||||||
  // MSVC has a bug with lambdas, so we put it here as a workaround.
 | 
					 | 
				
			||||||
  static std::unique_ptr<Holder<T>> CreateMessageHolder() {
 | 
					 | 
				
			||||||
  return absl::make_unique<Holder<T>>(new T);
 | 
					  return absl::make_unique<Holder<T>>(new T);
 | 
				
			||||||
  }
 | 
					}
 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Static members of template classes can be defined in the header.
 | 
					// Registers a message type. T must be a non-cv-qualified concrete proto type.
 | 
				
			||||||
template <typename T>
 | 
					MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry,
 | 
				
			||||||
NoDestructor<mediapipe::RegistrationToken>
 | 
					                                      T{}.GetTypeName(), CreateMessageHolder<T>)
 | 
				
			||||||
    MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
 | 
					 | 
				
			||||||
        T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// For non-Message payloads, this does nothing.
 | 
					// For non-Message payloads, this does nothing.
 | 
				
			||||||
template <typename T, typename Enable = void>
 | 
					template <typename T, typename Enable = void>
 | 
				
			||||||
struct HolderSupport {
 | 
					struct HolderPayloadRegistrator {};
 | 
				
			||||||
  static void EnsureStaticInit() {}
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// This template ensures that, for each concrete MessageLite subclass that is
 | 
					// This template ensures that, for each concrete MessageLite subclass that is
 | 
				
			||||||
// stored in a Packet, we register a function that allows us to create a
 | 
					// stored in a Packet, we register a function that allows us to create a
 | 
				
			||||||
// Holder with the correct payload type from the proto's type name.
 | 
					// Holder with the correct payload type from the proto's type name.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// We must use std::remove_cv to ensure we don't try to register Foo twice if
 | 
				
			||||||
 | 
					// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
 | 
				
			||||||
 | 
					// up to Holder?
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
struct HolderSupport<T,
 | 
					struct HolderPayloadRegistrator<
 | 
				
			||||||
                     typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
 | 
					    T, typename std::enable_if<is_concrete_proto_t<T>{}>::type>
 | 
				
			||||||
  // We must use std::remove_cv to ensure we don't try to register Foo twice if
 | 
					    : private MessageRegistrator<typename std::remove_cv<T>::type> {};
 | 
				
			||||||
  // there are Holder<Foo> and Holder<const Foo>. TODO: lift this
 | 
					 | 
				
			||||||
  // up to Holder?
 | 
					 | 
				
			||||||
  using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
 | 
					 | 
				
			||||||
  // For the registration static member to be instantiated, it needs to be
 | 
					 | 
				
			||||||
  // referenced in a context that requires the definition to exist (see ISO/IEC
 | 
					 | 
				
			||||||
  // C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
 | 
					 | 
				
			||||||
  // We need two different call-sites to cover proto types for which packets
 | 
					 | 
				
			||||||
  // are only ever created (i.e. the protos are only produced by calculators)
 | 
					 | 
				
			||||||
  // and proto types for which packets are only ever consumed (i.e. the protos
 | 
					 | 
				
			||||||
  // are only consumed by calculators).
 | 
					 | 
				
			||||||
  static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
class Holder : public HolderBase {
 | 
					class Holder : public HolderBase, private HolderPayloadRegistrator<T> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  explicit Holder(const T* ptr) : ptr_(ptr) {
 | 
					  explicit Holder(const T* ptr) : ptr_(ptr) {}
 | 
				
			||||||
    HolderSupport<T>::EnsureStaticInit();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  ~Holder() override { delete_helper(); }
 | 
					  ~Holder() override { delete_helper(); }
 | 
				
			||||||
  const T& data() const {
 | 
					  const T& data() const { return *ptr_; }
 | 
				
			||||||
    HolderSupport<T>::EnsureStaticInit();
 | 
					 | 
				
			||||||
    return *ptr_;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  TypeId GetTypeId() const final { return kTypeId<T>; }
 | 
					  TypeId GetTypeId() const final { return kTypeId<T>; }
 | 
				
			||||||
  // Releases the underlying data pointer and transfers the ownership to a
 | 
					  // Releases the underlying data pointer and transfers the ownership to a
 | 
				
			||||||
  // unique pointer.
 | 
					  // unique pointer.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,7 +12,11 @@
 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
// limitations under the License.
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/packet.h"
 | 
					#include "mediapipe/framework/packet.h"
 | 
				
			||||||
#include "mediapipe/framework/packet_test.pb.h"
 | 
					#include "mediapipe/framework/packet_test.pb.h"
 | 
				
			||||||
| 
						 | 
					@ -24,6 +28,9 @@
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::mediapipe::api2::builder::Graph;
 | 
				
			||||||
 | 
					using ::mediapipe::api2::builder::Stream;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace test_ns {
 | 
					namespace test_ns {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr char kOutTag[] = "OUT";
 | 
					constexpr char kOutTag[] = "OUT";
 | 
				
			||||||
| 
						 | 
					@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace test_ns
 | 
					}  // namespace test_ns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(PacketTest, InputTypeRegistration) {
 | 
					TEST(PacketRegistrationTest, InputTypeRegistration) {
 | 
				
			||||||
  using testing::Contains;
 | 
					  using testing::Contains;
 | 
				
			||||||
  ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
 | 
					  ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
 | 
				
			||||||
            "mediapipe.InputOnlyProto");
 | 
					            "mediapipe.InputOnlyProto");
 | 
				
			||||||
| 
						 | 
					@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) {
 | 
				
			||||||
              Contains("mediapipe.InputOnlyProto"));
 | 
					              Contains("mediapipe.InputOnlyProto"));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) {
 | 
				
			||||||
 | 
					  CalculatorGraphConfig config;
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    Graph graph;
 | 
				
			||||||
 | 
					    Stream<mediapipe::InputOnlyProto> input =
 | 
				
			||||||
 | 
					        graph.In(0).SetName("in").Cast<mediapipe::InputOnlyProto>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto& sink_node = graph.AddNode("TestSinkCalculator");
 | 
				
			||||||
 | 
					    input.ConnectTo(sink_node.In(test_ns::kInTag));
 | 
				
			||||||
 | 
					    Stream<int> output = sink_node.Out(test_ns::kOutTag).Cast<int>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output.ConnectTo(graph.Out(0)).SetName("out");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    config = graph.GetConfig();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph calculator_graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(calculator_graph.Initialize(std::move(config)));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(calculator_graph.StartRun({}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int value = 10;
 | 
				
			||||||
 | 
					  auto proto = std::make_unique<mediapipe::InputOnlyProto>();
 | 
				
			||||||
 | 
					  proto->set_x(value);
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(calculator_graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					      "in", Adopt(proto.release()).At(Timestamp(0))));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user