Initialize GPU support for Python Task API

PiperOrigin-RevId: 575842513
This commit is contained in:
Sebastian Schmidt 2023-10-23 09:31:58 -07:00 committed by Copybara-Service
parent 0dee33ccba
commit 6aa27d9aeb
5 changed files with 61 additions and 0 deletions

View File

@ -264,6 +264,7 @@ cc_library_with_tflite(
"//mediapipe/framework:executor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/tasks/cc:common",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -39,6 +39,10 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe {
namespace tasks {
namespace core {
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
} // namespace
/* static */
#if !MEDIAPIPE_DISABLE_GPU
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets,
std::shared_ptr<::mediapipe::GpuResources> resources) {
#else
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) {
#endif // !MEDIAPIPE_DISABLE_GPU
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
MP_RETURN_IF_ERROR(task_runner->Initialize(
std::move(config), std::move(op_resolver), std::move(default_executor),
std::move(input_side_packets)));
#if !MEDIAPIPE_DISABLE_GPU
if (resources) {
MP_RETURN_IF_ERROR(
task_runner->graph_.SetGpuResources(std::move(resources)));
}
#endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(task_runner->Start());
return task_runner;
}

View File

@ -42,6 +42,11 @@ limitations under the License.
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
#if !MEDIAPIPE_DISABLE_GPU
class GpuResources;
#endif // !MEDIAPIPE_DISABLE_GPU
namespace tasks {
namespace core {
@ -72,12 +77,22 @@ class TaskRunner {
// asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to
// provide the input packets and receive the output packets.
#if !MEDIAPIPE_DISABLE_GPU
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt,
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
#else
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt);
#endif // !MEDIAPIPE_DISABLE_GPU
// TaskRunner is neither copyable nor movable.
TaskRunner(const TaskRunner&) = delete;

View File

@ -26,9 +26,11 @@ pybind_library(
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/python/pybind:util",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner",
"@com_google_absl//absl/log:absl_log",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],

View File

@ -14,6 +14,7 @@
#include "mediapipe/tasks/python/core/pybind/task_runner.h"
#include "absl/log/absl_log.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/python/pybind/util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
@ -21,6 +22,9 @@
#include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // MEDIAPIPE_DISABLE_GPU
namespace mediapipe {
namespace tasks {
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
return absl::OkStatus();
};
}
#if !MEDIAPIPE_DISABLE_GPU
auto gpu_resources_ = mediapipe::GpuResources::Create();
if (!gpu_resources_.ok()) {
ABSL_LOG(INFO) << "GPU suport is not available: "
<< gpu_resources_.status();
gpu_resources_ = nullptr;
}
auto task_runner = TaskRunner::Create(
std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback),
/* default_executor= */ nullptr,
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
#else
auto task_runner = TaskRunner::Create(
std::move(graph_config),
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback));
#endif // !MEDIAPIPE_DISABLE_GPU
RaisePyErrorIfNotOk(task_runner.status());
return std::move(*task_runner);
},