Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE
PiperOrigin-RevId: 556063007
This commit is contained in:
parent
c448d54aa7
commit
3ac3b03ed5
|
@ -1655,6 +1655,7 @@ cc_test(
|
|||
":packet",
|
||||
":packet_test_cc_proto",
|
||||
":type_map",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -477,9 +477,6 @@ class GlobalFactoryRegistry {
|
|||
class RegistratorName { \
|
||||
private: \
|
||||
/* The member below triggers instantiation of the registration static. */ \
|
||||
/* Note that the constructor of calculator subclasses is only invoked */ \
|
||||
/* through the registration token, and so we cannot simply use the */ \
|
||||
/* static in theconstructor. */ \
|
||||
typename Internal##RegistratorName<T>::RequireStatics register_; \
|
||||
};
|
||||
|
||||
|
|
|
@ -455,60 +455,37 @@ struct is_concrete_proto_t
|
|||
!std::is_same<proto_ns::MessageLite, T>{} &&
|
||||
!std::is_same<proto_ns::Message, T>{}> {};
|
||||
|
||||
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
||||
template <typename T>
|
||||
struct MessageRegistrationImpl {
|
||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
||||
// This could have been a lambda inside registration's initializer below, but
|
||||
// MSVC has a bug with lambdas, so we put it here as a workaround.
|
||||
static std::unique_ptr<Holder<T>> CreateMessageHolder() {
|
||||
return absl::make_unique<Holder<T>>(new T);
|
||||
}
|
||||
};
|
||||
std::unique_ptr<HolderBase> CreateMessageHolder() {
|
||||
return absl::make_unique<Holder<T>>(new T);
|
||||
}
|
||||
|
||||
// Static members of template classes can be defined in the header.
|
||||
template <typename T>
|
||||
NoDestructor<mediapipe::RegistrationToken>
|
||||
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
||||
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
|
||||
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
||||
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry,
|
||||
T{}.GetTypeName(), CreateMessageHolder<T>)
|
||||
|
||||
// For non-Message payloads, this does nothing.
|
||||
template <typename T, typename Enable = void>
|
||||
struct HolderSupport {
|
||||
static void EnsureStaticInit() {}
|
||||
};
|
||||
struct HolderPayloadRegistrator {};
|
||||
|
||||
// This template ensures that, for each concrete MessageLite subclass that is
|
||||
// stored in a Packet, we register a function that allows us to create a
|
||||
// Holder with the correct payload type from the proto's type name.
|
||||
//
|
||||
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
||||
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
||||
// up to Holder?
|
||||
template <typename T>
|
||||
struct HolderSupport<T,
|
||||
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
|
||||
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
||||
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
||||
// up to Holder?
|
||||
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
|
||||
// For the registration static member to be instantiated, it needs to be
|
||||
// referenced in a context that requires the definition to exist (see ISO/IEC
|
||||
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
|
||||
// We need two different call-sites to cover proto types for which packets
|
||||
// are only ever created (i.e. the protos are only produced by calculators)
|
||||
// and proto types for which packets are only ever consumed (i.e. the protos
|
||||
// are only consumed by calculators).
|
||||
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
|
||||
};
|
||||
struct HolderPayloadRegistrator<
|
||||
T, typename std::enable_if<is_concrete_proto_t<T>{}>::type>
|
||||
: private MessageRegistrator<typename std::remove_cv<T>::type> {};
|
||||
|
||||
template <typename T>
|
||||
class Holder : public HolderBase {
|
||||
class Holder : public HolderBase, private HolderPayloadRegistrator<T> {
|
||||
public:
|
||||
explicit Holder(const T* ptr) : ptr_(ptr) {
|
||||
HolderSupport<T>::EnsureStaticInit();
|
||||
}
|
||||
explicit Holder(const T* ptr) : ptr_(ptr) {}
|
||||
~Holder() override { delete_helper(); }
|
||||
const T& data() const {
|
||||
HolderSupport<T>::EnsureStaticInit();
|
||||
return *ptr_;
|
||||
}
|
||||
const T& data() const { return *ptr_; }
|
||||
TypeId GetTypeId() const final { return kTypeId<T>; }
|
||||
// Releases the underlying data pointer and transfers the ownership to a
|
||||
// unique pointer.
|
||||
|
|
|
@ -12,7 +12,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/packet_test.pb.h"
|
||||
|
@ -24,6 +28,9 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Stream;
|
||||
|
||||
namespace test_ns {
|
||||
|
||||
constexpr char kOutTag[] = "OUT";
|
||||
|
@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator);
|
|||
|
||||
} // namespace test_ns
|
||||
|
||||
TEST(PacketTest, InputTypeRegistration) {
|
||||
TEST(PacketRegistrationTest, InputTypeRegistration) {
|
||||
using testing::Contains;
|
||||
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
|
||||
"mediapipe.InputOnlyProto");
|
||||
|
@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) {
|
|||
Contains("mediapipe.InputOnlyProto"));
|
||||
}
|
||||
|
||||
TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) {
|
||||
CalculatorGraphConfig config;
|
||||
{
|
||||
Graph graph;
|
||||
Stream<mediapipe::InputOnlyProto> input =
|
||||
graph.In(0).SetName("in").Cast<mediapipe::InputOnlyProto>();
|
||||
|
||||
auto& sink_node = graph.AddNode("TestSinkCalculator");
|
||||
input.ConnectTo(sink_node.In(test_ns::kInTag));
|
||||
Stream<int> output = sink_node.Out(test_ns::kOutTag).Cast<int>();
|
||||
|
||||
output.ConnectTo(graph.Out(0)).SetName("out");
|
||||
|
||||
config = graph.GetConfig();
|
||||
}
|
||||
|
||||
CalculatorGraph calculator_graph;
|
||||
MP_ASSERT_OK(calculator_graph.Initialize(std::move(config)));
|
||||
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||
|
||||
int value = 10;
|
||||
auto proto = std::make_unique<mediapipe::InputOnlyProto>();
|
||||
proto->set_x(value);
|
||||
MP_ASSERT_OK(calculator_graph.AddPacketToInputStream(
|
||||
"in", Adopt(proto.release()).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user