Initialize GPU support for Python Task API
PiperOrigin-RevId: 575842513
This commit is contained in:
parent
0dee33ccba
commit
6aa27d9aeb
|
@ -264,6 +264,7 @@ cc_library_with_tflite(
|
|||
"//mediapipe/framework:executor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:name_util",
|
||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
|
|
@ -39,6 +39,10 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace core {
|
||||
|
@ -88,16 +92,34 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
|
|||
} // namespace
|
||||
|
||||
/* static */
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
PacketsCallback packets_callback,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets,
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources) {
|
||||
#else
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
PacketsCallback packets_callback,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets) {
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
||||
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
||||
std::move(config), std::move(op_resolver), std::move(default_executor),
|
||||
std::move(input_side_packets)));
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
if (resources) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
task_runner->graph_.SetGpuResources(std::move(resources)));
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
MP_RETURN_IF_ERROR(task_runner->Start());
|
||||
return task_runner;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,11 @@ limitations under the License.
|
|||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
class GpuResources;
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace tasks {
|
||||
namespace core {
|
||||
|
||||
|
@ -72,12 +77,22 @@ class TaskRunner {
|
|||
// asynchronous method, Send(), to provide the input packets. If the packets
|
||||
// callback is absent, clients must use the synchronous method, Process(), to
|
||||
// provide the input packets and receive the output packets.
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt,
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources = nullptr);
|
||||
#else
|
||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// TaskRunner is neither copyable nor movable.
|
||||
TaskRunner(const TaskRunner&) = delete;
|
||||
|
|
|
@ -26,9 +26,11 @@ pybind_library(
|
|||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||
"//mediapipe/python/pybind:util",
|
||||
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"@com_google_absl//absl/log:absl_log",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
|
||||
],
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/tasks/python/core/pybind/task_runner.h"
|
||||
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/python/pybind/util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
@ -21,6 +22,9 @@
|
|||
#include "pybind11/stl.h"
|
||||
#include "pybind11_protobuf/native_proto_caster.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -74,10 +78,27 @@ mode) or not (synchronous mode).)doc");
|
|||
return absl::OkStatus();
|
||||
};
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
auto gpu_resources_ = mediapipe::GpuResources::Create();
|
||||
if (!gpu_resources_.ok()) {
|
||||
ABSL_LOG(INFO) << "GPU suport is not available: "
|
||||
<< gpu_resources_.status();
|
||||
gpu_resources_ = nullptr;
|
||||
}
|
||||
auto task_runner = TaskRunner::Create(
|
||||
std::move(graph_config),
|
||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||
std::move(callback),
|
||||
/* default_executor= */ nullptr,
|
||||
/* input_side_packes= */ std::nullopt, std::move(*gpu_resources_));
|
||||
#else
|
||||
auto task_runner = TaskRunner::Create(
|
||||
std::move(graph_config),
|
||||
absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
|
||||
std::move(callback));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
RaisePyErrorIfNotOk(task_runner.status());
|
||||
return std::move(*task_runner);
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue
Block a user