diff --git a/docs/framework_concepts/building_graphs_cpp.md b/docs/framework_concepts/building_graphs_cpp.md index 87c9755bf..c7e11f131 100644 --- a/docs/framework_concepts/building_graphs_cpp.md +++ b/docs/framework_concepts/building_graphs_cpp.md @@ -156,8 +156,124 @@ And you don't need to duplicate names and tags (`InferenceCalculator`, `TENSORS`, `MODEL`) or introduce dedicated constants here and there - those details are localized to `RunInference` function. -Tip: extracting `RunInference` and similar functions to dedicated units (e.g. +Tip: extracting `RunInference` and similar functions to dedicated modules (e.g. inference.h/cc which depends on the inference calculator) enables reuse in graphs construction code and helps automatically pull in calculator dependencies (e.g. no need to manually add `:inference_calculator` dep, just let your IDE include `inference.h` and build cleaner pull in corresponding dependency). + +#### Utility Classes + +And surely, it's not only about functions, in some cases it's beneficial to +introduce utility classes which can help making your graph construction code +more readable and less error prone. + +MediaPipe offers `PassThroughCalculator` calculator, which is simply passing +through its inputs: + +``` +input_stream: "float_value" +input_stream: "int_value" +input_stream: "bool_value" + +output_stream: "passed_float_value" +output_stream: "passed_int_value" +output_stream: "passed_bool_value" + +node { + calculator: "PassThroughCalculator" + input_stream: "float_value" + input_stream: "int_value" + input_stream: "bool_value" + // The order must be the same as for inputs (or you can use explicit indexes) + output_stream: "passed_float_value" + output_stream: "passed_int_value" + output_stream: "passed_bool_value" +} +``` + +Let's see the straightforward C++ construction code to create the above graph: + +```c++ +CalculatorGraphConfig BuildGraph() { + Graph graph; + + // Graph inputs. + Stream float_value = graph.In(0).SetName("float_value").Cast(); + Stream int_value = graph.In(1).SetName("int_value").Cast(); + Stream bool_value = graph.In(2).SetName("bool_value").Cast(); + + auto& pass_node = graph.AddNode("PassThroughCalculator"); + float_value.ConnectTo(pass_node.In("")[0]); + int_value.ConnectTo(pass_node.In("")[1]); + bool_value.ConnectTo(pass_node.In("")[2]); + Stream passed_float_value = pass_node.Out("")[0].Cast(); + Stream passed_int_value = pass_node.Out("")[1].Cast(); + Stream passed_bool_value = pass_node.Out("")[2].Cast(); + + // Graph outputs. + passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0)); + passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1)); + passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2)); + + // Get `CalculatorGraphConfig` to pass it into `CalculatorGraph` + return graph.GetConfig(); +} +``` + +While `pbtxt` representation maybe error prone (when we have many inputs to pass +through), C++ code looks even worse: repeated empty tags and `Cast` calls. Let's +see how we can do better by introducing a `PassThroughNodeBuilder`: + +```c++ +class PassThroughNodeBuilder { + public: + explicit PassThroughNodeBuilder(Graph& graph) + : node_(graph.AddNode("PassThroughCalculator")) {} + + template + Stream PassThrough(Stream stream) { + stream.ConnectTo(node_.In(index_)); + return node_.Out(index_++).Cast(); + } + + private: + int index_ = 0; + GenericNode& node_; +}; +``` + +And now graph construction code can look like: + +```c++ +CalculatorGraphConfig BuildGraph() { + Graph graph; + + // Graph inputs. + Stream float_value = graph.In(0).SetName("float_value").Cast(); + Stream int_value = graph.In(1).SetName("int_value").Cast(); + Stream bool_value = graph.In(2).SetName("bool_value").Cast(); + + PassThroughNodeBuilder pass_node_builder(graph); + Stream passed_float_value = pass_node_builder.PassThrough(float_value); + Stream passed_int_value = pass_node_builder.PassThrough(int_value); + Stream passed_bool_value = pass_node_builder.PassThrough(bool_value); + + // Graph outputs. + passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0)); + passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1)); + passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2)); + + // Get `CalculatorGraphConfig` to pass it into `CalculatorGraph` + return graph.GetConfig(); +} +``` + +Now you can't have incorrect order or index in your pass through construction +code and save some typing by guessing the type for `Cast` from the `PassThrough` +input. + +Tip: the same as for the `RunInference` function, extracting +`PassThroughNodeBuilder` and similar utility classes into dedicated modules +enables reuse in graph construction code and helps to automatically pull in the +corresponding calculator dependencies.