487 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			487 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright 2020 The MediaPipe Authors.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "mediapipe/python/pybind/calculator_graph.h"
 | |
| 
 | |
| #include "absl/memory/memory.h"
 | |
| #include "absl/strings/str_cat.h"
 | |
| #include "mediapipe/framework/calculator.pb.h"
 | |
| #include "mediapipe/framework/calculator_graph.h"
 | |
| #include "mediapipe/framework/packet.h"
 | |
| #include "mediapipe/framework/port/map_util.h"
 | |
| #include "mediapipe/framework/port/parse_text_proto.h"
 | |
| #include "mediapipe/framework/port/status.h"
 | |
| #include "mediapipe/framework/tool/calculator_graph_template.pb.h"
 | |
| #include "mediapipe/python/pybind/util.h"
 | |
| #include "pybind11/embed.h"
 | |
| #include "pybind11/pybind11.h"
 | |
| #include "pybind11/stl.h"
 | |
| 
 | |
| namespace mediapipe {
 | |
| namespace python {
 | |
| 
 | |
| // A mutex to guard the output stream observer python callback function.
 | |
| // Only one python callback can run at once.
 | |
| absl::Mutex callback_mutex;
 | |
| 
 | |
| template <typename T>
 | |
| T ParseProto(const py::object& proto_object) {
 | |
|   T proto;
 | |
|   if (!ParseTextProto<T>(proto_object.str(), &proto)) {
 | |
|     throw RaisePyError(
 | |
|         PyExc_RuntimeError,
 | |
|         absl::StrCat("Failed to parse: ", std::string(proto_object.str()))
 | |
|             .c_str());
 | |
|   }
 | |
|   return proto;
 | |
| }
 | |
| 
 | |
| namespace py = pybind11;
 | |
| 
 | |
| void CalculatorGraphSubmodule(pybind11::module* module) {
 | |
|   py::module m = module->def_submodule("calculator_graph",
 | |
|                                        "MediaPipe calculator graph module.");
 | |
| 
 | |
|   using GraphInputStreamAddMode =
 | |
|       mediapipe::CalculatorGraph::GraphInputStreamAddMode;
 | |
| 
 | |
|   py::enum_<GraphInputStreamAddMode>(m, "GraphInputStreamAddMode")
 | |
|       .value("WAIT_TILL_NOT_FULL", GraphInputStreamAddMode::WAIT_TILL_NOT_FULL)
 | |
|       .value("ADD_IF_NOT_FULL", GraphInputStreamAddMode::ADD_IF_NOT_FULL)
 | |
|       .export_values();
 | |
| 
 | |
|   // Calculator Graph
 | |
|   py::class_<CalculatorGraph> calculator_graph(
 | |
|       m, "CalculatorGraph", R"doc(The primary API for the MediaPipe Framework.
 | |
| 
 | |
|   MediaPipe processing takes place inside a graph, which defines packet flow
 | |
|   paths between nodes. A graph can have any number of inputs and outputs, and
 | |
|   data flow can branch and merge. Generally data flows forward, but backward
 | |
|   loops are possible.)doc");
 | |
| 
 | |
|   // TODO: Support graph initialization with graph templates and
 | |
|   // subgraph.
 | |
|   calculator_graph.def(
 | |
|       py::init([](py::kwargs kwargs) {
 | |
|         bool init_with_binary_graph = false;
 | |
|         bool init_with_graph_proto = false;
 | |
|         bool init_with_validated_graph_config = false;
 | |
|         CalculatorGraphConfig graph_config_proto;
 | |
|         for (const auto& kw : kwargs) {
 | |
|           const std::string& key = kw.first.cast<std::string>();
 | |
|           if (key == "binary_graph_path") {
 | |
|             init_with_binary_graph = true;
 | |
|             std::string file_name(kw.second.cast<py::object>().str());
 | |
|             graph_config_proto = ReadCalculatorGraphConfigFromFile(file_name);
 | |
|           } else if (key == "graph_config") {
 | |
|             init_with_graph_proto = true;
 | |
|             graph_config_proto =
 | |
|                 ParseProto<CalculatorGraphConfig>(kw.second.cast<py::object>());
 | |
|           } else if (key == "validated_graph_config") {
 | |
|             init_with_validated_graph_config = true;
 | |
|             graph_config_proto =
 | |
|                 py::cast<ValidatedGraphConfig*>(kw.second)->Config();
 | |
|           } else {
 | |
|             throw RaisePyError(
 | |
|                 PyExc_RuntimeError,
 | |
|                 absl::StrCat("Unknown kwargs input argument: ", key).c_str());
 | |
|           }
 | |
|         }
 | |
| 
 | |
|         if ((init_with_binary_graph ? 1 : 0) + (init_with_graph_proto ? 1 : 0) +
 | |
|                 (init_with_validated_graph_config ? 1 : 0) !=
 | |
|             1) {
 | |
|           throw RaisePyError(PyExc_ValueError,
 | |
|                              "Please provide one of the following: "
 | |
|                              "\'binary_graph_path\' to initialize the graph "
 | |
|                              "with a binary graph file, or "
 | |
|                              "\'graph_config\' to initialize the graph with a "
 | |
|                              "graph config proto, or "
 | |
|                              "\'validated_graph_config\' to initialize the "
 | |
|                              "graph with a ValidatedGraphConfig object.");
 | |
|         }
 | |
|         auto calculator_graph = absl::make_unique<CalculatorGraph>();
 | |
|         // Disable default service initialization. This allows us to use
 | |
|         // the CPU versions of calculators that only optionally request
 | |
|         // kGpuService.
 | |
|         RaisePyErrorIfNotOk(
 | |
|             calculator_graph->DisallowServiceDefaultInitialization());
 | |
|         RaisePyErrorIfNotOk(calculator_graph->Initialize(graph_config_proto));
 | |
|         return calculator_graph.release();
 | |
|       }),
 | |
|       R"doc(Initialize CalculatorGraph object.
 | |
| 
 | |
|   Args:
 | |
|     binary_graph_path: The path to a binary mediapipe graph file (.binarypb).
 | |
|     graph_config: A single CalculatorGraphConfig proto message or its text proto
 | |
|       format.
 | |
|     validated_graph_config: A ValidatedGraphConfig object.
 | |
| 
 | |
|   Raises:
 | |
|     FileNotFoundError: If the binary graph file can't be found.
 | |
|     ValueError: If the input arguments prvoided are more than needed or the
 | |
|       graph validation process contains error.
 | |
| )doc");
 | |
| 
 | |
|   // TODO: Return a Python CalculatorGraphConfig instead.
 | |
|   calculator_graph.def_property_readonly(
 | |
|       "text_config",
 | |
|       [](const CalculatorGraph& self) { return self.Config().DebugString(); });
 | |
| 
 | |
|   calculator_graph.def_property_readonly(
 | |
|       "binary_config", [](const CalculatorGraph& self) {
 | |
|         return py::bytes(self.Config().SerializeAsString());
 | |
|       });
 | |
| 
 | |
|   calculator_graph.def_property_readonly(
 | |
|       "max_queue_size",
 | |
|       [](CalculatorGraph* self) { return self->GetMaxInputStreamQueueSize(); });
 | |
| 
 | |
|   calculator_graph.def_property(
 | |
|       "graph_input_stream_add_mode",
 | |
|       [](const CalculatorGraph& self) {
 | |
|         return self.GetGraphInputStreamAddMode();
 | |
|       },
 | |
|       [](CalculatorGraph* self, CalculatorGraph::GraphInputStreamAddMode mode) {
 | |
|         self->SetGraphInputStreamAddMode(mode);
 | |
|       });
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "add_packet_to_input_stream",
 | |
|       [](CalculatorGraph* self, const std::string& stream, const Packet& packet,
 | |
|          const Timestamp& timestamp) {
 | |
|         Timestamp packet_timestamp =
 | |
|             timestamp == Timestamp::Unset() ? packet.Timestamp() : timestamp;
 | |
|         if (!packet_timestamp.IsAllowedInStream()) {
 | |
|           throw RaisePyError(
 | |
|               PyExc_ValueError,
 | |
|               absl::StrCat(packet_timestamp.DebugString(),
 | |
|                            " can't be the timestamp of a Packet in a stream.")
 | |
|                   .c_str());
 | |
|         }
 | |
|         py::gil_scoped_release gil_release;
 | |
|         RaisePyErrorIfNotOk(
 | |
|             self->AddPacketToInputStream(stream, packet.At(packet_timestamp)),
 | |
|             /**acquire_gil=*/true);
 | |
|       },
 | |
|       R"doc(Add a packet to a graph input stream.
 | |
| 
 | |
|   If the graph input stream add mode is ADD_IF_NOT_FULL, the packet will not be
 | |
|   added if any queue exceeds the max queue size specified by the graph config
 | |
|   and will raise a Python runtime error. The WAIT_TILL_NOT_FULL mode (default)
 | |
|   will block until the queues fall below the max queue size before adding the
 | |
|   packet. If the mode is max queue size is -1, then the packet is added
 | |
|   regardless of the sizes of the queues in the graph. The input stream must have
 | |
|   been specified in the configuration as a graph level input stream. On error,
 | |
|   nothing is added.
 | |
| 
 | |
|   Args:
 | |
|     stream: The name of the graph input stream.
 | |
|     packet: The packet to be added into the input stream.
 | |
|     timestamp: The timestamp of the packet. If set, the original packet
 | |
|       timestamp will be overwritten.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the stream is not a graph input stream or the packet can't
 | |
|       be added into the input stream due to the limited queue size or the wrong
 | |
|       packet type.
 | |
|     ValueError: If the timestamp of the Packet is invalid to be the timestamp of
 | |
|       a Packet in a stream.
 | |
| 
 | |
|   Examples:
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in',
 | |
|         packet=packet_creator.create_string('hello world').at(0))
 | |
| 
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in',
 | |
|         packet=packet_creator.create_string('hello world'),
 | |
|         timstamp=1)
 | |
| )doc",
 | |
