Introduce FixGraphBackEdges utils function.
PiperOrigin-RevId: 573925628
This commit is contained in:
parent
a1e1b5d34c
commit
2e11444f5c
|
@ -32,6 +32,7 @@ namespace core {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr char kFinishedTag[] = "FINISHED";
|
constexpr char kFinishedTag[] = "FINISHED";
|
||||||
constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
|
constexpr char kFlowLimiterCalculatorName[] = "FlowLimiterCalculator";
|
||||||
|
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -89,6 +90,19 @@ CalculatorGraphConfig AddFlowLimiterCalculator(
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config) {
|
||||||
|
// TODO remove when support is fixed.
|
||||||
|
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||||
|
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
|
||||||
|
for (int i = 0; i < graph_config.node_size(); ++i) {
|
||||||
|
if (graph_config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
|
||||||
|
auto* info = graph_config.mutable_node(i)->add_input_stream_info();
|
||||||
|
info->set_tag_index("LOOP");
|
||||||
|
info->set_back_edge(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -84,6 +84,10 @@ static TensorType* FindTensorByName(
|
||||||
std::vector<std::string> input_stream_tags, std::string finished_stream_tag,
|
std::vector<std::string> input_stream_tags, std::string finished_stream_tag,
|
||||||
int max_in_flight = 1, int max_in_queue = 1);
|
int max_in_flight = 1, int max_in_queue = 1);
|
||||||
|
|
||||||
|
// Fixs the graph config containing PreviousLoopbackCalculator where the edge
|
||||||
|
// forming a loop needs to be tagged as back edge.
|
||||||
|
void FixGraphBackEdges(::mediapipe::CalculatorGraphConfig& graph_config);
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -393,18 +393,8 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<std::vector<FaceGeometry>>(kFaceGeometryTag)];
|
graph[Output<std::vector<FaceGeometry>>(kFaceGeometryTag)];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO remove when support is fixed.
|
|
||||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
|
||||||
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
|
|
||||||
CalculatorGraphConfig config = graph.GetConfig();
|
CalculatorGraphConfig config = graph.GetConfig();
|
||||||
for (int i = 0; i < config.node_size(); ++i) {
|
core::FixGraphBackEdges(config);
|
||||||
if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
|
|
||||||
auto* info = config.mutable_node(i)->add_input_stream_info();
|
|
||||||
info->set_tag_index(kLoopTag);
|
|
||||||
info->set_back_edge(true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -254,18 +254,8 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||||
hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
|
hand_landmarker_outputs.image >> graph[Output<Image>(kImageTag)];
|
||||||
|
|
||||||
// TODO remove when support is fixed.
|
|
||||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
|
||||||
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
|
|
||||||
CalculatorGraphConfig config = graph.GetConfig();
|
CalculatorGraphConfig config = graph.GetConfig();
|
||||||
for (int i = 0; i < config.node_size(); ++i) {
|
core::FixGraphBackEdges(config);
|
||||||
if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) {
|
|
||||||
auto* info = config.mutable_node(i)->add_input_stream_info();
|
|
||||||
info->set_tag_index("LOOP");
|
|
||||||
info->set_back_edge(true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,6 +138,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||||
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
|
"//mediapipe/tasks/cc/vision/pose_detector:pose_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||||
|
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options.pb.h"
|
||||||
|
@ -259,19 +260,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO remove when support is fixed.
|
|
||||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
|
||||||
// InputStreamInfo, modifying the CalculatorGraphConfig proto directly.
|
|
||||||
CalculatorGraphConfig config = graph.GetConfig();
|
CalculatorGraphConfig config = graph.GetConfig();
|
||||||
for (int i = 0; i < config.node_size(); ++i) {
|
core::FixGraphBackEdges(config);
|
||||||
if (config.node(i).calculator() == "PreviousLoopbackCalculator") {
|
|
||||||
auto* info = config.mutable_node(i)->add_input_stream_info();
|
|
||||||
info->set_tag_index(kLoopTag);
|
|
||||||
info->set_back_edge(true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user