Initialize GPU support for Python Task API
PiperOrigin-RevId: 575842513
This commit is contained in:
parent
0dee33ccba
commit
6aa27d9aeb
|
@ -264,6 +264,7 @@ cc_library_with_tflite(
|
||||||
"//mediapipe/framework:executor",
|
"//mediapipe/framework:executor",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/tool:name_util",
|
"//mediapipe/framework/tool:name_util",
|
||||||
|
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
|
|
@ -39,6 +39,10 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace core {
|
namespace core {
|
||||||
|
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||||
|
CalculatorGraphConfig config,
|
||||||
|
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||||
|
PacketsCallback packets_callback,
|
||||||
|
std::shared_ptr<Executor> default_executor,
|
||||||
|
std::optional<PacketMap> input_side_packets,
|
||||||
|
std::shared_ptr<::mediapipe::GpuResources> resources) {
|
||||||
|
#else
|
||||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||||
PacketsCallback packets_callback,
|
PacketsCallback packets_callback,
|
||||||
std::shared_ptr<Executor> default_executor,
|
std::shared_ptr<Executor> default_executor,
|
||||||
std::optional<PacketMap> input_side_packets) {
|
std::optional<PacketMap> input_side_packets) {
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
||||||
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
||||||
std::move(config), std::move(op_resolver), std::move(default_executor),
|
std::move(config), std::move(op_resolver), std::move(default_executor),
|
||||||
std::move(input_side_packets)));
|
std::move(input_side_packets)));
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
if (resources) {
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
task_runner->graph_.SetGpuResources(std::move(resources)));
|
||||||
|
}
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(task_runner->Start());
|
MP_RETURN_IF_ERROR(task_runner->Start());
|
||||||
return task_runner;
|
return task_runner;
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,11 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
class GpuResources;
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace core {
|
namespace core {
|
||||||
|
|
||||||
|
@ -72,12 +77,22 @@ class TaskRunner {
|
||||||
// asynchronous method, Send(), to provide the input packets. If the packets
|
// asynchronous method, Send(), to provide the input packets. If the packets
|
||||||
// callback is absent, clients must use the synchronous method, Process(), to
|
// callback is absent, clients must use the synchronous method, Process(), to
|
||||||
// provide the input packets and receive the output packets.
|
// provide the input packets and receive the output packets.
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||||
|
CalculatorGraphConfig config,
|
||||||
|
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||||
|
PacketsCallback packets_callback = nullptr,
|
||||||
|
std::shared_ptr<Executor> default_executor = nullptr,
|
||||||
|
std::optional<PacketMap> input_side_packets = std::nullopt,
|
||||||
|
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
|
||||||
|
#else
|
||||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||||
PacketsCallback packets_callback = nullptr,
|
PacketsCallback packets_callback = nullptr,
|
||||||
std::shared_ptr<Executor> default_executor = nullptr,
|
std::shared_ptr<Executor> default_executor = nullptr,
|
||||||
std::optional<PacketMap> input_side_packets = std::nullopt);
|
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
// TaskRunner is neither copyable nor movable.
|
// TaskRunner is neither copyable nor movable.
|
||||||
TaskRunner(const TaskRunner&) = delete;
|
TaskRunner(const TaskRunner&) = delete;
|
||||||
|
|
|
@ -26,9 +26,11 @@ pybind_library(
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||||
"//mediapipe/python/pybind:util",
|
"//mediapipe/python/pybind:util",
|
||||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
"@com_google_absl//absl/log:absl_log",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
|
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
|
||||||
],
|
],
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include "mediapipe/tasks/python/core/pybind/task_runner.h"
|
#include "mediapipe/tasks/python/core/pybind/task_runner.h"
|
||||||
|
|
||||||
|
#include "absl/log/absl_log.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/python/pybind/util.h"
|
#include "mediapipe/python/pybind/util.h"
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
@ -21,6 +22,9 @@
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
#include "pybind11_protobuf/native_proto_caster.h"
|
#include "pybind11_protobuf/native_proto_caster.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||||
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
auto gpu_resources_ = mediapipe::GpuResources::Create();
|
||||||
|
if (!gpu_resources_.ok()) {
|
||||||
|
ABSL_LOG(INFO) << "GPU suport is not available: "
|
||||||
|
<< gpu_resources_.status();
|
||||||
|
gpu_resources_ = nullptr;
|
||||||
|
}
|
||||||
|
auto task_runner = TaskRunner::Create(
|
||||||
|
std::move(graph_config),
|
||||||
|
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||||
|
std::move(callback),
|
||||||
|
/* default_executor= */ nullptr,
|
||||||
|
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
|
||||||
|
#else
|
||||||
auto task_runner = TaskRunner::Create(
|
auto task_runner = TaskRunner::Create(
|
||||||
std::move(graph_config),
|
std::move(graph_config),
|
||||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||||
std::move(callback));
|
std::move(callback));
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
RaisePyErrorIfNotOk(task_runner.status());
|
RaisePyErrorIfNotOk(task_runner.status());
|
||||||
return std::move(*task_runner);
|
return std::move(*task_runner);
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue
Block a user