|       py::arg("stream"), py::arg("packet"),
 | |
|       py::arg("timestamp") = Timestamp::Unset());
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "close_input_stream",
 | |
|       [](CalculatorGraph* self, const std::string& stream) {
 | |
|         RaisePyErrorIfNotOk(self->CloseInputStream(stream));
 | |
|       },
 | |
|       R"doc(Close the named graph input stream.
 | |
| 
 | |
|   Args:
 | |
|     stream: The name of the stream to be closed.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the stream is not a graph input stream.
 | |
| 
 | |
| )doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "close_all_packet_sources",
 | |
|       [](CalculatorGraph* self) {
 | |
|         RaisePyErrorIfNotOk(self->CloseAllPacketSources());
 | |
|       },
 | |
|       R"doc(Closes all the graph input streams and source calculator nodes.)doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "start_run",
 | |
|       [](CalculatorGraph* self, const pybind11::dict& input_side_packets) {
 | |
|         std::map<std::string, Packet> input_side_packet_map;
 | |
|         for (const auto& kv_pair : input_side_packets) {
 | |
|           InsertIfNotPresent(&input_side_packet_map,
 | |
|                              kv_pair.first.cast<std::string>(),
 | |
|                              kv_pair.second.cast<Packet>());
 | |
|         }
 | |
|         RaisePyErrorIfNotOk(self->StartRun(input_side_packet_map));
 | |
|       },
 | |
