// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_ #define MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_ #include #include #include "absl/container/flat_hash_set.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/graph_service_manager.h" #include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/map_util.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/status_handler.pb.h" #include "mediapipe/framework/subgraph.h" namespace mediapipe { class ValidatedGraphConfig; // Type information for a graph node (Calculator, Generator, etc). class NodeTypeInfo { public: enum class NodeType { UNKNOWN = 0, CALCULATOR = 1, PACKET_GENERATOR = 2, GRAPH_INPUT_STREAM = 3, // The virtual node parent of a graph input stream. STATUS_HANDLER = 4, }; struct NodeRef { NodeRef() = default; NodeRef(NodeType node_type, int node_index) : type(node_type), index(node_index) {} NodeType type = NodeType::UNKNOWN; // The index of a graph node among the nodes of the same type in the // validated graph config. int index = -1; }; NodeTypeInfo() = default; ~NodeTypeInfo() = default; // Don't allow copy or assign (PacketTypeSet does not support copy). NodeTypeInfo(const NodeTypeInfo& other) = delete; NodeTypeInfo& operator=(const NodeTypeInfo& other) = delete; // Allow move (so we can create a std::vector with this type). NodeTypeInfo(NodeTypeInfo&& other) = default; // node_index is the index of this node among the nodes of the same type // in the validated graph config. absl::Status Initialize(const ValidatedGraphConfig& validated_graph, const CalculatorGraphConfig::Node& node, int node_index); absl::Status Initialize(const ValidatedGraphConfig& validated_graph, const PacketGeneratorConfig& node, int node_index); absl::Status Initialize(const ValidatedGraphConfig& validated_graph, const StatusHandlerConfig& node, int node_index); // TODO: many of these accessors can be replaced by Contract(). const PacketTypeSet& InputSidePacketTypes() const { return contract_.InputSidePackets(); } const PacketTypeSet& OutputSidePacketTypes() const { return contract_.OutputSidePackets(); } const PacketTypeSet& InputStreamTypes() const { return contract_.Inputs(); } const PacketTypeSet& OutputStreamTypes() const { return contract_.Outputs(); } const CalculatorContract& Contract() const { return contract_; } // Non-const accessors. PacketTypeSet& InputSidePacketTypes() { return contract_.InputSidePackets(); } PacketTypeSet& OutputSidePacketTypes() { return contract_.OutputSidePackets(); } PacketTypeSet& InputStreamTypes() { return contract_.Inputs(); } PacketTypeSet& OutputStreamTypes() { return contract_.Outputs(); } // Get the input/output side packet/stream index that is the first // for the PacketTypeSets. Subsequent id's in the collection are // guaranteed to be contiguous in the master flat array. int InputSidePacketBaseIndex() const { return input_side_packet_base_index_; } int OutputSidePacketBaseIndex() const { return output_side_packet_base_index_; } int InputStreamBaseIndex() const { return input_stream_base_index_; } int OutputStreamBaseIndex() const { return output_stream_base_index_; } // Get the type and index of this node. const NodeRef& Node() const { return node_; } // Setter methods for the indexes. This should only be used by // ValidatedGraphConfig. void SetInputSidePacketBaseIndex(int index) { input_side_packet_base_index_ = index; } void SetOutputSidePacketBaseIndex(int index) { output_side_packet_base_index_ = index; } void SetInputStreamBaseIndex(int index) { input_stream_base_index_ = index; } void SetOutputStreamBaseIndex(int index) { output_stream_base_index_ = index; } void SetNodeIndex(int index) { node_.index = index; } // Get the indexes (in ValidatedGraphConfig::Calculator's flat array) // of the source nodes which affect this node. The index can also // be a virtual node corresponding to a graph input stream (which are // listed by index contiguously after all calculators). // This function is only valid for a NodeTypeInfo of NodeType CALCULATOR. const absl::flat_hash_set& AncestorSources() const { return ancestor_sources_; } // Returns True if the source was not already there. // This function is only valid for a NodeTypeInfo of NodeType CALCULATOR. bool AddSource(int index) { return ancestor_sources_.insert(index).second; } // Convert the NodeType enum into a std::string (generally for error // messaging). static std::string NodeTypeToString(NodeType node_type); // Returns the name of the specified InputStreamHandler, or empty std::string // if none set. std::string GetInputStreamHandler() const { return contract_.GetInputStreamHandler(); } // Returns the MediaPipeOptions specified, or empty options if none set. MediaPipeOptions GetInputStreamHandlerOptions() const { return contract_.GetInputStreamHandlerOptions(); } private: // This object owns the PacketType objects (which are referenced by // ValidatedGraphConfig::EdgeInfo objects). CalculatorContract contract_; // The base indexes of the first entry belonging to this node in // the master flat arrays of ValidatedGraphConfig. Subsequent // entries are guaranteed to be sequential and in the order of the // CollectionItemIds. // Example: // all_input_streams // [node_info.InputStreamBaseIndex() + // node_info.InputStreamTypes().GetId("TAG", 2).value()]; int input_side_packet_base_index_ = -1; int output_side_packet_base_index_ = -1; int input_stream_base_index_ = -1; int output_stream_base_index_ = -1; // The type and index of this node. NodeRef node_; // The set of sources which affect this node. absl::flat_hash_set ancestor_sources_; }; // Information for either the input or output side of an edge. An edge // is either a side packet or a stream. struct EdgeInfo { // For an input edge (input side packet, or input stream) this is the // index of the corresponding output side which produces the data this // edge will see. int upstream = -1; // The parent node which owns this edge. For graph input streams this // is a virtual node (in which case there is no corresponding owning // node in calculators_). NodeTypeInfo::NodeRef parent_node; std::string name; PacketType* packet_type = nullptr; bool back_edge = false; // Only applicable to input streams. }; // This class is used to validate and canonicalize a CalculatorGraphConfig. class ValidatedGraphConfig { public: // Initializes the ValidatedGraphConfig. This function must be called // before any other functions. Subgraphs are specified through the // global graph registry or an optional local graph registry. absl::Status Initialize(const CalculatorGraphConfig& input_config, const GraphRegistry* graph_registry = nullptr, const GraphServiceManager* service_manager = nullptr); // Initializes the ValidatedGraphConfig from registered graph and subgraph // configs. Subgraphs are retrieved from the specified graph registry or from // the global graph registry. A subgraph can be instantiated directly by // specifying its type in |graph_type|. absl::Status Initialize(const std::string& graph_type, const Subgraph::SubgraphOptions* options = nullptr, const GraphRegistry* graph_registry = nullptr, const GraphServiceManager* service_manager = nullptr); // Initializes the ValidatedGraphConfig from the specified graph and subgraph // configs. Template graph and subgraph configs can be specified through // |input_templates|. Every subgraph must have its graph type specified in // CalclatorGraphConfig.type. A subgraph can be instantiated directly by // specifying its type in |graph_type|. A template graph can be instantiated // directly by specifying its template arguments in |arguments|. absl::Status Initialize( const std::vector& input_configs, const std::vector& input_templates, const std::string& graph_type = "", const Subgraph::SubgraphOptions* arguments = nullptr, const GraphServiceManager* service_manager = nullptr); // Returns true if the ValidatedGraphConfig has been initialized. bool Initialized() const { return initialized_; } // Returns an error if the provided side packets will be generated by // the PacketGenerators in this graph. template absl::Status CanAcceptSidePackets( const std::map& side_packets) const; // Validate that all the required side packets are provided, and the // packets have the required type. absl::Status ValidateRequiredSidePackets( const std::map& side_packets) const; // Same as ValidateRequiredSidePackets but only provide the type. absl::Status ValidateRequiredSidePacketTypes( const std::map& side_packet_types) const; // The proto configuration (canonicalized). const CalculatorGraphConfig& Config() const { return config_; } // Accessors for the info objects. const std::vector& CalculatorInfos() const { return calculators_; } const std::vector& GeneratorInfos() const { return generators_; } const std::vector& StatusHandlerInfos() const { return status_handlers_; } const std::vector& InputStreamInfos() const { return input_streams_; } const std::vector& OutputStreamInfos() const { return output_streams_; } const std::vector& InputSidePacketInfos() const { return input_side_packets_; } const std::vector& OutputSidePacketInfos() const { return output_side_packets_; } int OutputStreamIndex(const std::string& name) const { return FindWithDefault(stream_to_producer_, name, -1); } int OutputSidePacketIndex(const std::string& name) const { return FindWithDefault(side_packet_to_producer_, name, -1); } int OutputStreamToNode(const std::string& name) const { auto iter = stream_to_producer_.find(name); if (iter == stream_to_producer_.end()) { return -1; } return output_streams_[iter->second].parent_node.index; } // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredSidePacketTypeName( const std::string& name); // Returns the registered type name of the specified stream if it can // be determined, otherwise an appropriate error is returned. absl::StatusOr RegisteredStreamTypeName(const std::string& name); // The namespace used for class name lookup. std::string Package() const { return config_.package(); } // Returns true if |name| is a reserved executor name. static bool IsReservedExecutorName(const std::string& name); // Returns true if a side packet is provided as an input to the graph. bool IsExternalSidePacket(const std::string& name) const { return required_side_packets_.count(name) > 0; } private: // Initialize the PacketGenerator information. absl::Status InitializeGeneratorInfo(); // Initialize the Calculator information. absl::Status InitializeCalculatorInfo(); // Initialize the StatusHandler information. absl::Status InitializeStatusHandlerInfo(); // Initialize the EdgeInfo objects for side packets. // // If need_sorting_ptr is non-null it will be set to true iff the side // packet graph is not topologically sorted. If the nodes in the side // packet graph are not in sorted order, then side_packet_to_producer_ // will still be complete, but the upstream field of input_side_packets_ // may not be accurate. // // If need_sorting_ptr is nullptr then an error will be returned if the // nodes in the side packet graph are not in topologically sorted order. absl::Status InitializeSidePacketInfo(bool* need_sorting_ptr); // Adds EdgeInfo objects to input_side_packets_ for all the input side // packets required by the node_type_info. If nodes are processed // with AddInputSidePacketsForNode and AddOutputSidePacketsForNode // sequentially, then side_packet_to_producer_ and // required_side_packets_ are used to ensure that the graph is // topologically sorted. node_type_info is updated with the proper // initial index for input side packets. absl::Status AddInputSidePacketsForNode(NodeTypeInfo* node_type_info); // Adds EdgeInfo objects to output_side_packets_ for all the output side // packets produced by the node_type_info. side_packet_to_producer_ is // updated. need_sorting_ptr will be set to true if the nodes in the // side packet graph are detected to be in unsorted order (a side packet // is output after something that required it), otherwise need_sorting_ptr // is left as is. node_type_info is updated with the proper initial index // for output side packets. absl::Status AddOutputSidePacketsForNode(NodeTypeInfo* node_type_info, bool* need_sorting_ptr); // These functions are analogous to the same operations for side // packets, with the small difference that it is an error to use an // undefined stream (whereas it is allowed to use an undefined side // packet). absl::Status InitializeStreamInfo(bool* need_sorting_ptr); absl::Status AddOutputStreamsForNode(NodeTypeInfo* node_type_info); absl::Status AddInputStreamsForNode(NodeTypeInfo* node_type_info, bool* need_sorting_ptr); // A helper function for adding a single output stream EdgeInfo. absl::Status AddOutputStream(NodeTypeInfo::NodeRef node, const std::string& name, PacketType* packet_type); // Return the index of the node adjusted for the topological sorter. int SorterIndexForNode(NodeTypeInfo::NodeRef node) const; // Convert the index for the topological sorter back to the node type // and node index. NodeTypeInfo::NodeRef NodeForSorterIndex(int index) const; // Sort the nodes based on the information gotten from // InitializeSidePacketInfo and InitializeStreamInfo. After this // function, the InitializeSidePacketInfo and InitializeStreamInfo // functions must be run again (after clearing the data structures they // fill). // // NOTE: Only the generators and calculators need to be sorted. The other // two node types, graph input streams and status handlers, can be safely // ignored in the analysis of output side packet generation or stream // header packet propagation. absl::Status TopologicalSortNodes(); // TODO Add InputStreamHandler. // TODO Add OutputStreamHandler. // Fill the "upstream" field for all back edges. absl::Status FillUpstreamFieldForBackEdges(); // Compute the dependence of nodes on sources. absl::Status ComputeSourceDependence(); // Infer the type of types set to "Any" by what they are connected to. absl::Status ResolveAnyTypes(std::vector* input_edges, std::vector* output_edges); // Returns an error if the generator graph does not have consistent // type specifications for side packets. absl::Status ValidateSidePacketTypes(); // Returns an error if the graph of calculators does not have consistent // type specifications for streams. absl::Status ValidateStreamTypes(); // Returns an error if the graph does not have valid ExecutorConfigs, or // if the executor name in a node config is reserved or is not declared // in an ExecutorConfig. absl::Status ValidateExecutors(); bool initialized_ = false; CalculatorGraphConfig config_; // The type information for each node type. std::vector calculators_; std::vector generators_; std::vector status_handlers_; // NodeTypeInfo's of generators and calculators, topologically sorted. std::vector sorted_nodes_; // Mapping from stream name to the output_streams_ index which produces it. std::map stream_to_producer_; // Mapping from side packet name to the output_side_packets_ index // which produces it. std::map side_packet_to_producer_; // A structure to manage deletion of PacketType objects which need to // be owned by this object (used for graph input stream PacketType). std::vector> owned_packet_types_; // For each side packet which must still be supplied, a list of // input_side_packets_ indexes which must be validated against it. // TODO Use the information stored here for more thorough // validation. std::map> required_side_packets_; // The EdgeInfo objects for input/output side packets/streams. std::vector input_streams_; std::vector output_streams_; std::vector input_side_packets_; std::vector output_side_packets_; }; template absl::Status ValidatedGraphConfig::CanAcceptSidePackets( const std::map& side_packets) const { for (const auto& output_side_packet : output_side_packets_) { if (ContainsKey(side_packets, output_side_packet.name)) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << output_side_packet.name << "\" is both provided and generated by a PacketGenerator."; } } return absl::OkStatus(); } } // namespace mediapipe #endif // MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_