720 lines
22 KiB
C++
720 lines
22 KiB
C++
// This file defines an API to define a node's ports in a concise, type-safe
|
|
// way. Example usage in a node:
|
|
//
|
|
// static constexpr Input<int> kBase("IN");
|
|
// static constexpr Output<float> kOut("OUT");
|
|
// static constexpr SideInput<float>::Optional kDelta("DELTA");
|
|
// static constexpr SideOutput<float> kForward("FORWARD");
|
|
//
|
|
// Pass a CalculatorContext to a port to access the inputs or outputs in the
|
|
// context. For example:
|
|
//
|
|
// kBase(cc) yields an InputShardAccess<int>
|
|
// kOut(cc) yields an OutputShardAccess<float>
|
|
// kDelta(cc) yields an InputSidePacketAccess<float>
|
|
// kForward(cc) yields an OutputSidePacketAccess<float>
|
|
|
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|
|
#define MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|
|
|
|
#include <type_traits>
|
|
#include <utility>
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "mediapipe/framework/api2/const_str.h"
|
|
#include "mediapipe/framework/api2/packet.h"
|
|
#include "mediapipe/framework/calculator_context.h"
|
|
#include "mediapipe/framework/calculator_contract.h"
|
|
#include "mediapipe/framework/output_side_packet.h"
|
|
#include "mediapipe/framework/port/logging.h"
|
|
#include "mediapipe/framework/tool/type_util.h"
|
|
|
|
namespace mediapipe {
|
|
namespace api2 {
|
|
|
|
// This is a base class for various types of port. It is not meant to be used
|
|
// directly by node code.
|
|
class PortBase {
|
|
public:
|
|
constexpr PortBase(std::size_t tag_size, const char* tag, TypeId type_id,
|
|
bool optional, bool multiple)
|
|
: tag_(tag_size, tag),
|
|
optional_(optional),
|
|
multiple_(multiple),
|
|
type_id_(type_id) {}
|
|
|
|
bool IsOptional() const { return optional_; }
|
|
bool IsMultiple() const { return multiple_; }
|
|
const char* Tag() const { return tag_.data(); }
|
|
|
|
TypeId type_id() const { return type_id_; }
|
|
|
|
const const_str tag_;
|
|
const bool optional_;
|
|
const bool multiple_;
|
|
|
|
protected:
|
|
TypeId type_id_;
|
|
};
|
|
|
|
// These four base classes are used to distinguish between ports of different
|
|
// kinds. They are not meant to be used directly by node code.
|
|
class InputBase : public PortBase {
|
|
using PortBase::PortBase;
|
|
};
|
|
class OutputBase : public PortBase {
|
|
using PortBase::PortBase;
|
|
};
|
|
class SideInputBase : public PortBase {
|
|
using PortBase::PortBase;
|
|
};
|
|
class SideOutputBase : public PortBase {
|
|
using PortBase::PortBase;
|
|
};
|
|
|
|
struct NoneType {
|
|
private:
|
|
NoneType() = delete;
|
|
};
|
|
|
|
template <auto& P>
|
|
class SameType : public DynamicType {
|
|
public:
|
|
static constexpr const decltype(P)& kPort = P;
|
|
};
|
|
|
|
class PacketTypeAccess;
|
|
class PacketTypeAccessFallback;
|
|
template <typename T>
|
|
class InputShardAccess;
|
|
template <typename T>
|
|
class OutputShardAccess;
|
|
template <typename T>
|
|
class InputSidePacketAccess;
|
|
template <typename T>
|
|
class OutputSidePacketAccess;
|
|
template <typename T>
|
|
class InputShardOrSideAccess;
|
|
|
|
namespace internal {
|
|
|
|
// Forward declaration for AddToContract friend.
|
|
template <typename...>
|
|
class Contract;
|
|
|
|
template <class CC>
|
|
auto GetCollection(CC* cc, const InputBase& port) -> decltype(cc->Inputs()) {
|
|
return cc->Inputs();
|
|
}
|
|
|
|
template <class CC>
|
|
auto GetCollection(CC* cc, const SideInputBase& port)
|
|
-> decltype(cc->InputSidePackets()) {
|
|
return cc->InputSidePackets();
|
|
}
|
|
|
|
template <class CC>
|
|
auto GetCollection(CC* cc, const OutputBase& port) -> decltype(cc->Outputs()) {
|
|
return cc->Outputs();
|
|
}
|
|
|
|
template <class CC>
|
|
auto GetCollection(CC* cc, const SideOutputBase& port)
|
|
-> decltype(cc->OutputSidePackets()) {
|
|
return cc->OutputSidePackets();
|
|
}
|
|
|
|
template <class Collection>
|
|
auto GetOrNull(Collection& collection, const std::string& tag, int index)
|
|
-> decltype(&collection.Get(std::declval<CollectionItemId>())) {
|
|
CollectionItemId id = collection.GetId(tag, index);
|
|
return id.IsValid() ? &collection.Get(id) : nullptr;
|
|
}
|
|
|
|
template <class T>
|
|
struct IsOneOf : std::false_type {};
|
|
|
|
template <class... T>
|
|
struct IsOneOf<OneOf<T...>> : std::true_type {};
|
|
|
|
template <typename T, typename std::enable_if<
|
|
!std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{},
|
|
int>::type = 0>
|
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
|
pt.Set<T>();
|
|
}
|
|
|
|
template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{},
|
|
int>::type = 0>
|
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
|
pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag()));
|
|
}
|
|
|
|
template <>
|
|
inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) {
|
|
pt.SetAny();
|
|
}
|
|
|
|
template <>
|
|
inline void SetType<NoneType>(CalculatorContract* cc, PacketType& pt) {
|
|
// This is used for header-only streams. Should it be removed?
|
|
pt.SetNone();
|
|
}
|
|
|
|
template <typename... T>
|
|
inline void SetTypeOneOf(OneOf<T...>, CalculatorContract* cc, PacketType& pt) {
|
|
pt.SetOneOf<T...>();
|
|
}
|
|
|
|
template <typename T, typename std::enable_if<IsOneOf<T>{}, int>::type = 0>
|
|
inline void SetType(CalculatorContract* cc, PacketType& pt) {
|
|
SetTypeOneOf(T{}, cc, pt);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
InputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
|
|
InputStreamShard* stream) {
|
|
return InputShardAccess<ValueT>(*cc, stream);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
OutputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
|
|
OutputStreamShard* stream) {
|
|
return OutputShardAccess<ValueT>(*cc, stream);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
InputSidePacketAccess<ValueT> SinglePortAccess(
|
|
mediapipe::CalculatorContext* cc, const mediapipe::Packet* packet) {
|
|
return InputSidePacketAccess<ValueT>(packet);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
OutputSidePacketAccess<ValueT> SinglePortAccess(
|
|
mediapipe::CalculatorContext* cc, OutputSidePacket* osp) {
|
|
return OutputSidePacketAccess<ValueT>(osp);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
InputShardOrSideAccess<ValueT> SinglePortAccess(
|
|
mediapipe::CalculatorContext* cc, InputStreamShard* stream,
|
|
const mediapipe::Packet* packet) {
|
|
return InputShardOrSideAccess<ValueT>(*cc, stream, packet);
|
|
}
|
|
|
|
template <typename ValueT>
|
|
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
|
|
PacketType* pt);
|
|
|
|
template <typename ValueT>
|
|
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
|
|
PacketType* pt, bool is_stream);
|
|
|
|
template <typename ValueT, typename PortT, class CC>
|
|
auto AccessPort(std::false_type, const PortT& port, CC* cc) {
|
|
auto& collection = GetCollection(cc, port);
|
|
return SinglePortAccess<ValueT>(
|
|
cc, internal::GetOrNull(collection, port.Tag(), 0));
|
|
}
|
|
|
|
template <typename ValueT, typename X, class CC>
|
|
class MultiplePortAccess {
|
|
public:
|
|
using AccessT = decltype(SinglePortAccess<ValueT>(std::declval<CC*>(),
|
|
std::declval<X*>()));
|
|
|
|
MultiplePortAccess(CC* cc, X* first, int count)
|
|
: cc_(cc), first_(first), count_(count) {}
|
|
|
|
// TODO: maybe this should be size(), like in a standard C++
|
|
// container?
|
|
int Count() { return count_; }
|
|
AccessT operator[](int pos) {
|
|
CHECK_GE(pos, 0);
|
|
CHECK_LT(pos, count_);
|
|
return SinglePortAccess<ValueT>(cc_, &first_[pos]);
|
|
}
|
|
|
|
class Iterator {
|
|
public:
|
|
using iterator_category = std::input_iterator_tag;
|
|
using value_type = AccessT;
|
|
using difference_type = std::ptrdiff_t;
|
|
using pointer = AccessT*;
|
|
using reference = AccessT; // allowed; see e.g. std::istreambuf_iterator
|
|
|
|
Iterator(CC* cc, X* p) : cc_(cc), p_(p) {}
|
|
Iterator& operator++() {
|
|
++p_;
|
|
return *this;
|
|
}
|
|
Iterator operator++(int) {
|
|
Iterator res = *this;
|
|
++(*this);
|
|
return res;
|
|
}
|
|
bool operator==(const Iterator& other) const { return p_ == other.p_; }
|
|
bool operator!=(const Iterator& other) const { return !(*this == other); }
|
|
AccessT operator*() const { return SinglePortAccess<ValueT>(cc_, p_); }
|
|
|
|
private:
|
|
CC* cc_;
|
|
X* p_;
|
|
};
|
|
|
|
Iterator begin() { return Iterator(cc_, first_); }
|
|
Iterator end() { return Iterator(cc_, first_ + count_); }
|
|
|
|
private:
|
|
CC* cc_;
|
|
X* first_;
|
|
int count_;
|
|
};
|
|
|
|
template <typename ValueT, typename PortT, class CC>
|
|
auto AccessPort(std::true_type, const PortT& port, CC* cc) {
|
|
auto& collection = GetCollection(cc, port);
|
|
auto* first = internal::GetOrNull(collection, port.Tag(), 0);
|
|
using EntryT = typename std::remove_pointer<decltype(first)>::type;
|
|
return MultiplePortAccess<ValueT, EntryT, CC>(
|
|
cc, first, collection.NumEntries(port.Tag()));
|
|
}
|
|
|
|
template <class Base>
|
|
struct SideBase;
|
|
|
|
template <>
|
|
struct SideBase<InputBase> {
|
|
using type = SideInputBase;
|
|
};
|
|
|
|
// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
|
|
template <typename T, class = void>
|
|
struct ActualPayloadType {
|
|
using type = T;
|
|
};
|
|
|
|
template <typename T>
|
|
struct ActualPayloadType<
|
|
T, std::enable_if_t<std::is_base_of<DynamicType, T>{}, void>> {
|
|
using type = internal::Generic;
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// Maps special port value types, such as AnyType, to internal::Generic.
|
|
template <typename T>
|
|
using ActualPayloadT = typename internal::ActualPayloadType<T>::type;
|
|
|
|
static_assert(std::is_same_v<ActualPayloadT<int>, int>, "");
|
|
static_assert(std::is_same_v<ActualPayloadT<AnyType>, internal::Generic>, "");
|
|
|
|
template <typename Base, typename ValueT, bool IsOptional = false,
|
|
bool IsMultiple = false>
|
|
class SideFallbackT;
|
|
|
|
// This template is used to define a port. Nodes should use it through one
|
|
// of the aliases below (Input, Output, SideInput, SideOutput).
|
|
template <typename Base, typename ValueT, bool IsOptionalV = false,
|
|
bool IsMultipleV = false>
|
|
class PortCommon : public Base {
|
|
public:
|
|
using value_t = ValueT;
|
|
static constexpr bool kOptional = IsOptionalV;
|
|
static constexpr bool kMultiple = IsMultipleV;
|
|
|
|
using Optional = PortCommon<Base, ValueT, true, IsMultipleV>;
|
|
using Multiple = PortCommon<Base, ValueT, IsOptionalV, true>;
|
|
using SideFallback = SideFallbackT<Base, ValueT, IsOptionalV, IsMultipleV>;
|
|
|
|
template <std::size_t N>
|
|
explicit constexpr PortCommon(const char (&tag)[N])
|
|
: Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV) {}
|
|
|
|
using PayloadT = ActualPayloadT<ValueT>;
|
|
|
|
auto operator()(CalculatorContext* cc) const {
|
|
return internal::AccessPort<PayloadT>(
|
|
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
|
|
}
|
|
|
|
auto operator()(CalculatorContract* cc) const {
|
|
return internal::AccessPort<PayloadT>(
|
|
std::integral_constant<bool, IsMultipleV>{}, *this, cc);
|
|
}
|
|
|
|
private:
|
|
absl::Status AddToContract(CalculatorContract* cc) const {
|
|
if (kMultiple) {
|
|
AddMultiple(cc);
|
|
} else {
|
|
auto& pt = internal::GetCollection(cc, *this).Tag(this->Tag());
|
|
internal::SetType<value_t>(cc, pt);
|
|
if (kOptional) {
|
|
pt.Optional();
|
|
}
|
|
}
|
|
return {};
|
|
}
|
|
|
|
void AddMultiple(CalculatorContract* cc) const {
|
|
auto& collection = internal::GetCollection(cc, *this);
|
|
int count = collection.NumEntries(this->Tag());
|
|
for (int i = 0; i < count; ++i) {
|
|
internal::SetType<value_t>(cc, collection.Get(this->Tag(), i));
|
|
}
|
|
}
|
|
|
|
template <typename...>
|
|
friend class internal::Contract;
|
|
template <typename B, typename VT, bool, bool>
|
|
friend class mediapipe::api2::SideFallbackT;
|
|
};
|
|
|
|
// Use one of these templates to define a port in node code.
|
|
template <typename T = internal::Generic>
|
|
using Input = PortCommon<InputBase, T>;
|
|
|
|
template <typename T = internal::Generic>
|
|
using Output = PortCommon<OutputBase, T>;
|
|
|
|
template <typename T = internal::Generic>
|
|
using SideInput = PortCommon<SideInputBase, T>;
|
|
|
|
template <typename T = internal::Generic>
|
|
using SideOutput = PortCommon<SideOutputBase, T>;
|
|
|
|
template <typename Base, typename ValueT, bool IsOptionalV, bool IsMultipleV>
|
|
class SideFallbackT : public Base {
|
|
public:
|
|
using value_t = ValueT;
|
|
static constexpr bool kOptional = IsOptionalV;
|
|
static constexpr bool kMultiple = IsMultipleV;
|
|
using Optional = SideFallbackT<Base, ValueT, true, IsMultipleV>;
|
|
using PayloadT = ActualPayloadT<ValueT>;
|
|
|
|
const char* Tag() const { return stream_port.Tag(); }
|
|
|
|
auto operator()(CalculatorContract* cc) const {
|
|
bool is_stream = true;
|
|
auto& stream_collection = internal::GetCollection(cc, stream_port);
|
|
auto* packet_type = internal::GetOrNull(stream_collection, Tag(), 0);
|
|
if (packet_type == nullptr) {
|
|
auto& side_collection = internal::GetCollection(cc, side_port);
|
|
packet_type = internal::GetOrNull(side_collection, Tag(), 0);
|
|
is_stream = false;
|
|
}
|
|
return internal::SinglePortAccess<PayloadT>(cc, packet_type, is_stream);
|
|
}
|
|
|
|
auto operator()(CalculatorContext* cc) const {
|
|
auto& stream_collection = internal::GetCollection(cc, stream_port);
|
|
auto& side_collection = internal::GetCollection(cc, side_port);
|
|
return internal::SinglePortAccess<PayloadT>(
|
|
cc, internal::GetOrNull(stream_collection, Tag(), 0),
|
|
internal::GetOrNull(side_collection, Tag(), 0));
|
|
}
|
|
|
|
template <std::size_t N>
|
|
explicit constexpr SideFallbackT(const char (&tag)[N])
|
|
: Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV),
|
|
stream_port(tag),
|
|
side_port(tag) {}
|
|
|
|
protected:
|
|
absl::Status AddToContract(CalculatorContract* cc) const {
|
|
stream_port.AddToContract(cc);
|
|
side_port.AddToContract(cc);
|
|
int connected_count =
|
|
stream_port(cc).IsConnected() + side_port(cc).IsConnected();
|
|
if (connected_count > 1)
|
|
return absl::InvalidArgumentError(absl::StrCat(
|
|
Tag(),
|
|
" can be connected as a stream or as a side packet, but not both"));
|
|
if (!IsOptionalV && connected_count == 0)
|
|
return absl::InvalidArgumentError(
|
|
absl::StrCat(Tag(), " must be connected"));
|
|
return {};
|
|
}
|
|
|
|
using StreamPort = PortCommon<Base, ValueT, true, IsMultipleV>;
|
|
using SidePort = PortCommon<typename internal::SideBase<Base>::type, ValueT,
|
|
true, IsMultipleV>;
|
|
StreamPort stream_port;
|
|
SidePort side_port;
|
|
|
|
template <typename...>
|
|
friend class internal::Contract;
|
|
};
|
|
|
|
// An OutputShardAccess is returned when accessing an output stream within a
|
|
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
|
|
// OutputStreamShard. Like that class, this class will not be usually named in
|
|
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
|
|
class OutputShardAccessBase {
|
|
public:
|
|
OutputShardAccessBase(const CalculatorContext& cc, OutputStreamShard* output)
|
|
: context_(cc), output_(output) {}
|
|
|
|
Timestamp NextTimestampBound() const {
|
|
return (output_) ? output_->NextTimestampBound() : Timestamp::Unset();
|
|
}
|
|
void SetNextTimestampBound(Timestamp timestamp) {
|
|
if (output_) output_->SetNextTimestampBound(timestamp);
|
|
}
|
|
|
|
bool IsClosed() const { return output_ ? output_->IsClosed() : true; }
|
|
void Close() {
|
|
if (output_) output_->Close();
|
|
}
|
|
|
|
bool IsConnected() const { return output_ != nullptr; }
|
|
|
|
protected:
|
|
const CalculatorContext& context_;
|
|
OutputStreamShard* output_;
|
|
};
|
|
|
|
template <typename T>
|
|
class OutputShardAccess : public OutputShardAccessBase {
|
|
public:
|
|
void Send(Packet<T>&& packet) {
|
|
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
|
|
}
|
|
|
|
void Send(const Packet<T>& packet) {
|
|
if (output_) output_->AddPacket(ToOldPacket(packet));
|
|
}
|
|
|
|
void Send(const T& payload, Timestamp time) {
|
|
Send(api2::MakePacket<T>(payload).At(time));
|
|
}
|
|
|
|
void Send(const T& payload) { Send(payload, context_.InputTimestamp()); }
|
|
|
|
void Send(T&& payload, Timestamp time) {
|
|
Send(api2::MakePacket<T>(std::move(payload)).At(time));
|
|
}
|
|
|
|
void Send(T&& payload) {
|
|
Send(std::move(payload), context_.InputTimestamp());
|
|
}
|
|
|
|
void Send(std::unique_ptr<T> payload, Timestamp time) {
|
|
Send(api2::PacketAdopting(std::move(payload)).At(time));
|
|
}
|
|
|
|
void Send(std::unique_ptr<T> payload) {
|
|
Send(std::move(payload), context_.InputTimestamp());
|
|
}
|
|
|
|
void SetHeader(const PacketBase& header) {
|
|
if (output_) output_->SetHeader(ToOldPacket(header));
|
|
}
|
|
|
|
private:
|
|
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
|
|
: OutputShardAccessBase(cc, output) {}
|
|
|
|
friend OutputShardAccess<T> internal::SinglePortAccess<T>(
|
|
mediapipe::CalculatorContext*, OutputStreamShard*);
|
|
};
|
|
|
|
template <>
|
|
class OutputShardAccess<internal::Generic> : public OutputShardAccessBase {
|
|
public:
|
|
void Send(PacketBase&& packet) {
|
|
if (output_) output_->AddPacket(ToOldPacket(std::move(packet)));
|
|
}
|
|
|
|
void Send(const PacketBase& packet) {
|
|
if (output_) output_->AddPacket(ToOldPacket(packet));
|
|
}
|
|
|
|
void SetHeader(const PacketBase& header) {
|
|
if (output_) output_->SetHeader(ToOldPacket(header));
|
|
}
|
|
|
|
private:
|
|
OutputShardAccess(const CalculatorContext& cc, OutputStreamShard* output)
|
|
: OutputShardAccessBase(cc, output) {}
|
|
|
|
friend OutputShardAccess<internal::Generic>
|
|
internal::SinglePortAccess<internal::Generic>(mediapipe::CalculatorContext*,
|
|
OutputStreamShard*);
|
|
};
|
|
|
|
// Equivalent of OutputShardAccess, but for side packets.
|
|
template <typename T>
|
|
class OutputSidePacketAccess {
|
|
public:
|
|
void Set(Packet<T> packet) {
|
|
if (output_) output_->Set(ToOldPacket(std::move(packet)));
|
|
}
|
|
|
|
void Set(const T& payload) { Set(MakePacket<T>(payload)); }
|
|
void Set(T&& payload) { Set(MakePacket<T>(std::move(payload))); }
|
|
|
|
private:
|
|
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}
|
|
OutputSidePacket* output_;
|
|
|
|
friend OutputSidePacketAccess<T> internal::SinglePortAccess<T>(
|
|
mediapipe::CalculatorContext*, OutputSidePacket*);
|
|
};
|
|
|
|
template <typename T>
|
|
class InputShardAccess : public Packet<T> {
|
|
public:
|
|
const PacketBase& packet() const& { return *this; }
|
|
// Since InputShardAccess is currently created as a temporary, this avoids
|
|
// easy mistakes with dangling references.
|
|
PacketBase packet() const&& { return *this; }
|
|
|
|
bool IsDone() const { return stream_->IsDone(); }
|
|
bool IsConnected() const { return stream_ != nullptr; }
|
|
|
|
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
|
|
|
|
// "Consume" requires exclusive ownership of the packet's payload. In the
|
|
// current interim implementation, InputShardAccess creates a new reference to
|
|
// the payload (as a Packet<T> instead of a type-erased Packet), which means
|
|
// the conditions for Consume would never be satisfied. This helper class
|
|
// defines wrappers for the Consume methods in Packet which temporarily erase
|
|
// the reference held by the underlying InputStreamShard.
|
|
// Note that we cannot simply take over the reference when InputShardAccess is
|
|
// created, because it is currently created as a temporary and we might create
|
|
// more than one instance for the same stream.
|
|
template <class U = T,
|
|
class = std::enable_if_t<std::is_same<U, T>{},
|
|
decltype(&Packet<U>::Consume)>>
|
|
absl::StatusOr<std::unique_ptr<U>> Consume() {
|
|
return WrapConsumeCall(&Packet<T>::Consume);
|
|
}
|
|
|
|
template <class V, class U = T,
|
|
std::enable_if_t<internal::IsCompatibleType<V, U>{}, int> = 0>
|
|
absl::StatusOr<std::unique_ptr<V>> Consume() {
|
|
return WrapConsumeCall(&Packet<T>::template Consume<V>);
|
|
}
|
|
|
|
template <class... F>
|
|
auto ConsumeAndVisit(F&&... args) {
|
|
auto f = &Packet<T>::template ConsumeAndVisit<F...>;
|
|
return WrapConsumeCall(f, std::forward<F>(args)...);
|
|
}
|
|
|
|
private:
|
|
InputShardAccess(const CalculatorContext&, InputStreamShard* stream)
|
|
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
|
|
: Packet<T>()),
|
|
stream_(stream) {}
|
|
|
|
template <class F, class... A>
|
|
auto WrapConsumeCall(F f, A&&... args) {
|
|
stream_->Value() = {};
|
|
auto result = (this->*f)(std::forward<A>(args)...);
|
|
if (!result.ok()) {
|
|
stream_->Value() = ToOldPacket(*this);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
InputStreamShard* stream_;
|
|
|
|
friend InputShardAccess<T> internal::SinglePortAccess<T>(
|
|
mediapipe::CalculatorContext*, InputStreamShard*);
|
|
};
|
|
|
|
template <typename T>
|
|
class InputSidePacketAccess : public Packet<T> {
|
|
public:
|
|
const PacketBase& packet() const& { return *this; }
|
|
PacketBase packet() const&& { return *this; }
|
|
|
|
bool IsConnected() const { return connected_; }
|
|
|
|
private:
|
|
InputSidePacketAccess(const mediapipe::Packet* packet)
|
|
: Packet<T>(packet ? FromOldPacket(*packet).template As<T>()
|
|
: Packet<T>()),
|
|
connected_(packet != nullptr) {}
|
|
bool connected_;
|
|
|
|
friend InputSidePacketAccess<T> internal::SinglePortAccess<T>(
|
|
mediapipe::CalculatorContext*, const mediapipe::Packet*);
|
|
};
|
|
|
|
template <typename T>
|
|
class InputShardOrSideAccess : public Packet<T> {
|
|
public:
|
|
const PacketBase& packet() const& { return *this; }
|
|
PacketBase packet() const&& { return *this; }
|
|
|
|
bool IsDone() const { return stream_->IsDone(); }
|
|
bool IsConnected() const { return connected_; }
|
|
bool IsStream() const { return stream_ != nullptr; }
|
|
|
|
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
|
|
|
|
private:
|
|
InputShardOrSideAccess(const CalculatorContext&, InputStreamShard* stream,
|
|
const mediapipe::Packet* packet)
|
|
: Packet<T>(stream ? FromOldPacket(stream->Value()).template As<T>()
|
|
: packet ? FromOldPacket(*packet).template As<T>()
|
|
: Packet<T>()),
|
|
stream_(stream),
|
|
connected_(stream_ != nullptr || packet != nullptr) {}
|
|
InputStreamShard* stream_;
|
|
bool connected_;
|
|
|
|
friend InputShardOrSideAccess<T> internal::SinglePortAccess<T>(
|
|
mediapipe::CalculatorContext*, InputStreamShard*,
|
|
const mediapipe::Packet*);
|
|
};
|
|
|
|
class PacketTypeAccess {
|
|
public:
|
|
bool IsConnected() const { return packet_type_ != nullptr; }
|
|
|
|
protected:
|
|
PacketTypeAccess(PacketType* pt) : packet_type_(pt) {}
|
|
PacketType* packet_type_;
|
|
|
|
template <typename T>
|
|
friend PacketTypeAccess internal::SinglePortAccess(
|
|
mediapipe::CalculatorContract*, PacketType*);
|
|
};
|
|
|
|
class PacketTypeAccessFallback : public PacketTypeAccess {
|
|
public:
|
|
bool IsStream() const { return is_stream_; }
|
|
|
|
private:
|
|
PacketTypeAccessFallback(PacketType* pt, bool is_stream)
|
|
: PacketTypeAccess(pt), is_stream_(is_stream) {}
|
|
bool is_stream_;
|
|
|
|
template <typename T>
|
|
friend PacketTypeAccessFallback internal::SinglePortAccess(
|
|
mediapipe::CalculatorContract*, PacketType*, bool);
|
|
};
|
|
|
|
namespace internal {
|
|
template <typename ValueT>
|
|
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
|
|
PacketType* pt) {
|
|
return PacketTypeAccess(pt);
|
|
}
|
|
template <typename ValueT>
|
|
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
|
|
PacketType* pt, bool is_stream) {
|
|
return PacketTypeAccessFallback(pt, is_stream);
|
|
}
|
|
} // namespace internal
|
|
|
|
} // namespace api2
|
|
} // namespace mediapipe
|
|
|
|
#endif // MEDIAPIPE_FRAMEWORK_API2_PORT_H_
|