| 
 | |
|       R"doc(Start a run of the calculator graph.
 | |
| 
 | |
|   A non-blocking call to start a run of the graph and will return when the graph
 | |
|   is started. If input_side_packets is provided, the method will runs the graph
 | |
|   after adding the given extra input side packets.
 | |
| 
 | |
|   start_run(), wait_until_done(), has_error(), add_packet_to_input_stream(), and
 | |
|   close() allow more control over the execution of the graph run.  You can
 | |
|   insert packets directly into a stream while the graph is running.
 | |
|   Once start_run() has been called, the graph will continue to run until
 | |
|   wait_until_done() is called.
 | |
| 
 | |
|   If start_run() returns an error, then the graph is not started and a
 | |
|   subsequent call to start_run() can be attempted.
 | |
| 
 | |
|   Args:
 | |
|     input_side_packets: A dict maps from the input side packet names to the
 | |
|       packets.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the start run occurs any error, e.g. the graph config has
 | |
|       errors, the calculator can't be found, and the streams are not properly
 | |
|       connected.
 | |
| 
 | |
|   Examples:
 | |
|     graph = mp.CalculatorGraph(graph_config=video_process_graph)
 | |
|     graph.start_run(
 | |
|         input_side_packets={
 | |
|             'input_path': packet_creator.create_string('/tmp/input.video'),
 | |
|             'output_path': packet_creator.create_string('/tmp/output.video')
 | |
|         })
 | |
