Add Advanced Usage/Utility Classes section in Building Graph in C++

PiperOrigin-RevId: 508776246
This commit is contained in:
MediaPipe Team 2023-02-10 16:04:23 -08:00 committed by Copybara-Service
parent 482ee8f96c
commit 23012f2151

View File

@ -156,8 +156,124 @@ And you don't need to duplicate names and tags (`InferenceCalculator`,
`TENSORS`, `MODEL`) or introduce dedicated constants here and there - those `TENSORS`, `MODEL`) or introduce dedicated constants here and there - those
details are localized to `RunInference` function. details are localized to `RunInference` function.
Tip: extracting `RunInference` and similar functions to dedicated units (e.g. Tip: extracting `RunInference` and similar functions to dedicated modules (e.g.
inference.h/cc which depends on the inference calculator) enables reuse in inference.h/cc which depends on the inference calculator) enables reuse in
graphs construction code and helps automatically pull in calculator dependencies graphs construction code and helps automatically pull in calculator dependencies
(e.g. no need to manually add `:inference_calculator` dep, just let your IDE (e.g. no need to manually add `:inference_calculator` dep, just let your IDE
include `inference.h` and build cleaner pull in corresponding dependency). include `inference.h` and build cleaner pull in corresponding dependency).
#### Utility Classes
And surely, it's not only about functions, in some cases it's beneficial to
introduce utility classes which can help making your graph construction code
more readable and less error prone.
MediaPipe offers `PassThroughCalculator` calculator, which is simply passing
through its inputs:
```
input_stream: "float_value"
input_stream: "int_value"
input_stream: "bool_value"
output_stream: "passed_float_value"
output_stream: "passed_int_value"
output_stream: "passed_bool_value"
node {
calculator: "PassThroughCalculator"
input_stream: "float_value"
input_stream: "int_value"
input_stream: "bool_value"
// The order must be the same as for inputs (or you can use explicit indexes)
output_stream: "passed_float_value"
output_stream: "passed_int_value"
output_stream: "passed_bool_value"
}
```
Let's see the straightforward C++ construction code to create the above graph:
```c++
CalculatorGraphConfig BuildGraph() {
Graph graph;
// Graph inputs.
Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();
auto& pass_node = graph.AddNode("PassThroughCalculator");
float_value.ConnectTo(pass_node.In("")[0]);
int_value.ConnectTo(pass_node.In("")[1]);
bool_value.ConnectTo(pass_node.In("")[2]);
Stream<float> passed_float_value = pass_node.Out("")[0].Cast<float>();
Stream<int> passed_int_value = pass_node.Out("")[1].Cast<int>();
Stream<bool> passed_bool_value = pass_node.Out("")[2].Cast<bool>();
// Graph outputs.
passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));
// Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
return graph.GetConfig();
}
```
While `pbtxt` representation maybe error prone (when we have many inputs to pass
through), C++ code looks even worse: repeated empty tags and `Cast` calls. Let's
see how we can do better by introducing a `PassThroughNodeBuilder`:
```c++
class PassThroughNodeBuilder {
public:
explicit PassThroughNodeBuilder(Graph& graph)
: node_(graph.AddNode("PassThroughCalculator")) {}
template <typename T>
Stream<T> PassThrough(Stream<T> stream) {
stream.ConnectTo(node_.In(index_));
return node_.Out(index_++).Cast<T>();
}
private:
int index_ = 0;
GenericNode& node_;
};
```
And now graph construction code can look like:
```c++
CalculatorGraphConfig BuildGraph() {
Graph graph;
// Graph inputs.
Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();
PassThroughNodeBuilder pass_node_builder(graph);
Stream<float> passed_float_value = pass_node_builder.PassThrough(float_value);
Stream<int> passed_int_value = pass_node_builder.PassThrough(int_value);
Stream<bool> passed_bool_value = pass_node_builder.PassThrough(bool_value);
// Graph outputs.
passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));
// Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
return graph.GetConfig();
}
```
Now you can't have incorrect order or index in your pass through construction
code and save some typing by guessing the type for `Cast` from the `PassThrough`
input.
Tip: the same as for the `RunInference` function, extracting
`PassThroughNodeBuilder` and similar utility classes into dedicated modules
enables reuse in graph construction code and helps to automatically pull in the
corresponding calculator dependencies.