// This file defines an API to define a node's ports in a concise, type-safe // way. Example usage in a node: // // static constexpr Input kBase("IN"); // static constexpr Output kOut("OUT"); // static constexpr SideInput::Optional kDelta("DELTA"); // static constexpr SideOutput kForward("FORWARD"); // // Pass a CalculatorContext to a port to access the inputs or outputs in the // context. For example: // // kBase(cc) yields an InputShardAccess // kOut(cc) yields an OutputShardAccess // kDelta(cc) yields an InputSidePacketAccess // kForward(cc) yields an OutputSidePacketAccess #ifndef MEDIAPIPE_FRAMEWORK_API2_PORT_H_ #define MEDIAPIPE_FRAMEWORK_API2_PORT_H_ #include #include #include "absl/strings/str_cat.h" #include "mediapipe/framework/api2/const_str.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/output_side_packet.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { namespace api2 { // typeid is not constexpr, but a pointer to this is. template size_t get_type_hash() { return typeid(T).hash_code(); } using type_id_fptr = size_t (*)(); // This is a base class for various types of port. It is not meant to be used // directly by node code. class PortBase { public: constexpr PortBase(std::size_t tag_size, const char* tag, type_id_fptr get_type_id, bool optional, bool multiple) : tag_(tag_size, tag), optional_(optional), multiple_(multiple), type_id_getter_(get_type_id) {} bool IsOptional() const { return optional_; } bool IsMultiple() const { return multiple_; } const char* Tag() const { return tag_.data(); } size_t type_id() const { return type_id_getter_(); } const const_str tag_; const bool optional_; const bool multiple_; protected: type_id_fptr type_id_getter_; }; // These four base classes are used to distinguish between ports of different // kinds. They are not meant to be used directly by node code. class InputBase : public PortBase { using PortBase::PortBase; }; class OutputBase : public PortBase { using PortBase::PortBase; }; class SideInputBase : public PortBase { using PortBase::PortBase; }; class SideOutputBase : public PortBase { using PortBase::PortBase; }; struct NoneType { private: NoneType() = delete; }; struct DynamicType {}; struct AnyType : public DynamicType {}; template class SameType : public DynamicType { public: static constexpr const decltype(P)& kPort = P; }; class PacketTypeAccess; class PacketTypeAccessFallback; template class InputShardAccess; template class OutputShardAccess; template class InputSidePacketAccess; template class OutputSidePacketAccess; template class InputShardOrSideAccess; namespace internal { // Forward declaration for AddToContract friend. template class Contract; template auto GetCollection(CC* cc, const InputBase& port) -> decltype(cc->Inputs()) { return cc->Inputs(); } template auto GetCollection(CC* cc, const SideInputBase& port) -> decltype(cc->InputSidePackets()) { return cc->InputSidePackets(); } template auto GetCollection(CC* cc, const OutputBase& port) -> decltype(cc->Outputs()) { return cc->Outputs(); } template auto GetCollection(CC* cc, const SideOutputBase& port) -> decltype(cc->OutputSidePackets()) { return cc->OutputSidePackets(); } template auto GetOrNull(Collection& collection, const std::string& tag, int index) -> decltype(&collection.Get(std::declval())) { CollectionItemId id = collection.GetId(tag, index); return id.IsValid() ? &collection.Get(id) : nullptr; } template struct IsOneOf : std::false_type {}; template struct IsOneOf> : std::true_type {}; template {} && !IsOneOf{}, int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.Set(); } template {}, int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag())); } template <> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetAny(); } template <> inline void SetType(CalculatorContract* cc, PacketType& pt) { // This is used for header-only streams. Should it be removed? pt.SetNone(); } template {}, int>::type = 0> inline void SetType(CalculatorContract* cc, PacketType& pt) { pt.SetAny(); } template InputShardAccess SinglePortAccess(mediapipe::CalculatorContext* cc, InputStreamShard* stream) { return InputShardAccess(*cc, stream); } template OutputShardAccess SinglePortAccess(mediapipe::CalculatorContext* cc, OutputStreamShard* stream) { return OutputShardAccess(*cc, stream); } template InputSidePacketAccess SinglePortAccess( mediapipe::CalculatorContext* cc, const mediapipe::Packet* packet) { return InputSidePacketAccess(packet); } template OutputSidePacketAccess SinglePortAccess( mediapipe::CalculatorContext* cc, OutputSidePacket* osp) { return OutputSidePacketAccess(osp); } template InputShardOrSideAccess SinglePortAccess( mediapipe::CalculatorContext* cc, InputStreamShard* stream, const mediapipe::Packet* packet) { return InputShardOrSideAccess(*cc, stream, packet); } template PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc, PacketType* pt); template PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc, PacketType* pt, bool is_stream); template auto AccessPort(std::false_type, const PortT& port, CC* cc) { auto& collection = GetCollection(cc, port); return SinglePortAccess( cc, internal::GetOrNull(collection, port.Tag(), 0)); } template class MultiplePortAccess { public: using AccessT = decltype(SinglePortAccess(std::declval(), std::declval())); MultiplePortAccess(CC* cc, X* first, int count) : cc_(cc), first_(first), count_(count) {} // TODO: maybe this should be size(), like in a standard C++ // container? int Count() { return count_; } AccessT operator[](int pos) { CHECK_GE(pos, 0); CHECK_LT(pos, count_); return SinglePortAccess(cc_, &first_[pos]); } class Iterator { public: using iterator_category = std::input_iterator_tag; using value_type = AccessT; using difference_type = std::ptrdiff_t; using pointer = AccessT*; using reference = AccessT; // allowed; see e.g. std::istreambuf_iterator Iterator(CC* cc, X* p) : cc_(cc), p_(p) {} Iterator& operator++() { ++p_; return *this; } Iterator operator++(int) { Iterator res = *this; ++(*this); return res; } bool operator==(const Iterator& other) const { return p_ == other.p_; } bool operator!=(const Iterator& other) const { return !(*this == other); } AccessT operator*() const { return SinglePortAccess(cc_, p_); } private: CC* cc_; X* p_; }; Iterator begin() { return Iterator(cc_, first_); } Iterator end() { return Iterator(cc_, first_ + count_); } private: CC* cc_; X* first_; int count_; }; template auto AccessPort(std::true_type, const PortT& port, CC* cc) { auto& collection = GetCollection(cc, port); auto* first = internal::GetOrNull(collection, port.Tag(), 0); using EntryT = typename std::remove_pointer::type; return MultiplePortAccess( cc, first, collection.NumEntries(port.Tag())); } template struct SideBase; template <> struct SideBase { using type = SideInputBase; }; } // namespace internal // TODO: maybe return a PacketBase instead of a Packet? template {}, int>::type = 0> auto ActualValueT(T) -> T; auto ActualValueT(DynamicType) -> internal::Generic; template class SideFallbackT; // This template is used to define a port. Nodes should use it through one // of the aliases below (Input, Output, SideInput, SideOutput). template class PortCommon : public Base { public: using value_t = ValueT; static constexpr bool kOptional = IsOptionalV; static constexpr bool kMultiple = IsMultipleV; using Optional = PortCommon; using Multiple = PortCommon; using SideFallback = SideFallbackT; template explicit constexpr PortCommon(const char (&tag)[N]) : Base(N, tag, &get_type_hash, IsOptionalV, IsMultipleV) {} using PayloadT = decltype(ActualValueT(std::declval())); auto operator()(CalculatorContext* cc) const { return internal::AccessPort( std::integral_constant{}, *this, cc); } auto operator()(CalculatorContract* cc) const { return internal::AccessPort( std::integral_constant{}, *this, cc); } private: absl::Status AddToContract(CalculatorContract* cc) const { if (kMultiple) { AddMultiple(cc); } else { auto& pt = internal::GetCollection(cc, *this).Tag(this->Tag()); internal::SetType(cc, pt); if (kOptional) { pt.Optional(); } } return {}; } void AddMultiple(CalculatorContract* cc) const { auto& collection = internal::GetCollection(cc, *this); int count = collection.NumEntries(this->Tag()); for (int i = 0; i < count; ++i) { internal::SetType(cc, collection.Get(this->Tag(), i)); } } template friend class internal::Contract; template friend class mediapipe::api2::SideFallbackT; }; // Use one of these templates to define a port in node code. template using Input = PortCommon; template using Output = PortCommon; template using SideInput = PortCommon; template using SideOutput = PortCommon; template class SideFallbackT : public Base { public: using value_t = ValueT; static constexpr bool kOptional = IsOptionalV; static constexpr bool kMultiple = IsMultipleV; using Optional = SideFallbackT; using PayloadT = decltype(ActualValueT(std::declval())); const char* Tag() const { return stream_port.Tag(); } auto operator()(CalculatorContract* cc) const { bool is_stream = true; auto& stream_collection = internal::GetCollection(cc, stream_port); auto* packet_type = internal::GetOrNull(stream_collection, Tag(), 0); if (packet_type == nullptr) { auto& side_collection = internal::GetCollection(cc, side_port); packet_type = internal::GetOrNull(side_collection, Tag(), 0); is_stream = false; } return internal::SinglePortAccess(cc, packet_type, is_stream); } auto operator()(CalculatorContext* cc) const { auto& stream_collection = internal::GetCollection(cc, stream_port); auto& side_collection = internal::GetCollection(cc, side_port); return internal::SinglePortAccess( cc, internal::GetOrNull(stream_collection, Tag(), 0), internal::GetOrNull(side_collection, Tag(), 0)); } template explicit constexpr SideFallbackT(const char (&tag)[N]) : Base(N, tag, &get_type_hash, IsOptionalV, IsMultipleV), stream_port(tag), side_port(tag) {} protected: absl::Status AddToContract(CalculatorContract* cc) const { stream_port.AddToContract(cc); side_port.AddToContract(cc); int connected_count = stream_port(cc).IsConnected() + side_port(cc).IsConnected(); if (connected_count > 1) return absl::InvalidArgumentError(absl::StrCat( Tag(), " can be connected as a stream or as a side packet, but not both")); if (!IsOptionalV && connected_count == 0) return absl::InvalidArgumentError( absl::StrCat(Tag(), " must be connected")); return {}; } using StreamPort = PortCommon; using SidePort = PortCommon::type, ValueT, true, IsMultipleV>; StreamPort stream_port; SidePort side_port; template friend class internal::Contract; }; // An OutputShardAccess is returned when accessing an output stream within a // CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to // OutputStreamShard. Like that class, this class will not be usually named in // calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)). class OutputShardAccessBase { public: OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output) : context_(cc), output_(output) {} void SetNextTimestampBound(Timestamp timestamp) { if (output_) output_->SetNextTimestampBound(timestamp); } bool IsClosed() { return output_ ? output_->IsClosed() : true; } void Close() { if (output_) output_->Close(); } bool IsConnected() { return output_ != nullptr; } protected: const CalculatorContext& context_; OutputStreamShard* output_; }; template class OutputShardAccess : public OutputShardAccessBase { public: void Send(Packet&& packet) { if (output_) output_->AddPacket(ToOldPacket(std::move(packet))); } void Send(const Packet& packet) { if (output_) output_->AddPacket(ToOldPacket(packet)); } void Send(const T& payload, Timestamp time) { Send(api2::MakePacket(payload).At(time)); } void Send(const T& payload) { Send(payload, context_.InputTimestamp()); } void Send(T&& payload, Timestamp time) { Send(api2::MakePacket(std::move(payload)).At(time)); } void Send(T&& payload) { Send(std::move(payload), context_.InputTimestamp()); } void Send(std::unique_ptr payload, Timestamp time) { Send(api2::PacketAdopting(std::move(payload)).At(time)); } void Send(std::unique_ptr payload) { Send(std::move(payload), context_.InputTimestamp()); } private: OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output) : OutputShardAccessBase(cc, output) {} friend OutputShardAccess internal::SinglePortAccess( mediapipe::CalculatorContext*, OutputStreamShard*); }; template <> class OutputShardAccess : public OutputShardAccessBase { public: void Send(PacketBase&& packet) { if (output_) output_->AddPacket(ToOldPacket(std::move(packet))); } void Send(const PacketBase& packet) { if (output_) output_->AddPacket(ToOldPacket(packet)); } void SetHeader(const PacketBase& header) { if (output_) output_->SetHeader(ToOldPacket(header)); } private: OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output) : OutputShardAccessBase(cc, output) {} friend OutputShardAccess internal::SinglePortAccess(mediapipe::CalculatorContext*, OutputStreamShard*); }; // Equivalent of OutputShardAccess, but for side packets. template class OutputSidePacketAccess { public: void Set(Packet packet) { if (output_) output_->Set(ToOldPacket(std::move(packet))); } void Set(const T& payload) { Set(MakePacket(payload)); } void Set(T&& payload) { Set(MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} OutputSidePacket* output_; friend OutputSidePacketAccess internal::SinglePortAccess( mediapipe::CalculatorContext*, OutputSidePacket*); }; template class InputShardAccess : public Packet { public: const PacketBase& packet() const& { return *this; } // Since InputShardAccess is currently created as a temporary, this avoids // easy mistakes with dangling references. PacketBase packet() const&& { return *this; } bool IsDone() const { return stream_->IsDone(); } bool IsConnected() { return stream_ != nullptr; } PacketBase Header() const { return FromOldPacket(stream_->Header()); } // "Consume" requires exclusive ownership of the packet's payload. In the // current interim implementation, InputShardAccess creates a new reference to // the payload (as a Packet instead of a type-erased Packet), which means // the conditions for Consume would never be satisfied. This helper class // defines wrappers for the Consume methods in Packet which temporarily erase // the reference held by the underlying InputStreamShard. // Note that we cannot simply take over the reference when InputShardAccess is // created, because it is currently created as a temporary and we might create // more than one instance for the same stream. template {}, decltype(&Packet::Consume)>> absl::StatusOr> Consume() { return WrapConsumeCall(&Packet::Consume); } template {}, int> = 0> absl::StatusOr> Consume() { return WrapConsumeCall(&Packet::template Consume); } template auto ConsumeAndVisit(F&&... args) { auto f = &Packet::template ConsumeAndVisit; return WrapConsumeCall(f, std::forward(args)...); } private: InputShardAccess(const CalculatorContext&, InputStreamShard* stream) : Packet(stream ? FromOldPacket(stream->Value()).template As() : Packet()), stream_(stream) {} template auto WrapConsumeCall(F f, A&&... args) { stream_->Value() = {}; auto result = (this->*f)(std::forward(args)...); if (!result.ok()) { stream_->Value() = ToOldPacket(*this); } return result; } InputStreamShard* stream_; friend InputShardAccess internal::SinglePortAccess( mediapipe::CalculatorContext*, InputStreamShard*); }; template class InputSidePacketAccess : public Packet { public: const PacketBase& packet() const& { return *this; } PacketBase packet() const&& { return *this; } bool IsConnected() { return connected_; } private: InputSidePacketAccess(const mediapipe::Packet* packet) : Packet(packet ? FromOldPacket(*packet).template As() : Packet()), connected_(packet != nullptr) {} bool connected_; friend InputSidePacketAccess internal::SinglePortAccess( mediapipe::CalculatorContext*, const mediapipe::Packet*); }; template class InputShardOrSideAccess : public Packet { public: const PacketBase& packet() const& { return *this; } PacketBase packet() const&& { return *this; } bool IsDone() const { return stream_->IsDone(); } bool IsConnected() { return connected_; } bool IsStream() { return stream_ != nullptr; } PacketBase Header() const { return FromOldPacket(stream_->Header()); } private: InputShardOrSideAccess(const CalculatorContext&, InputStreamShard* stream, const mediapipe::Packet* packet) : Packet(stream ? FromOldPacket(stream->Value()).template As() : packet ? FromOldPacket(*packet).template As() : Packet()), stream_(stream), connected_(stream_ != nullptr || packet != nullptr) {} InputStreamShard* stream_; bool connected_; friend InputShardOrSideAccess internal::SinglePortAccess( mediapipe::CalculatorContext*, InputStreamShard*, const mediapipe::Packet*); }; class PacketTypeAccess { public: bool IsConnected() { return packet_type_ != nullptr; } protected: PacketTypeAccess(PacketType* pt) : packet_type_(pt) {} PacketType* packet_type_; template friend PacketTypeAccess internal::SinglePortAccess( mediapipe::CalculatorContract*, PacketType*); }; class PacketTypeAccessFallback : public PacketTypeAccess { public: bool IsStream() { return is_stream_; } private: PacketTypeAccessFallback(PacketType* pt, bool is_stream) : PacketTypeAccess(pt), is_stream_(is_stream) {} bool is_stream_; template friend PacketTypeAccessFallback internal::SinglePortAccess( mediapipe::CalculatorContract*, PacketType*, bool); }; namespace internal { template PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc, PacketType* pt) { return PacketTypeAccess(pt); } template PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc, PacketType* pt, bool is_stream) { return PacketTypeAccessFallback(pt, is_stream); } } // namespace internal } // namespace api2 } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_API2_PORT_H_