Add Capabilities/Utility Functions section in Building Graph in C++
PiperOrigin-RevId: 508773788
This commit is contained in:
parent
e719d2d423
commit
482ee8f96c
|
@ -17,8 +17,8 @@ nav_order: 1
|
||||||
C++ graph builder is a powerful tool for:
|
C++ graph builder is a powerful tool for:
|
||||||
|
|
||||||
* Building complex graphs
|
* Building complex graphs
|
||||||
* Parametrizing graphs (e.g. setting a delegate on
|
* Parametrizing graphs (e.g. setting a delegate on `InferenceCalculator`,
|
||||||
`InferenceCalculator`, enabling/disabling parts of the graph)
|
enabling/disabling parts of the graph)
|
||||||
* Deduplicating graphs (e.g. instead of CPU and GPU dedicated graphs in pbtxt
|
* Deduplicating graphs (e.g. instead of CPU and GPU dedicated graphs in pbtxt
|
||||||
you can have a single code that constructs required graphs, sharing as much
|
you can have a single code that constructs required graphs, sharing as much
|
||||||
as possible)
|
as possible)
|
||||||
|
@ -37,7 +37,6 @@ input_side_packet: "model"
|
||||||
// Graph outputs.
|
// Graph outputs.
|
||||||
output_stream: "output_tensors"
|
output_stream: "output_tensors"
|
||||||
|
|
||||||
// Nodes.
|
|
||||||
node {
|
node {
|
||||||
calculator: "InferenceCalculator"
|
calculator: "InferenceCalculator"
|
||||||
input_stream: "TENSORS:input_tensors"
|
input_stream: "TENSORS:input_tensors"
|
||||||
|
@ -64,7 +63,6 @@ CalculatorGraphConfig BuildGraph() {
|
||||||
SidePacket<TfLiteModelPtr> model =
|
SidePacket<TfLiteModelPtr> model =
|
||||||
graph.SideIn(0).SetName("model").Cast<TfLiteModelPtr>();
|
graph.SideIn(0).SetName("model").Cast<TfLiteModelPtr>();
|
||||||
|
|
||||||
// Nodes.
|
|
||||||
auto& inference_node = graph.AddNode("InferenceCalculator");
|
auto& inference_node = graph.AddNode("InferenceCalculator");
|
||||||
auto& inference_opts =
|
auto& inference_opts =
|
||||||
inference_node.GetOptions<InferenceCalculatorOptions>();
|
inference_node.GetOptions<InferenceCalculatorOptions>();
|
||||||
|
@ -87,12 +85,79 @@ Short summary:
|
||||||
|
|
||||||
* Use `Graph::In/SideIn` to get graph inputs as `Stream/SidePacket`
|
* Use `Graph::In/SideIn` to get graph inputs as `Stream/SidePacket`
|
||||||
* Use `Node::Out/SideOut` to get node outputs as `Stream/SidePacket`
|
* Use `Node::Out/SideOut` to get node outputs as `Stream/SidePacket`
|
||||||
* Use `Stream/SidePacket::ConnectTo` to connect streams and side packets to node
|
* Use `Stream/SidePacket::ConnectTo` to connect streams and side packets to
|
||||||
inputs (`Node::In/SideIn`) and graph outputs (`Graph::Out/SideOut`)
|
node inputs (`Node::In/SideIn`) and graph outputs (`Graph::Out/SideOut`)
|
||||||
* There's a "shortcut" operator `>>` that you can use instead of
|
* There's a "shortcut" operator `>>` that you can use instead of
|
||||||
`ConnectTo` function (E.g. `x >> node.In("IN")`).
|
`ConnectTo` function (E.g. `x >> node.In("IN")`).
|
||||||
* `Stream/SidePacket::Cast` is used to cast stream or side packet of `AnyType` (E.g. `Stream<AnyType> in = graph.In(0);`) to a particular type
|
* `Stream/SidePacket::Cast` is used to cast stream or side packet of `AnyType`
|
||||||
* Using actual types instead of `AnyType` sets you on a better path for unleashing graph
|
(E.g. `Stream<AnyType> in = graph.In(0);`) to a particular type
|
||||||
builder capabilities and improving your graphs readability.
|
* Using actual types instead of `AnyType` sets you on a better path for
|
||||||
|
unleashing graph builder capabilities and improving your graphs
|
||||||
|
readability.
|
||||||
|
|
||||||
|
### Advanced Usage
|
||||||
|
|
||||||
|
#### Utility Functions
|
||||||
|
|
||||||
|
Let's extract inference construction code into a dedicated utility function to
|
||||||
|
help for readability and code reuse:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
// Updates graph to run inference.
|
||||||
|
Stream<std::vector<Tensor>> RunInference(
|
||||||
|
Stream<std::vector<Tensor>> tensors, SidePacket<TfLiteModelPtr> model,
|
||||||
|
const InferenceCalculatorOptions::Delegate& delegate, Graph& graph) {
|
||||||
|
auto& inference_node = graph.AddNode("InferenceCalculator");
|
||||||
|
auto& inference_opts =
|
||||||
|
inference_node.GetOptions<InferenceCalculatorOptions>();
|
||||||
|
*inference_opts.mutable_delegate() = delegate;
|
||||||
|
tensors.ConnectTo(inference_node.In("TENSORS"));
|
||||||
|
model.ConnectTo(inference_node.SideIn("MODEL"));
|
||||||
|
return inference_node.Out("TENSORS").Cast<std::vector<Tensor>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraphConfig BuildGraph() {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
// Graph inputs.
|
||||||
|
Stream<std::vector<Tensor>> input_tensors =
|
||||||
|
graph.In(0).SetName("input_tensors").Cast<std::vector<Tensor>>();
|
||||||
|
SidePacket<TfLiteModelPtr> model =
|
||||||
|
graph.SideIn(0).SetName("model").Cast<TfLiteModelPtr>();
|
||||||
|
|
||||||
|
InferenceCalculatorOptions::Delegate delegate;
|
||||||
|
delegate.mutable_gpu();
|
||||||
|
Stream<std::vector<Tensor>> output_tensors =
|
||||||
|
RunInference(input_tensors, model, delegate, graph);
|
||||||
|
|
||||||
|
// Graph outputs.
|
||||||
|
output_tensors.SetName("output_tensors").ConnectTo(graph.Out(0));
|
||||||
|
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
As a result, `RunInference` provides a clear interface stating what are the
|
||||||
|
inputs/outputs and their types.
|
||||||
|
|
||||||
|
It can be easily reused, e.g. it's only a few lines if you want to run an extra
|
||||||
|
model inference:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
// Run first inference.
|
||||||
|
Stream<std::vector<Tensor>> output_tensors =
|
||||||
|
RunInference(input_tensors, model, delegate, graph);
|
||||||
|
// Run second inference on the output of the first one.
|
||||||
|
Stream<std::vector<Tensor>> extra_output_tensors =
|
||||||
|
RunInference(output_tensors, extra_model, delegate, graph);
|
||||||
|
```
|
||||||
|
|
||||||
|
And you don't need to duplicate names and tags (`InferenceCalculator`,
|
||||||
|
`TENSORS`, `MODEL`) or introduce dedicated constants here and there - those
|
||||||
|
details are localized to `RunInference` function.
|
||||||
|
|
||||||
|
Tip: extracting `RunInference` and similar functions to dedicated units (e.g.
|
||||||
|
inference.h/cc which depends on the inference calculator) enables reuse in
|
||||||
|
graphs construction code and helps automatically pull in calculator dependencies
|
||||||
|
(e.g. no need to manually add `:inference_calculator` dep, just let your IDE
|
||||||
|
include `inference.h` and build cleaner pull in corresponding dependency).
|
||||||
|
|
Loading…
Reference in New Issue
Block a user