Migrate packet messages auto registration to rely on MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE
PiperOrigin-RevId: 556063007
This commit is contained in:
parent
c448d54aa7
commit
3ac3b03ed5
|
@ -1655,6 +1655,7 @@ cc_test(
|
||||||
":packet",
|
":packet",
|
||||||
":packet_test_cc_proto",
|
":packet_test_cc_proto",
|
||||||
":type_map",
|
":type_map",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/port:core_proto",
|
"//mediapipe/framework/port:core_proto",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -477,9 +477,6 @@ class GlobalFactoryRegistry {
|
||||||
class RegistratorName { \
|
class RegistratorName { \
|
||||||
private: \
|
private: \
|
||||||
/* The member below triggers instantiation of the registration static. */ \
|
/* The member below triggers instantiation of the registration static. */ \
|
||||||
/* Note that the constructor of calculator subclasses is only invoked */ \
|
|
||||||
/* through the registration token, and so we cannot simply use the */ \
|
|
||||||
/* static in theconstructor. */ \
|
|
||||||
typename Internal##RegistratorName<T>::RequireStatics register_; \
|
typename Internal##RegistratorName<T>::RequireStatics register_; \
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -455,60 +455,37 @@ struct is_concrete_proto_t
|
||||||
!std::is_same<proto_ns::MessageLite, T>{} &&
|
!std::is_same<proto_ns::MessageLite, T>{} &&
|
||||||
!std::is_same<proto_ns::Message, T>{}> {};
|
!std::is_same<proto_ns::Message, T>{}> {};
|
||||||
|
|
||||||
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct MessageRegistrationImpl {
|
std::unique_ptr<HolderBase> CreateMessageHolder() {
|
||||||
static NoDestructor<mediapipe::RegistrationToken> registration;
|
|
||||||
// This could have been a lambda inside registration's initializer below, but
|
|
||||||
// MSVC has a bug with lambdas, so we put it here as a workaround.
|
|
||||||
static std::unique_ptr<Holder<T>> CreateMessageHolder() {
|
|
||||||
return absl::make_unique<Holder<T>>(new T);
|
return absl::make_unique<Holder<T>>(new T);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// Static members of template classes can be defined in the header.
|
// Registers a message type. T must be a non-cv-qualified concrete proto type.
|
||||||
template <typename T>
|
MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(MessageRegistrator, MessageHolderRegistry,
|
||||||
NoDestructor<mediapipe::RegistrationToken>
|
T{}.GetTypeName(), CreateMessageHolder<T>)
|
||||||
MessageRegistrationImpl<T>::registration(MessageHolderRegistry::Register(
|
|
||||||
T{}.GetTypeName(), MessageRegistrationImpl<T>::CreateMessageHolder));
|
|
||||||
|
|
||||||
// For non-Message payloads, this does nothing.
|
// For non-Message payloads, this does nothing.
|
||||||
template <typename T, typename Enable = void>
|
template <typename T, typename Enable = void>
|
||||||
struct HolderSupport {
|
struct HolderPayloadRegistrator {};
|
||||||
static void EnsureStaticInit() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
// This template ensures that, for each concrete MessageLite subclass that is
|
// This template ensures that, for each concrete MessageLite subclass that is
|
||||||
// stored in a Packet, we register a function that allows us to create a
|
// stored in a Packet, we register a function that allows us to create a
|
||||||
// Holder with the correct payload type from the proto's type name.
|
// Holder with the correct payload type from the proto's type name.
|
||||||
|
//
|
||||||
|
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
||||||
|
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
||||||
|
// up to Holder?
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct HolderSupport<T,
|
struct HolderPayloadRegistrator<
|
||||||
typename std::enable_if<is_concrete_proto_t<T>{}>::type> {
|
T, typename std::enable_if<is_concrete_proto_t<T>{}>::type>
|
||||||
// We must use std::remove_cv to ensure we don't try to register Foo twice if
|
: private MessageRegistrator<typename std::remove_cv<T>::type> {};
|
||||||
// there are Holder<Foo> and Holder<const Foo>. TODO: lift this
|
|
||||||
// up to Holder?
|
|
||||||
using R = MessageRegistrationImpl<typename std::remove_cv<T>::type>;
|
|
||||||
// For the registration static member to be instantiated, it needs to be
|
|
||||||
// referenced in a context that requires the definition to exist (see ISO/IEC
|
|
||||||
// C++ 2003 standard, 14.7.1). Calling this ensures that's the case.
|
|
||||||
// We need two different call-sites to cover proto types for which packets
|
|
||||||
// are only ever created (i.e. the protos are only produced by calculators)
|
|
||||||
// and proto types for which packets are only ever consumed (i.e. the protos
|
|
||||||
// are only consumed by calculators).
|
|
||||||
static void EnsureStaticInit() { CHECK(R::registration.get() != nullptr); }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class Holder : public HolderBase {
|
class Holder : public HolderBase, private HolderPayloadRegistrator<T> {
|
||||||
public:
|
public:
|
||||||
explicit Holder(const T* ptr) : ptr_(ptr) {
|
explicit Holder(const T* ptr) : ptr_(ptr) {}
|
||||||
HolderSupport<T>::EnsureStaticInit();
|
|
||||||
}
|
|
||||||
~Holder() override { delete_helper(); }
|
~Holder() override { delete_helper(); }
|
||||||
const T& data() const {
|
const T& data() const { return *ptr_; }
|
||||||
HolderSupport<T>::EnsureStaticInit();
|
|
||||||
return *ptr_;
|
|
||||||
}
|
|
||||||
TypeId GetTypeId() const final { return kTypeId<T>; }
|
TypeId GetTypeId() const final { return kTypeId<T>; }
|
||||||
// Releases the underlying data pointer and transfers the ownership to a
|
// Releases the underlying data pointer and transfers the ownership to a
|
||||||
// unique pointer.
|
// unique pointer.
|
||||||
|
|
|
@ -12,7 +12,11 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/packet_test.pb.h"
|
#include "mediapipe/framework/packet_test.pb.h"
|
||||||
|
@ -24,6 +28,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Stream;
|
||||||
|
|
||||||
namespace test_ns {
|
namespace test_ns {
|
||||||
|
|
||||||
constexpr char kOutTag[] = "OUT";
|
constexpr char kOutTag[] = "OUT";
|
||||||
|
@ -48,7 +55,7 @@ REGISTER_CALCULATOR(TestSinkCalculator);
|
||||||
|
|
||||||
} // namespace test_ns
|
} // namespace test_ns
|
||||||
|
|
||||||
TEST(PacketTest, InputTypeRegistration) {
|
TEST(PacketRegistrationTest, InputTypeRegistration) {
|
||||||
using testing::Contains;
|
using testing::Contains;
|
||||||
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
|
ASSERT_EQ(mediapipe::InputOnlyProto{}.GetTypeName(),
|
||||||
"mediapipe.InputOnlyProto");
|
"mediapipe.InputOnlyProto");
|
||||||
|
@ -56,5 +63,33 @@ TEST(PacketTest, InputTypeRegistration) {
|
||||||
Contains("mediapipe.InputOnlyProto"));
|
Contains("mediapipe.InputOnlyProto"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(PacketRegistrationTest, AdoptingRegisteredProtoWorks) {
|
||||||
|
CalculatorGraphConfig config;
|
||||||
|
{
|
||||||
|
Graph graph;
|
||||||
|
Stream<mediapipe::InputOnlyProto> input =
|
||||||
|
graph.In(0).SetName("in").Cast<mediapipe::InputOnlyProto>();
|
||||||
|
|
||||||
|
auto& sink_node = graph.AddNode("TestSinkCalculator");
|
||||||
|
input.ConnectTo(sink_node.In(test_ns::kInTag));
|
||||||
|
Stream<int> output = sink_node.Out(test_ns::kOutTag).Cast<int>();
|
||||||
|
|
||||||
|
output.ConnectTo(graph.Out(0)).SetName("out");
|
||||||
|
|
||||||
|
config = graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph calculator_graph;
|
||||||
|
MP_ASSERT_OK(calculator_graph.Initialize(std::move(config)));
|
||||||
|
MP_ASSERT_OK(calculator_graph.StartRun({}));
|
||||||
|
|
||||||
|
int value = 10;
|
||||||
|
auto proto = std::make_unique<mediapipe::InputOnlyProto>();
|
||||||
|
proto->set_x(value);
|
||||||
|
MP_ASSERT_OK(calculator_graph.AddPacketToInputStream(
|
||||||
|
"in", Adopt(proto.release()).At(Timestamp(0))));
|
||||||
|
MP_ASSERT_OK(calculator_graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user