diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index fa61feb9d..bb0d4b001 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -264,6 +264,7 @@ cc_library_with_tflite( "//mediapipe/framework:executor", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:name_util", + "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/tasks/cc:common", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/mediapipe/tasks/cc/core/task_runner.cc b/mediapipe/tasks/cc/core/task_runner.cc index 88c91bcdb..e3862ddd7 100644 --- a/mediapipe/tasks/cc/core/task_runner.cc +++ b/mediapipe/tasks/cc/core/task_runner.cc @@ -39,6 +39,10 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // !MEDIAPIPE_DISABLE_GPU + namespace mediapipe { namespace tasks { namespace core { @@ -88,16 +92,34 @@ absl::StatusOr GenerateOutputPacketMap( } // namespace /* static */ +#if !MEDIAPIPE_DISABLE_GPU +absl::StatusOr> TaskRunner::Create( + CalculatorGraphConfig config, + std::unique_ptr op_resolver, + PacketsCallback packets_callback, + std::shared_ptr default_executor, + std::optional input_side_packets, + std::shared_ptr<::mediapipe::GpuResources> resources) { +#else absl::StatusOr> TaskRunner::Create( CalculatorGraphConfig config, std::unique_ptr op_resolver, PacketsCallback packets_callback, std::shared_ptr default_executor, std::optional input_side_packets) { +#endif // !MEDIAPIPE_DISABLE_GPU auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); MP_RETURN_IF_ERROR(task_runner->Initialize( std::move(config), std::move(op_resolver), std::move(default_executor), std::move(input_side_packets))); + +#if !MEDIAPIPE_DISABLE_GPU + if (resources) { + MP_RETURN_IF_ERROR( + task_runner->graph_.SetGpuResources(std::move(resources))); + } +#endif // !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(task_runner->Start()); return task_runner; } diff --git a/mediapipe/tasks/cc/core/task_runner.h b/mediapipe/tasks/cc/core/task_runner.h index 810063d4b..ef48bef55 100644 --- a/mediapipe/tasks/cc/core/task_runner.h +++ b/mediapipe/tasks/cc/core/task_runner.h @@ -42,6 +42,11 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" namespace mediapipe { + +#if !MEDIAPIPE_DISABLE_GPU +class GpuResources; +#endif // !MEDIAPIPE_DISABLE_GPU + namespace tasks { namespace core { @@ -72,12 +77,22 @@ class TaskRunner { // asynchronous method, Send(), to provide the input packets. If the packets // callback is absent, clients must use the synchronous method, Process(), to // provide the input packets and receive the output packets. +#if !MEDIAPIPE_DISABLE_GPU + static absl::StatusOr> Create( + CalculatorGraphConfig config, + std::unique_ptr op_resolver = nullptr, + PacketsCallback packets_callback = nullptr, + std::shared_ptr default_executor = nullptr, + std::optional input_side_packets = std::nullopt, + std::shared_ptr<::mediapipe::GpuResources> resources = nullptr); +#else static absl::StatusOr> Create( CalculatorGraphConfig config, std::unique_ptr op_resolver = nullptr, PacketsCallback packets_callback = nullptr, std::shared_ptr default_executor = nullptr, std::optional input_side_packets = std::nullopt); +#endif // !MEDIAPIPE_DISABLE_GPU // TaskRunner is neither copyable nor movable. TaskRunner(const TaskRunner&) = delete; diff --git a/mediapipe/tasks/python/core/pybind/BUILD b/mediapipe/tasks/python/core/pybind/BUILD index 88ea05f4f..391712f27 100644 --- a/mediapipe/tasks/python/core/pybind/BUILD +++ b/mediapipe/tasks/python/core/pybind/BUILD @@ -26,9 +26,11 @@ pybind_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/python/pybind:util", "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:task_runner", + "@com_google_absl//absl/log:absl_log", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index f95cddde8..0de7d24d8 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -14,6 +14,7 @@ #include "mediapipe/tasks/python/core/pybind/task_runner.h" +#include "absl/log/absl_log.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/python/pybind/util.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" @@ -21,6 +22,9 @@ #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" #include "tensorflow/lite/core/api/op_resolver.h" +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_shared_data_internal.h" +#endif // MEDIAPIPE_DISABLE_GPU namespace mediapipe { namespace tasks { @@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc"); return absl::OkStatus(); }; } + +#if !MEDIAPIPE_DISABLE_GPU + auto gpu_resources_ = mediapipe::GpuResources::Create(); + if (!gpu_resources_.ok()) { + ABSL_LOG(INFO) << "GPU suport is not available: " + << gpu_resources_.status(); + gpu_resources_ = nullptr; + } + auto task_runner = TaskRunner::Create( + std::move(graph_config), + absl::make_unique(), + std::move(callback), + /* default_executor= */ nullptr, + /* input_side_packes= */ std::nullopt, std::move(*gpu_resources_)); +#else auto task_runner = TaskRunner::Create( std::move(graph_config), absl::make_unique(), std::move(callback)); +#endif // !MEDIAPIPE_DISABLE_GPU + RaisePyErrorIfNotOk(task_runner.status()); return std::move(*task_runner); },