Introduce FixGraphBackEdges utils function.
PiperOrigin-RevId: 573925628
This commit is contained in:
		
							parent
							
								
									a1e1b5d34c
								
							
						
					
					
						commit
						2e11444f5c
					
				| 
						 | 
					@ -32,6 +32,7 @@ namespace core {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
constexpr char kFinishedTag[] = "FINISHED";
 | 
					constexpr char kFinishedTag[] = "FINISHED";
 | 
				
			||||||
constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
 | 
					constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
 | 
				
			||||||
 | 
					constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -89,6 +90,19 @@ CalculatorGraphConfig AddFlowLimiterCalculator(
 | 
				
			||||||
  return config;
 | 
					  return config;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config) {
 | 
				
			||||||
 | 
					  // TODO remove when support is fixed.
 | 
				
			||||||
 | 
					  // As mediapipe GraphBuilder currently doesn't support configuring
 | 
				
			||||||
 | 
					  // InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
 | 
				
			||||||
 | 
					  for (int i = 0; i < graph_config.node_size(); ++i) {
 | 
				
			||||||
 | 
					    if (graph_config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
 | 
				
			||||||
 | 
					      auto* info = graph_config.mutable_node(i)->add_input_stream_info();
 | 
				
			||||||
 | 
					      info->set_tag_index("LOOP");
 | 
				
			||||||
 | 
					      info->set_back_edge(true);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace core
 | 
					}  // namespace core
 | 
				
			||||||
}  // namespace tasks
 | 
					}  // namespace tasks
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,6 +84,10 @@ static TensorType* FindTensorByName(
 | 
				
			||||||
    std::vector<std::string> input_stream_tags, std::string finished_stream_tag,
 | 
					    std::vector<std::string> input_stream_tags, std::string finished_stream_tag,
 | 
				
			||||||
    int max_in_flight = 1, int max_in_queue = 1);
 | 
					    int max_in_flight = 1, int max_in_queue = 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Fixs the graph config containing PreviousLoopbackCalculator where the edge
 | 
				
			||||||
 | 
					// forming a loop needs to be tagged as back edge.
 | 
				
			||||||
 | 
					void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace core
 | 
					}  // namespace core
 | 
				
			||||||
}  // namespace tasks
 | 
					}  // namespace tasks
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -393,18 +393,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
          graph[Output<std::vector<FaceGeometry>>(kFaceGeometryTag)];
 | 
					          graph[Output<std::vector<FaceGeometry>>(kFaceGeometryTag)];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO remove when support is fixed.
 | 
					 | 
				
			||||||
    // As mediapipe GraphBuilder currently doesn't support configuring
 | 
					 | 
				
			||||||
    // InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
 | 
					 | 
				
			||||||
    CalculatorGraphConfig config = graph.GetConfig();
 | 
					    CalculatorGraphConfig config = graph.GetConfig();
 | 
				
			||||||
    for (int i = 0; i < config.node_size(); ++i) {
 | 
					    core::FixGraphBackEdges(config);
 | 
				
			||||||
      if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
 | 
					 | 
				
			||||||
        auto* info = config.mutable_node(i)->add_input_stream_info();
 | 
					 | 
				
			||||||
        info->set_tag_index(kLoopTag);
 | 
					 | 
				
			||||||
        info->set_back_edge(true);
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return config;
 | 
					    return config;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -254,18 +254,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
 | 
					        graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
 | 
				
			||||||
    hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
 | 
					    hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO remove when support is fixed.
 | 
					 | 
				
			||||||
    // As mediapipe GraphBuilder currently doesn't support configuring
 | 
					 | 
				
			||||||
    // InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
 | 
					 | 
				
			||||||
    CalculatorGraphConfig config = graph.GetConfig();
 | 
					    CalculatorGraphConfig config = graph.GetConfig();
 | 
				
			||||||
    for (int i = 0; i < config.node_size(); ++i) {
 | 
					    core::FixGraphBackEdges(config);
 | 
				
			||||||
      if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
 | 
					 | 
				
			||||||
        auto* info = config.mutable_node(i)->add_input_stream_info();
 | 
					 | 
				
			||||||
        info->set_tag_index("LOOP");
 | 
					 | 
				
			||||||
        info->set_back_edge(true);
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return config;
 | 
					    return config;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -138,6 +138,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_asset_bundle_resources",
 | 
					        "//mediapipe/tasks/cc/core:model_asset_bundle_resources",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_resources_cache",
 | 
					        "//mediapipe/tasks/cc/core:model_resources_cache",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
					        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core:utils",
 | 
				
			||||||
        "//mediapipe/tasks/cc/metadata/utils:zip_utils",
 | 
					        "//mediapipe/tasks/cc/metadata/utils:zip_utils",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
 | 
					        "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
 | 
					        "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,7 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
 | 
					#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
 | 
					#include "mediapipe/tasks/cc/core/model_resources_cache.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
					#include "mediapipe/tasks/cc/core/model_task_graph.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/core/utils.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
 | 
					#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
 | 
					#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
 | 
					#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
 | 
				
			||||||
| 
						 | 
					@ -259,19 +260,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
 | 
				
			||||||
          graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
					          graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO remove when support is fixed.
 | 
					 | 
				
			||||||
    // As mediapipe GraphBuilder currently doesn't support configuring
 | 
					 | 
				
			||||||
    // InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
 | 
					 | 
				
			||||||
    CalculatorGraphConfig config = graph.GetConfig();
 | 
					    CalculatorGraphConfig config = graph.GetConfig();
 | 
				
			||||||
    for (int i = 0; i < config.node_size(); ++i) {
 | 
					    core::FixGraphBackEdges(config);
 | 
				
			||||||
      if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
 | 
					 | 
				
			||||||
        auto* info = config.mutable_node(i)->add_input_stream_info();
 | 
					 | 
				
			||||||
        info->set_tag_index(kLoopTag);
 | 
					 | 
				
			||||||
        info->set_back_edge(true);
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return config;
 | 
					    return config;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user