diff --git a/mediapipe/tasks/cc/core/utils.cc b/mediapipe/tasks/cc/core/utils.cc index 168c4363c..d6db1d69c 100644 --- a/mediapipe/tasks/cc/core/utils.cc +++ b/mediapipe/tasks/cc/core/utils.cc @@ -32,6 +32,7 @@ namespace core { namespace { constexpr char kFinishedTag[] = "FINISHED"; constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator"; +constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; } // namespace @@ -89,6 +90,19 @@ CalculatorGraphConfig AddFlowLimiterCalculator( return config; } +void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config) { + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + for (int i = 0; i < graph_config.node_size(); ++i) { + if (graph_config.node(i).calculator() == kPreviousLoopbackCalculatorName) { + auto* info = graph_config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index("LOOP"); + info->set_back_edge(true); + } + } +} + } // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/utils.h b/mediapipe/tasks/cc/core/utils.h index 54d63866d..5b56c8f57 100644 --- a/mediapipe/tasks/cc/core/utils.h +++ b/mediapipe/tasks/cc/core/utils.h @@ -84,6 +84,10 @@ static TensorType* FindTensorByName( std::vector input_stream_tags, std::string finished_stream_tag, int max_in_flight = 1, int max_in_queue = 1); +// Fixs the graph config containing PreviousLoopbackCalculator where the edge +// forming a loop needs to be tagged as back edge. +void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config); + } // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc index e563ba29a..c9a9a1932 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc @@ -393,18 +393,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kFaceGeometryTag)]; } - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == "PreviousLoopbackCalculator") { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index(kLoopTag); - info->set_back_edge(true); - break; - } - } + core::FixGraphBackEdges(config); return config; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index bb0eb4833..34f7e7a9f 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -254,18 +254,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kPalmDetectionsTag)]; hand_landmarker_outputs.image >> graph[Output(kImageTag)]; - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index("LOOP"); - info->set_back_edge(true); - break; - } - } + core::FixGraphBackEdges(config); return config; } diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index a706b405e..51ae92adc 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -138,6 +138,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 5e33e744b..2f5a8b99b 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h" @@ -259,19 +260,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { graph[Output>(kSegmentationMaskTag)]; } - // TODO remove when support is fixed. - // As mediapipe GraphBuilder currently doesn't support configuring - // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. CalculatorGraphConfig config = graph.GetConfig(); - for (int i = 0; i < config.node_size(); ++i) { - if (config.node(i).calculator() == "PreviousLoopbackCalculator") { - auto* info = config.mutable_node(i)->add_input_stream_info(); - info->set_tag_index(kLoopTag); - info->set_back_edge(true); - break; - } - } - + core::FixGraphBackEdges(config); return config; }