|     graph.close()
 | |
| 
 | |
|     out = []
 | |
|     graph = mp.CalculatorGraph(graph_config=pass_through_graph)
 | |
|     graph.observe_output_stream('out',
 | |
|                                 lambda stream_name, packet: out.append(packet))
 | |
|     graph.start_run()
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(0), timestamp=0)
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(1), timestamp=1)
 | |
|     graph.close()
 | |
| 
 | |
| )doc",
 | |
|       py::arg("input_side_packets") = py::dict());
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "wait_until_done",
 | |
|       [](CalculatorGraph* self) {
 | |
|         py::gil_scoped_release gil_release;
 | |
|         RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
 | |
|       },
 | |
|       R"doc(Wait for the current run to finish.
 | |
| 
 | |
|   A blocking call to wait for the current run to finish (block the current
 | |
|   thread until all source calculators are stopped, all graph input streams have
 | |
|   been closed, and no more calculators can be run). This function can be called
 | |
|   only after start_run(),
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the graph occurs any error during the wait call.
 | |
| 
 | |
|   Examples:
 | |
|     out = []
 | |
|     graph = mp.CalculatorGraph(graph_config=pass_through_graph)
 | |
|     graph.observe_output_stream('out', lambda stream_name, packet: out.append(packet))
 | |
|     graph.start_run()
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(0), timestamp=0)
 | |
|     graph.close_all_packet_sources()
 | |
|     graph.wait_until_done()
 | |
| 
 | |
| )doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "wait_until_idle",
 | |
|       [](CalculatorGraph* self) {
 | |
|         py::gil_scoped_release gil_release;
 | |
|         RaisePyErrorIfNotOk(self->WaitUntilIdle(), /**acquire_gil=*/true);
 | |
|       },
 | |
|       R"doc(Wait until the running graph is in the idle mode.
 | |
| 
 | |
|   Wait until the running graph is in the idle mode, which is when nothing can
 | |
|   be scheduled and nothing is running in the worker threads. This function can
 | |
|   be called only after start_run().
 | |
| 
 | |
|   NOTE: The graph must not have any source nodes because source nodes prevent
 | |
|   the running graph from becoming idle until the source nodes are done.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the graph occurs any error during the wait call.
 | |
| 
 | |
|   Examples:
 | |
|     out = []
 | |
|     graph = mp.CalculatorGraph(graph_config=pass_through_graph)
 | |
|     graph.observe_output_stream('out',
 | |
|                                 lambda stream_name, packet: out.append(packet))
 | |
|     graph.start_run()
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(0), timestamp=0)
 | |
|     graph.wait_until_idle()
 | |
| 
 | |
| )doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "wait_for_observed_output",
 | |
|       [](CalculatorGraph* self) {
 | |
|         py::gil_scoped_release gil_release;
 | |
|         RaisePyErrorIfNotOk(self->WaitForObservedOutput(),
 | |
|                             /**acquire_gil=*/true);
 | |
|       },
 | |
|       R"doc(Wait until a packet is emitted on one of the observed output streams.
 | |
| 
 | |
|   Returns immediately if a packet has already been emitted since the last
 | |
|   call to this function.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError:
 | |
|       If the graph occurs any error or the graph is terminated while waiting.
 | |
| 
 | |
|   Examples:
 | |
|     out = []
 | |
|     graph = mp.CalculatorGraph(graph_config=pass_through_graph)
 | |
|     graph.observe_output_stream('out',
 | |
|                                 lambda stream_name, packet: out.append(packet))
 | |
|     graph.start_run()
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(0), timestamp=0)
 | |
|     graph.wait_for_observed_output()
 | |
|     value = packet_getter.get_int(out[0])
 | |
|     graph.add_packet_to_input_stream(
 | |
|         stream='in', packet=packet_creator.create_int(1), timestamp=1)
 | |
|     graph.wait_for_observed_output()
 | |
|     value = packet_getter.get_int(out[1])
 | |
