#ifndef MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ #define MEDIAPIPE_FRAMEWORK_API2_BUILDER_H_ #include #include #include "absl/container/flat_hash_map.h" #include "mediapipe/framework/api2/const_str.h" #include "mediapipe/framework/api2/contract.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_contract.h" namespace mediapipe { namespace api2 { namespace builder { template T& GetWithAutoGrow(std::vector>* vecp, int index) { auto& vec = *vecp; if (vec.size() <= index) { vec.resize(index + 1); } if (vec[index] == nullptr) { vec[index] = absl::make_unique(); } return *vec[index]; } struct TagIndexLocation { const std::string& tag; std::size_t index; std::size_t count; }; template class TagIndexMap { public: std::vector>& operator[](const std::string& tag) { return map_[tag]; } void Visit(std::function fun) const { for (const auto& tagged : map_) { TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; for (const auto& item : tagged.second) { fun(loc, *item); ++loc.index; } } } void Visit(std::function fun) { for (auto& tagged : map_) { TagIndexLocation loc{tagged.first, 0, tagged.second.size()}; for (auto& item : tagged.second) { fun(loc, item.get()); ++loc.index; } } } // Note: entries are held by a unique_ptr to ensure pointers remain valid. // Should use absl::flat_hash_map but ordering keys for now. std::map>> map_; }; // These structs are used internally to store information about the endpoints // of a connection. struct SourceBase; struct DestinationBase { SourceBase* source = nullptr; }; struct SourceBase { std::vector dests_; std::string name_; }; // Following existing GraphConfig usage, we allow using a multiport as a single // port as well. This is necessary for generic nodes, since we have no // information about which ports are meant to be multiports or not, but it is // also convenient with typed nodes. template class MultiPort : public Single { public: using Base = typename Single::Base; explicit MultiPort(std::vector>* vec) : Single(vec), vec_(*vec) {} Single operator[](int index) { CHECK_GE(index, 0); return Single{&GetWithAutoGrow(&vec_, index)}; } private: std::vector>& vec_; }; // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template class DestinationImpl { public: using Base = DestinationBase; explicit DestinationImpl(std::vector>* vec) : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {} DestinationBase& base_; }; template class DestinationImpl : public MultiPort> { public: using MultiPort>::MultiPort; }; template class SourceImpl { public: using Base = SourceBase; // Src is used as the return type of fluent methods below. Since these are // single-port methods, it is desirable to always decay to a reference to the // single-port superclass, even if they are called on a multiport. using Src = SourceImpl; template using Dst = DestinationImpl; // clang-format off template struct AllowConnection : public std::integral_constant{} || std::is_same{} || std::is_same{}> {}; // clang-format on explicit SourceImpl(std::vector>* vec) : SourceImpl(&GetWithAutoGrow(vec, 0)) {} explicit SourceImpl(SourceBase* base) : base_(*base) {} template {}, int>::type = 0> Src& AddTarget(const Dst& dest) { CHECK(dest.base_.source == nullptr); dest.base_.source = &base_; base_.dests_.emplace_back(&dest.base_); return *this; } Src& SetName(std::string name) { base_.name_ = std::move(name); return *this; } template Src& operator>>(const Dst& dest) { return AddTarget(dest); } private: SourceBase& base_; }; template class SourceImpl : public MultiPort> { public: using MultiPort>::MultiPort; }; // A source and a destination correspond to an output/input stream on a node, // and a side source and side destination correspond to an output/input side // packet. // For graph inputs/outputs, however, the inputs are sources, and the outputs // are destinations. This is because graph ports are connected "from inside" // when building the graph. template using Source = SourceImpl; template using SideSource = SourceImpl; template using Destination = DestinationImpl; template using SideDestination = DestinationImpl; class NodeBase { public: // TODO: right now access to an indexed port is made directly by // specifying both a tag and an index. It would be better to represent this // as a two-step lookup, first getting a multi-port, and then accessing one // of its entries by index. However, for nodes without visible contracts we // can't know whether a tag is indexable or not, so we would need the // multi-port to also be usable as a port directly (representing index 0). Source Out(const std::string& tag) { return Source(&out_streams_[tag]); } Destination In(const std::string& tag) { return Destination(&in_streams_[tag]); } SideSource SideOut(const std::string& tag) { return SideSource(&out_sides_[tag]); } SideDestination SideIn(const std::string& tag) { return SideDestination(&in_sides_[tag]); } // Convenience methods for accessing purely index-based ports. Source Out(int index) { return Out("")[index]; } Destination In(int index) { return In("")[index]; } SideSource SideOut(int index) { return SideOut("")[index]; } SideDestination SideIn(int index) { return SideIn("")[index]; } template T& GetOptions() { options_used_ = true; return *options_.MutableExtension(T::ext); } protected: NodeBase(std::string type) : type_(std::move(type)) {} std::string type_; TagIndexMap in_streams_; TagIndexMap out_streams_; TagIndexMap in_sides_; TagIndexMap out_sides_; CalculatorOptions options_; // ideally we'd just check if any extensions are set on options_ bool options_used_ = false; friend class Graph; }; template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. explicit Node()->Node; #endif // C++17 template <> class Node : public NodeBase { public: Node(std::string type) : NodeBase(std::move(type)) {} }; using GenericNode = Node; template