Add Advanced Usage/Utility Classes section in Building Graph in C++
PiperOrigin-RevId: 508776246
This commit is contained in:
parent
482ee8f96c
commit
23012f2151
|
@ -156,8 +156,124 @@ And you don't need to duplicate names and tags (`InferenceCalculator`,
|
||||||
`TENSORS`, `MODEL`) or introduce dedicated constants here and there - those
|
`TENSORS`, `MODEL`) or introduce dedicated constants here and there - those
|
||||||
details are localized to `RunInference` function.
|
details are localized to `RunInference` function.
|
||||||
|
|
||||||
Tip: extracting `RunInference` and similar functions to dedicated units (e.g.
|
Tip: extracting `RunInference` and similar functions to dedicated modules (e.g.
|
||||||
inference.h/cc which depends on the inference calculator) enables reuse in
|
inference.h/cc which depends on the inference calculator) enables reuse in
|
||||||
graphs construction code and helps automatically pull in calculator dependencies
|
graphs construction code and helps automatically pull in calculator dependencies
|
||||||
(e.g. no need to manually add `:inference_calculator` dep, just let your IDE
|
(e.g. no need to manually add `:inference_calculator` dep, just let your IDE
|
||||||
include `inference.h` and build cleaner pull in corresponding dependency).
|
include `inference.h` and build cleaner pull in corresponding dependency).
|
||||||
|
|
||||||
|
#### Utility Classes
|
||||||
|
|
||||||
|
And surely, it's not only about functions, in some cases it's beneficial to
|
||||||
|
introduce utility classes which can help making your graph construction code
|
||||||
|
more readable and less error prone.
|
||||||
|
|
||||||
|
MediaPipe offers `PassThroughCalculator` calculator, which is simply passing
|
||||||
|
through its inputs:
|
||||||
|
|
||||||
|
```
|
||||||
|
input_stream: "float_value"
|
||||||
|
input_stream: "int_value"
|
||||||
|
input_stream: "bool_value"
|
||||||
|
|
||||||
|
output_stream: "passed_float_value"
|
||||||
|
output_stream: "passed_int_value"
|
||||||
|
output_stream: "passed_bool_value"
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughCalculator"
|
||||||
|
input_stream: "float_value"
|
||||||
|
input_stream: "int_value"
|
||||||
|
input_stream: "bool_value"
|
||||||
|
// The order must be the same as for inputs (or you can use explicit indexes)
|
||||||
|
output_stream: "passed_float_value"
|
||||||
|
output_stream: "passed_int_value"
|
||||||
|
output_stream: "passed_bool_value"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Let's see the straightforward C++ construction code to create the above graph:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
CalculatorGraphConfig BuildGraph() {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
|
||||||
|
Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
|
||||||
|
Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();
|
||||||
|
|
||||||
|
auto& pass_node = graph.AddNode("PassThroughCalculator");
|
||||||
|
float_value.ConnectTo(pass_node.In("")[0]);
|
||||||
|
int_value.ConnectTo(pass_node.In("")[1]);
|
||||||
|
bool_value.ConnectTo(pass_node.In("")[2]);
|
||||||
|
Stream<float> passed_float_value = pass_node.Out("")[0].Cast<float>();
|
||||||
|
Stream<int> passed_int_value = pass_node.Out("")[1].Cast<int>();
|
||||||
|
Stream<bool> passed_bool_value = pass_node.Out("")[2].Cast<bool>();
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
|
||||||
|
passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
|
||||||
|
passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));
|
||||||
|
|
||||||
|
// Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
While `pbtxt` representation maybe error prone (when we have many inputs to pass
|
||||||
|
through), C++ code looks even worse: repeated empty tags and `Cast` calls. Let's
|
||||||
|
see how we can do better by introducing a `PassThroughNodeBuilder`:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
class PassThroughNodeBuilder {
|
||||||
|
public:
|
||||||
|
explicit PassThroughNodeBuilder(Graph& graph)
|
||||||
|
: node_(graph.AddNode("PassThroughCalculator")) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Stream<T> PassThrough(Stream<T> stream) {
|
||||||
|
stream.ConnectTo(node_.In(index_));
|
||||||
|
return node_.Out(index_++).Cast<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int index_ = 0;
|
||||||
|
GenericNode& node_;
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
And now graph construction code can look like:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
CalculatorGraphConfig BuildGraph() {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<float> float_value = graph.In(0).SetName("float_value").Cast<float>();
|
||||||
|
Stream<int> int_value = graph.In(1).SetName("int_value").Cast<int>();
|
||||||
|
Stream<bool> bool_value = graph.In(2).SetName("bool_value").Cast<bool>();
|
||||||
|
|
||||||
|
PassThroughNodeBuilder pass_node_builder(graph);
|
||||||
|
Stream<float> passed_float_value = pass_node_builder.PassThrough(float_value);
|
||||||
|
Stream<int> passed_int_value = pass_node_builder.PassThrough(int_value);
|
||||||
|
Stream<bool> passed_bool_value = pass_node_builder.PassThrough(bool_value);
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
passed_float_value.SetName("passed_float_value").ConnectTo(graph.Out(0));
|
||||||
|
passed_int_value.SetName("passed_int_value").ConnectTo(graph.Out(1));
|
||||||
|
passed_bool_value.SetName("passed_bool_value").ConnectTo(graph.Out(2));
|
||||||
|
|
||||||
|
// Get `CalculatorGraphConfig` to pass it into `CalculatorGraph`
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can't have incorrect order or index in your pass through construction
|
||||||
|
code and save some typing by guessing the type for `Cast` from the `PassThrough`
|
||||||
|
input.
|
||||||
|
|
||||||
|
Tip: the same as for the `RunInference` function, extracting
|
||||||
|
`PassThroughNodeBuilder` and similar utility classes into dedicated modules
|
||||||
|
enables reuse in graph construction code and helps to automatically pull in the
|
||||||
|
corresponding calculator dependencies.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user