Internal change for profiling

PiperOrigin-RevId: 494126771
This commit is contained in:
MediaPipe Team 2022-12-09 03:19:45 -08:00 committed by Copybara-Service
parent bea0caae65
commit 3aeec84ac0
2 changed files with 22 additions and 0 deletions

View File

@ -369,6 +369,7 @@ absl::Status ValidatedGraphConfig::Initialize(
input_side_packets_.clear(); input_side_packets_.clear();
output_side_packets_.clear(); output_side_packets_.clear();
stream_to_producer_.clear(); stream_to_producer_.clear();
output_streams_to_consumer_nodes_.clear();
input_streams_.clear(); input_streams_.clear();
output_streams_.clear(); output_streams_.clear();
owned_packet_types_.clear(); owned_packet_types_.clear();
@ -719,6 +720,15 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
<< " does not have a corresponding output stream."; << " does not have a corresponding output stream.";
} }
} }
// Add this node as a consumer of this edge's output stream.
if (edge_info.upstream > -1) {
auto parent_node = output_streams_[edge_info.upstream].parent_node;
if (parent_node.type == NodeTypeInfo::NodeType::CALCULATOR) {
int this_idx = node_type_info->Node().index;
output_streams_to_consumer_nodes_[edge_info.upstream].push_back(
this_idx);
}
}
edge_info.parent_node = node_type_info->Node(); edge_info.parent_node = node_type_info->Node();
edge_info.name = name; edge_info.name = name;

View File

@ -282,6 +282,14 @@ class ValidatedGraphConfig {
return output_streams_[iter->second].parent_node.index; return output_streams_[iter->second].parent_node.index;
} }
std::vector<int> OutputStreamToConsumers(int idx) const {
auto iter = output_streams_to_consumer_nodes_.find(idx);
if (iter == output_streams_to_consumer_nodes_.end()) {
return {};
}
return iter->second;
}
// Returns the registered type name of the specified side packet if // Returns the registered type name of the specified side packet if
// it can be determined, otherwise an appropriate error is returned. // it can be determined, otherwise an appropriate error is returned.
absl::StatusOr<std::string> RegisteredSidePacketTypeName( absl::StatusOr<std::string> RegisteredSidePacketTypeName(
@ -418,6 +426,10 @@ class ValidatedGraphConfig {
// Mapping from stream name to the output_streams_ index which produces it. // Mapping from stream name to the output_streams_ index which produces it.
std::map<std::string, int> stream_to_producer_; std::map<std::string, int> stream_to_producer_;
// Mapping from output streams to consumer node ids. Used for profiling.
std::map<int, std::vector<int>> output_streams_to_consumer_nodes_;
// Mapping from side packet name to the output_side_packets_ index // Mapping from side packet name to the output_side_packets_ index
// which produces it. // which produces it.
std::map<std::string, int> side_packet_to_producer_; std::map<std::string, int> side_packet_to_producer_;