mediapipe/mediapipe/python/pybind/calculator_graph.cc
Sebastian Schmidt dc63a5401c No public description
PiperOrigin-RevId: 572266397
2023-10-10 09:02:48 -07:00

487 lines
18 KiB
C++

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/python/pybind/calculator_graph.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/calculator_graph_template.pb.h"
#include "mediapipe/python/pybind/util.h"
#include "pybind11/embed.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace mediapipe {
namespace python {
// A mutex to guard the output stream observer python callback function.
// Only one python callback can run at once.
absl::Mutex callback_mutex;
template <typename T>
T ParseProto(const py::object& proto_object) {
T proto;
if (!ParseTextProto<T>(proto_object.str(), &proto)) {
throw RaisePyError(
PyExc_RuntimeError,
absl::StrCat("Failed to parse: ", std::string(proto_object.str()))
.c_str());
}
return proto;
}
namespace py = pybind11;
void CalculatorGraphSubmodule(pybind11::module* module) {
py::module m = module->def_submodule("calculator_graph",
"MediaPipe calculator graph module.");
using GraphInputStreamAddMode =
mediapipe::CalculatorGraph::GraphInputStreamAddMode;
py::enum_<GraphInputStreamAddMode>(m, "GraphInputStreamAddMode")
.value("WAIT_TILL_NOT_FULL", GraphInputStreamAddMode::WAIT_TILL_NOT_FULL)
.value("ADD_IF_NOT_FULL", GraphInputStreamAddMode::ADD_IF_NOT_FULL)
.export_values();
// Calculator Graph
py::class_<CalculatorGraph> calculator_graph(
m, "CalculatorGraph", R"doc(The primary API for the MediaPipe Framework.
MediaPipe processing takes place inside a graph, which defines packet flow
paths between nodes. A graph can have any number of inputs and outputs, and
data flow can branch and merge. Generally data flows forward, but backward
loops are possible.)doc");
// TODO: Support graph initialization with graph templates and
// subgraph.
calculator_graph.def(
py::init([](py::kwargs kwargs) {
bool init_with_binary_graph = false;
bool init_with_graph_proto = false;
bool init_with_validated_graph_config = false;
CalculatorGraphConfig graph_config_proto;
for (const auto& kw : kwargs) {
const std::string& key = kw.first.cast<std::string>();
if (key == "binary_graph_path") {
init_with_binary_graph = true;
std::string file_name(kw.second.cast<py::object>().str());
graph_config_proto = ReadCalculatorGraphConfigFromFile(file_name);
} else if (key == "graph_config") {
init_with_graph_proto = true;
graph_config_proto =
ParseProto<CalculatorGraphConfig>(kw.second.cast<py::object>());
} else if (key == "validated_graph_config") {
init_with_validated_graph_config = true;
graph_config_proto =
py::cast<ValidatedGraphConfig*>(kw.second)->Config();
} else {
throw RaisePyError(
PyExc_RuntimeError,
absl::StrCat("Unknown kwargs input argument: ", key).c_str());
}
}
if ((init_with_binary_graph ? 1 : 0) + (init_with_graph_proto ? 1 : 0) +
(init_with_validated_graph_config ? 1 : 0) !=
1) {
throw RaisePyError(PyExc_ValueError,
"Please provide one of the following: "
"\'binary_graph_path\' to initialize the graph "
"with a binary graph file, or "
"\'graph_config\' to initialize the graph with a "
"graph config proto, or "
"\'validated_graph_config\' to initialize the "
"graph with a ValidatedGraphConfig object.");
}
auto calculator_graph = absl::make_unique<CalculatorGraph>();
// Disable default service initialization. This allows us to use
// the CPU versions of calculators that only optionally request
// kGpuService.
RaisePyErrorIfNotOk(
calculator_graph->DisallowServiceDefaultInitialization());
RaisePyErrorIfNotOk(calculator_graph->Initialize(graph_config_proto));
return calculator_graph.release();
}),
R"doc(Initialize CalculatorGraph object.
Args:
binary_graph_path: The path to a binary mediapipe graph file (.binarypb).
graph_config: A single CalculatorGraphConfig proto message or its text proto
format.
validated_graph_config: A ValidatedGraphConfig object.
Raises:
FileNotFoundError: If the binary graph file can't be found.
ValueError: If the input arguments prvoided are more than needed or the
graph validation process contains error.
)doc");
// TODO: Return a Python CalculatorGraphConfig instead.
calculator_graph.def_property_readonly(
"text_config",
[](const CalculatorGraph& self) { return self.Config().DebugString(); });
calculator_graph.def_property_readonly(
"binary_config", [](const CalculatorGraph& self) {
return py::bytes(self.Config().SerializeAsString());
});
calculator_graph.def_property_readonly(
"max_queue_size",
[](CalculatorGraph* self) { return self->GetMaxInputStreamQueueSize(); });
calculator_graph.def_property(
"graph_input_stream_add_mode",
[](const CalculatorGraph& self) {
return self.GetGraphInputStreamAddMode();
},
[](CalculatorGraph* self, CalculatorGraph::GraphInputStreamAddMode mode) {
self->SetGraphInputStreamAddMode(mode);
});
calculator_graph.def(
"add_packet_to_input_stream",
[](CalculatorGraph* self, const std::string& stream, const Packet& packet,
const Timestamp& timestamp) {
Timestamp packet_timestamp =
timestamp == Timestamp::Unset() ? packet.Timestamp() : timestamp;
if (!packet_timestamp.IsAllowedInStream()) {
throw RaisePyError(
PyExc_ValueError,
absl::StrCat(packet_timestamp.DebugString(),
" can't be the timestamp of a Packet in a stream.")
.c_str());
}
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(
self->AddPacketToInputStream(stream, packet.At(packet_timestamp)),
/**acquire_gil=*/true);
},
R"doc(Add a packet to a graph input stream.
If the graph input stream add mode is ADD_IF_NOT_FULL, the packet will not be
added if any queue exceeds the max queue size specified by the graph config
and will raise a Python runtime error. The WAIT_TILL_NOT_FULL mode (default)
will block until the queues fall below the max queue size before adding the
packet. If the mode is max queue size is -1, then the packet is added
regardless of the sizes of the queues in the graph. The input stream must have
been specified in the configuration as a graph level input stream. On error,
nothing is added.
Args:
stream: The name of the graph input stream.
packet: The packet to be added into the input stream.
timestamp: The timestamp of the packet. If set, the original packet
timestamp will be overwritten.
Raises:
RuntimeError: If the stream is not a graph input stream or the packet can't
be added into the input stream due to the limited queue size or the wrong
packet type.
ValueError: If the timestamp of the Packet is invalid to be the timestamp of
a Packet in a stream.
Examples:
graph.add_packet_to_input_stream(
stream='in',
packet=packet_creator.create_string('hello world').at(0))
graph.add_packet_to_input_stream(
stream='in',
packet=packet_creator.create_string('hello world'),
timstamp=1)
)doc",
py::arg("stream"), py::arg("packet"),
py::arg("timestamp") = Timestamp::Unset());
calculator_graph.def(
"close_input_stream",
[](CalculatorGraph* self, const std::string& stream) {
RaisePyErrorIfNotOk(self->CloseInputStream(stream));
},
R"doc(Close the named graph input stream.
Args:
stream: The name of the stream to be closed.
Raises:
RuntimeError: If the stream is not a graph input stream.
)doc");
calculator_graph.def(
"close_all_packet_sources",
[](CalculatorGraph* self) {
RaisePyErrorIfNotOk(self->CloseAllPacketSources());
},
R"doc(Closes all the graph input streams and source calculator nodes.)doc");
calculator_graph.def(
"start_run",
[](CalculatorGraph* self, const pybind11::dict& input_side_packets) {
std::map<std::string, Packet> input_side_packet_map;
for (const auto& kv_pair : input_side_packets) {
InsertIfNotPresent(&input_side_packet_map,
kv_pair.first.cast<std::string>(),
kv_pair.second.cast<Packet>());
}
RaisePyErrorIfNotOk(self->StartRun(input_side_packet_map));
},
R"doc(Start a run of the calculator graph.
A non-blocking call to start a run of the graph and will return when the graph
is started. If input_side_packets is provided, the method will runs the graph
after adding the given extra input side packets.
start_run(), wait_until_done(), has_error(), add_packet_to_input_stream(), and
close() allow more control over the execution of the graph run. You can
insert packets directly into a stream while the graph is running.
Once start_run() has been called, the graph will continue to run until
wait_until_done() is called.
If start_run() returns an error, then the graph is not started and a
subsequent call to start_run() can be attempted.
Args:
input_side_packets: A dict maps from the input side packet names to the
packets.
Raises:
RuntimeError: If the start run occurs any error, e.g. the graph config has
errors, the calculator can't be found, and the streams are not properly
connected.
Examples:
graph = mp.CalculatorGraph(graph_config=video_process_graph)
graph.start_run(
input_side_packets={
'input_path': packet_creator.create_string('/tmp/input.video'),
'output_path': packet_creator.create_string('/tmp/output.video')
})
graph.close()
out = []
graph = mp.CalculatorGraph(graph_config=pass_through_graph)
graph.observe_output_stream('out',
lambda stream_name, packet: out.append(packet))
graph.start_run()
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(0), timestamp=0)
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(1), timestamp=1)
graph.close()
)doc",
py::arg("input_side_packets") = py::dict());
calculator_graph.def(
"wait_until_done",
[](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
},
R"doc(Wait for the current run to finish.
A blocking call to wait for the current run to finish (block the current
thread until all source calculators are stopped, all graph input streams have
been closed, and no more calculators can be run). This function can be called
only after start_run(),
Raises:
RuntimeError: If the graph occurs any error during the wait call.
Examples:
out = []
graph = mp.CalculatorGraph(graph_config=pass_through_graph)
graph.observe_output_stream('out', lambda stream_name, packet: out.append(packet))
graph.start_run()
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(0), timestamp=0)
graph.close_all_packet_sources()
graph.wait_until_done()
)doc");
calculator_graph.def(
"wait_until_idle",
[](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilIdle(), /**acquire_gil=*/true);
},
R"doc(Wait until the running graph is in the idle mode.
Wait until the running graph is in the idle mode, which is when nothing can
be scheduled and nothing is running in the worker threads. This function can
be called only after start_run().
NOTE: The graph must not have any source nodes because source nodes prevent
the running graph from becoming idle until the source nodes are done.
Raises:
RuntimeError: If the graph occurs any error during the wait call.
Examples:
out = []
graph = mp.CalculatorGraph(graph_config=pass_through_graph)
graph.observe_output_stream('out',
lambda stream_name, packet: out.append(packet))
graph.start_run()
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(0), timestamp=0)
graph.wait_until_idle()
)doc");
calculator_graph.def(
"wait_for_observed_output",
[](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitForObservedOutput(),
/**acquire_gil=*/true);
},
R"doc(Wait until a packet is emitted on one of the observed output streams.
Returns immediately if a packet has already been emitted since the last
call to this function.
Raises:
RuntimeError:
If the graph occurs any error or the graph is terminated while waiting.
Examples:
out = []
graph = mp.CalculatorGraph(graph_config=pass_through_graph)
graph.observe_output_stream('out',
lambda stream_name, packet: out.append(packet))
graph.start_run()
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(0), timestamp=0)
graph.wait_for_observed_output()
value = packet_getter.get_int(out[0])
graph.add_packet_to_input_stream(
stream='in', packet=packet_creator.create_int(1), timestamp=1)
graph.wait_for_observed_output()
value = packet_getter.get_int(out[1])
)doc");
calculator_graph.def(
"has_error", [](const CalculatorGraph& self) { return self.HasError(); },
R"doc(Quick non-locking means of checking if the graph has encountered an error)doc");
calculator_graph.def(
"get_combined_error_message",
[](CalculatorGraph* self) {
absl::Status error_status;
if (self->GetCombinedErrors(&error_status) && !error_status.ok()) {
return error_status.ToString();
}
return std::string();
},
R"doc(Combines error messages as a single string.
Examples:
if graph.has_error():
print(graph.get_combined_error_message())
)doc");
// TODO: Support passing a single-argument lambda for convenience.
calculator_graph.def(
"observe_output_stream",
[](CalculatorGraph* self, const std::string& stream_name,
pybind11::function callback_fn, bool observe_timestamp_bounds) {
RaisePyErrorIfNotOk(self->ObserveOutputStream(
stream_name,
[callback_fn, stream_name](const Packet& packet) {
absl::MutexLock lock(&callback_mutex);
// Acquires GIL before calling Python callback.
py::gil_scoped_acquire gil_acquire;
callback_fn(stream_name, packet);
return absl::OkStatus();
},
observe_timestamp_bounds));
},
R"doc(Observe the named output stream.
callback_fn will be invoked on every packet emitted by the output stream.
This method can only be called before start_run().
Args:
stream_name: The name of the output stream.
callback_fn: The callback function to invoke on every packet emitted by the
output stream.
observe_timestamp_bounds: If true, emits an empty packet at
timestamp_bound -1 when timestamp bound changes.
Raises:
RuntimeError: If the calculator graph isn't initialized or the stream
doesn't exist.
Examples:
out = []
graph = mp.CalculatorGraph(graph_config=graph_config)
graph.observe_output_stream('out',
lambda stream_name, packet: out.append(packet))
)doc",
py::arg("stream_name"), py::arg("callback_fn"),
py::arg("observe_timestamp_bounds") = false);
calculator_graph.def(
"close",
[](CalculatorGraph* self) {
RaisePyErrorIfNotOk(self->CloseAllPacketSources());
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
},
R"doc(Close all the input sources and shutdown the graph.)doc");
calculator_graph.def(
"get_output_side_packet",
[](CalculatorGraph* self, const std::string& packet_name) {
auto status_or_packet = self->GetOutputSidePacket(packet_name);
RaisePyErrorIfNotOk(status_or_packet.status());
return status_or_packet.value();
},
R"doc(Get output side packet by name after the graph is done.
Args:
stream: The name of the outnput stream.
Raises:
RuntimeError: If the graph is still running or the output side packet is not
found or empty.
Examples:
graph = mp.CalculatorGraph(graph_config=graph_config)
graph.start_run()
graph.close()
output_side_packet = graph.get_output_side_packet('packet_name')
)doc",
py::return_value_policy::move);
}
} // namespace python
} // namespace mediapipe