| 
 | |
| )doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "has_error", [](const CalculatorGraph& self) { return self.HasError(); },
 | |
|       R"doc(Quick non-locking means of checking if the graph has encountered an error)doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "get_combined_error_message",
 | |
|       [](CalculatorGraph* self) {
 | |
|         absl::Status error_status;
 | |
|         if (self->GetCombinedErrors(&error_status) && !error_status.ok()) {
 | |
|           return error_status.ToString();
 | |
|         }
 | |
|         return std::string();
 | |
|       },
 | |
|       R"doc(Combines error messages as a single string.
 | |
| 
 | |
|   Examples:
 | |
|     if graph.has_error():
 | |
|       print(graph.get_combined_error_message())
 | |
| 
 | |
| )doc");
 | |
| 
 | |
|   // TODO: Support passing a single-argument lambda for convenience.
 | |
|   calculator_graph.def(
 | |
|       "observe_output_stream",
 | |
|       [](CalculatorGraph* self, const std::string& stream_name,
 | |
|          pybind11::function callback_fn, bool observe_timestamp_bounds) {
 | |
|         RaisePyErrorIfNotOk(self->ObserveOutputStream(
 | |
|             stream_name,
 | |
|             [callback_fn, stream_name](const Packet& packet) {
 | |
|               absl::MutexLock lock(&callback_mutex);
 | |
|               // Acquires GIL before calling Python callback.
 | |
|               py::gil_scoped_acquire gil_acquire;
 | |
|               callback_fn(stream_name, packet);
 | |
|               return absl::OkStatus();
 | |
|             },
 | |
|             observe_timestamp_bounds));
 | |
|       },
 | |
|       R"doc(Observe the named output stream.
 | |
| 
 | |
|   callback_fn will be invoked on every packet emitted by the output stream.
 | |
|   This method can only be called before start_run().
 | |
| 
 | |
|   Args:
 | |
|     stream_name: The name of the output stream.
 | |
|     callback_fn: The callback function to invoke on every packet emitted by the
 | |
|       output stream.
 | |
|     observe_timestamp_bounds: If true, emits an empty packet at
 | |
|       timestamp_bound -1 when timestamp bound changes.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the calculator graph isn't initialized or the stream
 | |
|       doesn't exist.
 | |
| 
 | |
|   Examples:
 | |
|     out = []
 | |
|     graph = mp.CalculatorGraph(graph_config=graph_config)
 | |
|     graph.observe_output_stream('out',
 | |
|                                 lambda stream_name, packet: out.append(packet))
 | |
| 
 | |
| )doc",
 | |
|       py::arg("stream_name"), py::arg("callback_fn"),
 | |
|       py::arg("observe_timestamp_bounds") = false);
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "close",
 | |
|       [](CalculatorGraph* self) {
 | |
|         RaisePyErrorIfNotOk(self->CloseAllPacketSources());
 | |
|         py::gil_scoped_release gil_release;
 | |
|         RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
 | |
|       },
 | |
|       R"doc(Close all the input sources and shutdown the graph.)doc");
 | |
| 
 | |
|   calculator_graph.def(
 | |
|       "get_output_side_packet",
 | |
|       [](CalculatorGraph* self, const std::string& packet_name) {
 | |
|         auto status_or_packet = self->GetOutputSidePacket(packet_name);
 | |
|         RaisePyErrorIfNotOk(status_or_packet.status());
 | |
|         return status_or_packet.value();
 | |
|       },
 | |
|       R"doc(Get output side packet by name after the graph is done.
 | |
| 
 | |
|   Args:
 | |
|     stream: The name of the outnput stream.
 | |
| 
 | |
|   Raises:
 | |
|     RuntimeError: If the graph is still running or the output side packet is not
 | |
|       found or empty.
 | |
| 
 | |
|   Examples:
 | |
|     graph = mp.CalculatorGraph(graph_config=graph_config)
 | |
|     graph.start_run()
 | |
|     graph.close()
 | |
|     output_side_packet = graph.get_output_side_packet('packet_name')
 | |
| 
 | |
| )doc",
 | |
|       py::return_value_policy::move);
 | |
| }
 | |
| 
 | |
| }  // namespace python
 | |
| }  // namespace mediapipe
 |