Initialize GPU support for Python Task API

PiperOrigin-RevId: 575842513
This commit is contained in:
Sebastian Schmidt 2023-10-23 09:31:58 -07:00 committed by Copybara-Service
parent 0dee33ccba
commit 6aa27d9aeb
5 changed files with 61 additions and 0 deletions

View File

@ -264,6 +264,7 @@ cc_library_with_tflite(
"//mediapipe/framework:executor", "//mediapipe/framework:executor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:name_util",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",

View File

@ -39,6 +39,10 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace core { namespace core {
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
} // namespace } // namespace
/* static */ /* static */
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets,
std::shared_ptr<::mediapipe::GpuResources> resources) {
#else
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create( absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver, std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback, PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor, std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) { std::optional<PacketMap> input_side_packets) {
#endif // !MEDIAPIPE_DISABLE_GPU
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
MP_RETURN_IF_ERROR(task_runner->Initialize( MP_RETURN_IF_ERROR(task_runner->Initialize(
std::move(config), std::move(op_resolver), std::move(default_executor), std::move(config), std::move(op_resolver), std::move(default_executor),
std::move(input_side_packets))); std::move(input_side_packets)));
#if !MEDIAPIPE_DISABLE_GPU
if (resources) {
MP_RETURN_IF_ERROR(
task_runner->graph_.SetGpuResources(std::move(resources)));
}
#endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(task_runner->Start()); MP_RETURN_IF_ERROR(task_runner->Start());
return task_runner; return task_runner;
} }

View File

@ -42,6 +42,11 @@ limitations under the License.
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe { namespace mediapipe {
#if !MEDIAPIPE_DISABLE_GPU
class GpuResources;
#endif // !MEDIAPIPE_DISABLE_GPU
namespace tasks { namespace tasks {
namespace core { namespace core {
@ -72,12 +77,22 @@ class TaskRunner {
// asynchronous method, Send(), to provide the input packets. If the packets // asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to // callback is absent, clients must use the synchronous method, Process(), to
// provide the input packets and receive the output packets. // provide the input packets and receive the output packets.
#if !MEDIAPIPE_DISABLE_GPU
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt,
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
#else
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create( static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr, PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr, std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt); std::optional<PacketMap> input_side_packets = std::nullopt);
#endif // !MEDIAPIPE_DISABLE_GPU
// TaskRunner is neither copyable nor movable. // TaskRunner is neither copyable nor movable.
TaskRunner(const TaskRunner&) = delete; TaskRunner(const TaskRunner&) = delete;

View File

@ -26,9 +26,11 @@ pybind_library(
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/python/pybind:util", "//mediapipe/python/pybind:util",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
], ],

View File

@ -14,6 +14,7 @@
#include "mediapipe/tasks/python/core/pybind/task_runner.h" #include "mediapipe/tasks/python/core/pybind/task_runner.h"
#include "absl/log/absl_log.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
@ -21,6 +22,9 @@
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h" #include "pybind11_protobuf/native_proto_caster.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
return absl::OkStatus(); return absl::OkStatus();
}; };
} }
#if !MEDIAPIPE_DISABLE_GPU
auto gpu_resources_ = mediapipe::GpuResources::Create();
if (!gpu_resources_.ok()) {
ABSL_LOG(INFO) << "GPU suport is not available: "
<< gpu_resources_.status();
gpu_resources_ = nullptr;
}
auto task_runner = TaskRunner::Create(
std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback),
/* default_executor= */ nullptr,
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
#else
auto task_runner = TaskRunner::Create( auto task_runner = TaskRunner::Create(
std::move(graph_config), std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(), absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback)); std::move(callback));
#endif // !MEDIAPIPE_DISABLE_GPU
RaisePyErrorIfNotOk(task_runner.status()); RaisePyErrorIfNotOk(task_runner.status());
return std::move(*task_runner); return std::move(*task_runner);
}, },