mediapipe/mediapipe/framework/validated_graph_config.h
MediaPipe Team 3aeec84ac0 Internal change for profiling
PiperOrigin-RevId: 494126771
2022-12-09 03:21:10 -08:00

470 lines
19 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_
#define MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_
#include <map>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/status_handler.pb.h"
#include "mediapipe/framework/subgraph.h"
namespace mediapipe {
class ValidatedGraphConfig;
// Type information for a graph node (Calculator, Generator, etc).
class NodeTypeInfo {
public:
enum class NodeType {
UNKNOWN = 0,
CALCULATOR = 1,
PACKET_GENERATOR = 2,
GRAPH_INPUT_STREAM = 3, // The virtual node parent of a graph input stream.
STATUS_HANDLER = 4,
};
struct NodeRef {
NodeRef() = default;
NodeRef(NodeType node_type, int node_index)
: type(node_type), index(node_index) {}
NodeType type = NodeType::UNKNOWN;
// The index of a graph node among the nodes of the same type in the
// validated graph config.
int index = -1;
};
NodeTypeInfo() = default;
~NodeTypeInfo() = default;
// Don't allow copy or assign (PacketTypeSet does not support copy).
NodeTypeInfo(const NodeTypeInfo& other) = delete;
NodeTypeInfo& operator=(const NodeTypeInfo& other) = delete;
// Allow move (so we can create a std::vector with this type).
NodeTypeInfo(NodeTypeInfo&& other) = default;
// node_index is the index of this node among the nodes of the same type
// in the validated graph config.
absl::Status Initialize(const ValidatedGraphConfig& validated_graph,
const CalculatorGraphConfig::Node& node,
int node_index);
absl::Status Initialize(const ValidatedGraphConfig& validated_graph,
const PacketGeneratorConfig& node, int node_index);
absl::Status Initialize(const ValidatedGraphConfig& validated_graph,
const StatusHandlerConfig& node, int node_index);
// TODO: many of these accessors can be replaced by Contract().
const PacketTypeSet& InputSidePacketTypes() const {
return contract_.InputSidePackets();
}
const PacketTypeSet& OutputSidePacketTypes() const {
return contract_.OutputSidePackets();
}
const PacketTypeSet& InputStreamTypes() const { return contract_.Inputs(); }
const PacketTypeSet& OutputStreamTypes() const { return contract_.Outputs(); }
const CalculatorContract& Contract() const { return contract_; }
// Non-const accessors.
PacketTypeSet& InputSidePacketTypes() { return contract_.InputSidePackets(); }
PacketTypeSet& OutputSidePacketTypes() {
return contract_.OutputSidePackets();
}
PacketTypeSet& InputStreamTypes() { return contract_.Inputs(); }
PacketTypeSet& OutputStreamTypes() { return contract_.Outputs(); }
// Get the input/output side packet/stream index that is the first
// for the PacketTypeSets. Subsequent id's in the collection are
// guaranteed to be contiguous in the main flat array.
int InputSidePacketBaseIndex() const { return input_side_packet_base_index_; }
int OutputSidePacketBaseIndex() const {
return output_side_packet_base_index_;
}
int InputStreamBaseIndex() const { return input_stream_base_index_; }
int OutputStreamBaseIndex() const { return output_stream_base_index_; }
// Get the type and index of this node.
const NodeRef& Node() const { return node_; }
// Setter methods for the indexes. This should only be used by
// ValidatedGraphConfig.
void SetInputSidePacketBaseIndex(int index) {
input_side_packet_base_index_ = index;
}
void SetOutputSidePacketBaseIndex(int index) {
output_side_packet_base_index_ = index;
}
void SetInputStreamBaseIndex(int index) { input_stream_base_index_ = index; }
void SetOutputStreamBaseIndex(int index) {
output_stream_base_index_ = index;
}
void SetNodeIndex(int index) { node_.index = index; }
// Get the indexes (in ValidatedGraphConfig::Calculator's flat array)
// of the source nodes which affect this node. The index can also
// be a virtual node corresponding to a graph input stream (which are
// listed by index contiguously after all calculators).
// This function is only valid for a NodeTypeInfo of NodeType CALCULATOR.
const absl::flat_hash_set<int>& AncestorSources() const {
return ancestor_sources_;
}
// Returns True if the source was not already there.
// This function is only valid for a NodeTypeInfo of NodeType CALCULATOR.
bool AddSource(int index) { return ancestor_sources_.insert(index).second; }
// Convert the NodeType enum into a string (generally for error messaging).
static std::string NodeTypeToString(NodeType node_type);
// Returns the name of the specified InputStreamHandler, or empty string if
// none set.
std::string GetInputStreamHandler() const {
return contract_.GetInputStreamHandler();
}
// Returns the MediaPipeOptions specified, or empty options if none set.
MediaPipeOptions GetInputStreamHandlerOptions() const {
return contract_.GetInputStreamHandlerOptions();
}
private:
// This object owns the PacketType objects (which are referenced by
// ValidatedGraphConfig::EdgeInfo objects).
CalculatorContract contract_;
// The base indexes of the first entry belonging to this node in
// the main flat arrays of ValidatedGraphConfig. Subsequent
// entries are guaranteed to be sequential and in the order of the
// CollectionItemIds.
// Example:
// all_input_streams
// [node_info.InputStreamBaseIndex() +
// node_info.InputStreamTypes().GetId("TAG", 2).value()];
int input_side_packet_base_index_ = 0;
int output_side_packet_base_index_ = 0;
int input_stream_base_index_ = 0;
int output_stream_base_index_ = 0;
// The type and index of this node.
NodeRef node_;
// The set of sources which affect this node.
absl::flat_hash_set<int> ancestor_sources_;
};
// Information for either the input or output side of an edge. An edge
// is either a side packet or a stream.
struct EdgeInfo {
// For an input edge (input side packet, or input stream) this is the
// index of the corresponding output side which produces the data this
// edge will see.
int upstream = -1;
// The parent node which owns this edge. For graph input streams this
// is a virtual node (in which case there is no corresponding owning
// node in calculators_).
NodeTypeInfo::NodeRef parent_node;
std::string name;
PacketType* packet_type = nullptr;
bool back_edge = false; // Only applicable to input streams.
};
// This class is used to validate and canonicalize a CalculatorGraphConfig.
class ValidatedGraphConfig {
public:
// Initializes the ValidatedGraphConfig. This function must be called
// before any other functions. Subgraphs are specified through the
// global graph registry or an optional local graph registry.
absl::Status Initialize(
CalculatorGraphConfig input_config,
const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from registered graph and subgraph
// configs. Subgraphs are retrieved from the specified graph registry or from
// the global graph registry. A subgraph can be instantiated directly by
// specifying its type in |graph_type|.
absl::Status Initialize(
const std::string& graph_type,
const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from the specified graph and subgraph
// configs. Template graph and subgraph configs can be specified through
// |input_templates|. Every subgraph must have its graph type specified in
// CalclatorGraphConfig.type. A subgraph can be instantiated directly by
// specifying its type in |graph_type|. A template graph can be instantiated
// directly by specifying its template arguments in |arguments|.
absl::Status Initialize(
const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates,
const std::string& graph_type = "",
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Returns true if the ValidatedGraphConfig has been initialized.
bool Initialized() const { return initialized_; }
// Returns an error if the provided side packets will be generated by
// the PacketGenerators in this graph.
template <typename T>
absl::Status CanAcceptSidePackets(
const std::map<std::string, T>& side_packets) const;
// Validate that all the required side packets are provided, and the
// packets have the required type.
absl::Status ValidateRequiredSidePackets(
const std::map<std::string, Packet>& side_packets) const;
// Same as ValidateRequiredSidePackets but only provide the type.
absl::Status ValidateRequiredSidePacketTypes(
const std::map<std::string, PacketType>& side_packet_types) const;
// The proto configuration (canonicalized).
const CalculatorGraphConfig& Config() const { return config_; }
// Accessors for the info objects.
const std::vector<NodeTypeInfo>& CalculatorInfos() const {
return calculators_;
}
const std::vector<NodeTypeInfo>& GeneratorInfos() const {
return generators_;
}
const std::vector<NodeTypeInfo>& StatusHandlerInfos() const {
return status_handlers_;
}
const std::vector<EdgeInfo>& InputStreamInfos() const {
return input_streams_;
}
const std::vector<EdgeInfo>& OutputStreamInfos() const {
return output_streams_;
}
const std::vector<EdgeInfo>& InputSidePacketInfos() const {
return input_side_packets_;
}
const std::vector<EdgeInfo>& OutputSidePacketInfos() const {
return output_side_packets_;
}
int OutputStreamIndex(const std::string& name) const {
return FindWithDefault(stream_to_producer_, name, -1);
}
int OutputSidePacketIndex(const std::string& name) const {
return FindWithDefault(side_packet_to_producer_, name, -1);
}
int OutputStreamToNode(const std::string& name) const {
auto iter = stream_to_producer_.find(name);
if (iter == stream_to_producer_.end()) {
return -1;
}
return output_streams_[iter->second].parent_node.index;
}
std::vector<int> OutputStreamToConsumers(int idx) const {
auto iter = output_streams_to_consumer_nodes_.find(idx);
if (iter == output_streams_to_consumer_nodes_.end()) {
return {};
}
return iter->second;
}
// Returns the registered type name of the specified side packet if
// it can be determined, otherwise an appropriate error is returned.
absl::StatusOr<std::string> RegisteredSidePacketTypeName(
const std::string& name);
// Returns the registered type name of the specified stream if it can
// be determined, otherwise an appropriate error is returned.
absl::StatusOr<std::string> RegisteredStreamTypeName(const std::string& name);
// The namespace used for class name lookup.
std::string Package() const { return config_.package(); }
// Returns true if |name| is a reserved executor name.
static bool IsReservedExecutorName(const std::string& name);
// Returns true if a side packet is provided as an input to the graph.
bool IsExternalSidePacket(const std::string& name) const {
return required_side_packets_.count(name) > 0;
}
private:
// Perform transforms such as converting legacy features, expanding
// subgraphs, and popluting input stream handler.
absl::Status PerformBasicTransforms(
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager);
// Initialize the PacketGenerator information.
absl::Status InitializeGeneratorInfo();
// Initialize the Calculator information.
absl::Status InitializeCalculatorInfo();
// Initialize the StatusHandler information.
absl::Status InitializeStatusHandlerInfo();
// Initialize the EdgeInfo objects for side packets.
//
// If need_sorting_ptr is non-null it will be set to true iff the side
// packet graph is not topologically sorted. If the nodes in the side
// packet graph are not in sorted order, then side_packet_to_producer_
// will still be complete, but the upstream field of input_side_packets_
// may not be accurate.
//
// If need_sorting_ptr is nullptr then an error will be returned if the
// nodes in the side packet graph are not in topologically sorted order.
absl::Status InitializeSidePacketInfo(bool* need_sorting_ptr);
// Adds EdgeInfo objects to input_side_packets_ for all the input side
// packets required by the node_type_info. If nodes are processed
// with AddInputSidePacketsForNode and AddOutputSidePacketsForNode
// sequentially, then side_packet_to_producer_ and
// required_side_packets_ are used to ensure that the graph is
// topologically sorted. node_type_info is updated with the proper
// initial index for input side packets.
absl::Status AddInputSidePacketsForNode(NodeTypeInfo* node_type_info);
// Adds EdgeInfo objects to output_side_packets_ for all the output side
// packets produced by the node_type_info. side_packet_to_producer_ is
// updated. need_sorting_ptr will be set to true if the nodes in the
// side packet graph are detected to be in unsorted order (a side packet
// is output after something that required it), otherwise need_sorting_ptr
// is left as is. node_type_info is updated with the proper initial index
// for output side packets.
absl::Status AddOutputSidePacketsForNode(NodeTypeInfo* node_type_info,
bool* need_sorting_ptr);
// These functions are analogous to the same operations for side
// packets, with the small difference that it is an error to use an
// undefined stream (whereas it is allowed to use an undefined side
// packet).
absl::Status InitializeStreamInfo(bool* need_sorting_ptr);
absl::Status AddOutputStreamsForNode(NodeTypeInfo* node_type_info);
absl::Status AddInputStreamsForNode(NodeTypeInfo* node_type_info,
bool* need_sorting_ptr);
// A helper function for adding a single output stream EdgeInfo.
absl::Status AddOutputStream(NodeTypeInfo::NodeRef node,
const std::string& name,
PacketType* packet_type);
// Return the index of the node adjusted for the topological sorter.
int SorterIndexForNode(NodeTypeInfo::NodeRef node) const;
// Convert the index for the topological sorter back to the node type
// and node index.
NodeTypeInfo::NodeRef NodeForSorterIndex(int index) const;
// Sort the nodes based on the information gotten from
// InitializeSidePacketInfo and InitializeStreamInfo. After this
// function, the InitializeSidePacketInfo and InitializeStreamInfo
// functions must be run again (after clearing the data structures they
// fill).
//
// NOTE: Only the generators and calculators need to be sorted. The other
// two node types, graph input streams and status handlers, can be safely
// ignored in the analysis of output side packet generation or stream
// header packet propagation.
absl::Status TopologicalSortNodes();
// TODO Add InputStreamHandler.
// TODO Add OutputStreamHandler.
// Fill the "upstream" field for all back edges.
absl::Status FillUpstreamFieldForBackEdges();
// Compute the dependence of nodes on sources.
absl::Status ComputeSourceDependence();
// Infer the type of types set to "Any" by what they are connected to.
absl::Status ResolveAnyTypes(std::vector<EdgeInfo>* input_edges,
std::vector<EdgeInfo>* output_edges);
// Narrow down OneOf types if they other end is a single type.
absl::Status ResolveOneOfTypes(std::vector<EdgeInfo>* input_edges,
std::vector<EdgeInfo>* output_edges);
// Returns an error if the generator graph does not have consistent
// type specifications for side packets.
absl::Status ValidateSidePacketTypes();
// Returns an error if the graph of calculators does not have consistent
// type specifications for streams.
absl::Status ValidateStreamTypes();
// Returns an error if the graph does not have valid ExecutorConfigs, or
// if the executor name in a node config is reserved or is not declared
// in an ExecutorConfig.
absl::Status ValidateExecutors();
bool initialized_ = false;
CalculatorGraphConfig config_;
// The type information for each node type.
std::vector<NodeTypeInfo> calculators_;
std::vector<NodeTypeInfo> generators_;
std::vector<NodeTypeInfo> status_handlers_;
// NodeTypeInfo's of generators and calculators, topologically sorted.
std::vector<NodeTypeInfo*> sorted_nodes_;
// Mapping from stream name to the output_streams_ index which produces it.
std::map<std::string, int> stream_to_producer_;
// Mapping from output streams to consumer node ids. Used for profiling.
std::map<int, std::vector<int>> output_streams_to_consumer_nodes_;
// Mapping from side packet name to the output_side_packets_ index
// which produces it.
std::map<std::string, int> side_packet_to_producer_;
// A structure to manage deletion of PacketType objects which need to
// be owned by this object (used for graph input stream PacketType).
std::vector<std::unique_ptr<PacketType>> owned_packet_types_;
// For each side packet which must still be supplied, a list of
// input_side_packets_ indexes which must be validated against it.
// TODO Use the information stored here for more thorough
// validation.
std::map<std::string, std::vector<int>> required_side_packets_;
// The EdgeInfo objects for input/output side packets/streams.
std::vector<EdgeInfo> input_streams_;
std::vector<EdgeInfo> output_streams_;
std::vector<EdgeInfo> input_side_packets_;
std::vector<EdgeInfo> output_side_packets_;
};
template <typename T>
absl::Status ValidatedGraphConfig::CanAcceptSidePackets(
const std::map<std::string, T>& side_packets) const {
for (const auto& output_side_packet : output_side_packets_) {
if (ContainsKey(side_packets, output_side_packet.name)) {
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Side packet \"" << output_side_packet.name
<< "\" is both provided and generated by a PacketGenerator.";
}
}
return absl::OkStatus();
}
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_VALIDATED_GRAPH_CONFIG_H_