diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index b1ebb0410..e02191cb8 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -242,6 +242,8 @@ class Packet : public Packet { friend Packet PacketAdopting(const U* ptr); template friend Packet PacketAdopting(std::unique_ptr ptr); + template + friend Packet PacketSharingOwnership(std::shared_ptr ptr); }; namespace internal { @@ -464,6 +466,17 @@ Packet PacketAdopting(std::unique_ptr ptr) { return Packet(std::make_shared>(ptr.release())); } +template +Packet PacketSharingOwnership(std::shared_ptr ptr) { + return Packet( + std::make_shared>(std::move(ptr))); +} + +template +std::shared_ptr SharedPtrWithPacket(Packet packet) { + return mediapipe::SharedPtrWithPacket(std::move(packet)); +} + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/packet_test.cc b/mediapipe/framework/api2/packet_test.cc index 00bc35086..1180b787a 100644 --- a/mediapipe/framework/api2/packet_test.cc +++ b/mediapipe/framework/api2/packet_test.cc @@ -162,6 +162,21 @@ TEST(PacketTest, PacketAdopting) { EXPECT_FALSE(p.IsEmpty()); } +TEST(PacketTest, PacketSharingOwnership) { + bool deleted = false; + std::shared_ptr object(new int(42), [&deleted](const int* p) { + delete p; + deleted = true; + }); + Packet p = PacketSharingOwnership(object); + EXPECT_FALSE(p.IsEmpty()); + EXPECT_EQ(p.Get(), 42); + object = nullptr; + EXPECT_FALSE(deleted); // Packet keeps it alive. + p = {}; + ASSERT_TRUE(deleted); // last owner expired. +} + TEST(PacketTest, PacketGeneric) { // With C++17, Packet<> could be written simply as Packet. Packet<> p = PacketAdopting(new float(1.0)); @@ -281,6 +296,24 @@ TEST(PacketTest, PolymorphismAbstract) { EXPECT_EQ(base->name(), "ConcreteDerived"); } +TEST(PacketTest, ShareSubobjectOwnership) { + // Create a packet that contains a vector and tracks deletion. + bool deleted = false; + std::shared_ptr> ints(new std::vector{0, 1, 2, 3}, + [&deleted](std::vector* p) { + delete p; + deleted = true; + }); + auto vector_packet = PacketSharingOwnership(std::move(ints)); + // Create a packet that references one of the items in the vector. + Packet item_packet = PacketSharingOwnership(std::shared_ptr( + SharedPtrWithPacket(vector_packet), &vector_packet.Get()[1])); + vector_packet = {}; + ASSERT_FALSE(deleted); // item_packet keeps it alive + item_packet = {}; + ASSERT_TRUE(deleted); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index 1024cbc15..837989a68 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -242,6 +242,12 @@ Packet Adopt(const T* ptr); // returned Packet but also all of its copies. The timestamp of the returned // Packet is Timestamp::Unset(). To set the timestamp, the caller should do // PointToForeign(...).At(...). +// TODO: deprecate PointToForeign in favor of +// MakePacketSharingOwnership. Currently, we have to provide two separate +// implementations for handling static array vs. non static array types as +// the shared_ptr does not work with static array for backward compatibility. +// Eventually we should encourage the clients to deprecate the usage of these +// functions. template Packet PointToForeign(const T* ptr); @@ -324,6 +330,17 @@ Packet MakePacket(Args&&... args) { // NOLINT(build/c++11) new T{std::forward::type>(args)...})); } +// Returns a Packet that shares ownership of its data. The packet will hold a +// reference to the provided shared_ptr throughout its lifetime. Since the +// payload of packets is expected to be immutable, the caller MUST ensure that +// the data does not change as long as the Packet is alive. +// Unlike PointToForeign, which takes a raw pointer, this allows the caller to +// know when MediaPipe (as well as any other owners) is done using the data. +// The timestamp of the returned Packet is Timestamp::Unset(). To set the +// timestamp, the caller should do MakePacketSharingOwnership(...).At(...). +template +Packet MakePacketSharingOwnership(std::shared_ptr ptr); + // Returns a mutable pointer to the data in a unique_ptr in a packet. This // is useful in combination with AdoptAsUniquePtr. The caller must // exercise caution when mutating the retrieved data, since the data @@ -579,11 +596,14 @@ class Holder : public HolderBase { } }; -// Like Holder, but does not own its data. +// Like Holder, but does not exclusively own its data. template class ForeignHolder : public Holder { public: - using Holder::Holder; + explicit ForeignHolder(std::shared_ptr ptr) + : Holder(reinterpret_cast(ptr.get())), + owner_(std::move(ptr)) {} + ~ForeignHolder() override { // Null out ptr_ so it doesn't get deleted by ~Holder. // Note that ~Holder cannot call HasForeignOwner because the subclass's @@ -591,6 +611,9 @@ class ForeignHolder : public Holder { this->ptr_ = nullptr; } bool HasForeignOwner() const final { return true; } + + protected: + const std::shared_ptr owner_; }; template @@ -768,7 +791,22 @@ Packet Adopt(const T* ptr) { template Packet PointToForeign(const T* ptr) { CHECK(ptr != nullptr); - return packet_internal::Create(new packet_internal::ForeignHolder(ptr)); + using U = typename std::shared_ptr::element_type; + // The reinterpret_cast is required here and in the ForeignHolder constructor + // in order to handle the type decay introduced by the shared_ptr for the + // statically allocated array. + return packet_internal::Create(new packet_internal::ForeignHolder( + std::shared_ptr(reinterpret_cast(ptr), [](const U*) { + // Note: PointToForeign does not own its data in any way, so + // the deleter does nothing. + }))); +} + +template +Packet MakePacketSharingOwnership(std::shared_ptr ptr) { + CHECK(ptr != nullptr); + return packet_internal::Create( + new packet_internal::ForeignHolder(std::move(ptr))); } // Equal Packets refer to the same memory contents, like equal pointers. diff --git a/mediapipe/framework/packet_test.cc b/mediapipe/framework/packet_test.cc index 88a8dff43..4efc594ee 100644 --- a/mediapipe/framework/packet_test.cc +++ b/mediapipe/framework/packet_test.cc @@ -14,6 +14,7 @@ #include "mediapipe/framework/packet.h" +#include #include #include #include @@ -262,6 +263,68 @@ TEST(PacketTest, MakePacketOfIntVector) { vector_packet2.Get>()); } +TEST(PacketTest, MakePacketSharingOwnership) { + bool deleted = false; + std::shared_ptr object(new int(42), [&deleted](const int* p) { + delete p; + deleted = true; + }); + Packet packet = MakePacketSharingOwnership(object); + MP_ASSERT_OK(packet.ValidateAsType()); + EXPECT_EQ(packet.Get(), 42); + EXPECT_FALSE(deleted); + object = nullptr; + EXPECT_FALSE(deleted); // Packet keeps it alive. + packet = {}; + EXPECT_TRUE(deleted); // last owner expired. +} + +TEST(PacketTest, ShareSubobjectOwnership) { + // Create a packet that contains a vector and tracks deletion. + bool deleted = false; + std::shared_ptr> ints(new std::vector{0, 1, 2, 3}, + [&deleted](std::vector* p) { + delete p; + deleted = true; + }); + Packet vector_packet = MakePacketSharingOwnership(std::move(ints)); + // Create a packet that references one of the items in the vector. + Packet item_packet = MakePacketSharingOwnership(std::shared_ptr( + SharedPtrWithPacket>(vector_packet), + &vector_packet.Get>()[1])); + vector_packet = {}; + ASSERT_FALSE(deleted); // item_packet keeps it alive + item_packet = {}; + ASSERT_TRUE(deleted); +} + +TEST(PacketTest, PointToForeignDynamicArray) { + int* input_translation = new int[2]; + input_translation[0] = 0; + input_translation[1] = 1; + Packet packet = PointToForeign(&input_translation); + const auto& content = packet.Get(); + // The packet content should point to the array. + EXPECT_EQ(content[0], 0); + EXPECT_EQ(content[1], 1); + packet = {}; + // The vector values should be unaffected. + EXPECT_EQ(input_translation[0], 0); + EXPECT_EQ(input_translation[1], 1); + delete[] input_translation; +} + +TEST(PacketTest, PointToForeignStaticArray) { + const int input_translation[] = {0, 1}; + auto packet = PointToForeign(&input_translation); + const auto& content = packet.Get(); + // The packet content should point to the array. + EXPECT_THAT(content, testing::ElementsAre(0, 1)); + packet = {}; + // The vector values should be unaffected. + EXPECT_THAT(input_translation, testing::ElementsAre(0, 1)); +} + TEST(PacketTest, TestPacketMoveConstructor) { std::vector* packet_vector_ptr = new std::vector(); packet_vector_ptr->push_back(MakePacket(42));