diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index dad9cdf1f..a3e44c536 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -29,6 +29,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index a34c23168..b7987f982 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -17,15 +17,56 @@ limitations under the License. #include #include +#include +#include "absl/log/log.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" namespace mediapipe { namespace tasks { namespace core { +proto::Acceleration ConvertDelegateOptionsToAccelerationProto( + const BaseOptions::CpuOptions& options) { + proto::Acceleration acceleration_proto = proto::Acceleration(); + acceleration_proto.mutable_tflite(); + return acceleration_proto; +} + +proto::Acceleration ConvertDelegateOptionsToAccelerationProto( + const BaseOptions::GpuOptions& options) { + proto::Acceleration acceleration_proto = proto::Acceleration(); + auto* gpu = acceleration_proto.mutable_gpu(); + gpu->set_use_advanced_gpu_api(true); + gpu->set_cached_kernel_path(options.cached_kernel_path); + gpu->set_serialized_model_dir(options.serialized_model_dir); + gpu->set_model_token(options.model_token); + return acceleration_proto; +} + +template +void SetDelegateOptionsOrDie(const BaseOptions* base_options, + proto::BaseOptions& base_options_proto) { + if (base_options->delegate_options.has_value()) { + if (!std::holds_alternative(*base_options->delegate_options)) { + LOG(FATAL) << "Specified Delegate type does not match the provided " + "delegate options."; + } else { + std::visit( + [&base_options_proto](const auto& delegate_options) { + proto::Acceleration acceleration_proto = + ConvertDelegateOptionsToAccelerationProto(delegate_options); + base_options_proto.mutable_acceleration()->Swap( + &acceleration_proto); + }, + *base_options->delegate_options); + } + } +} + proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { proto::BaseOptions base_options_proto; if (!base_options->model_asset_path.empty()) { @@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { switch (base_options->delegate) { case BaseOptions::Delegate::CPU: base_options_proto.mutable_acceleration()->mutable_tflite(); + SetDelegateOptionsOrDie(base_options, + base_options_proto); break; case BaseOptions::Delegate::GPU: base_options_proto.mutable_acceleration() ->mutable_gpu() ->set_use_advanced_gpu_api(true); + SetDelegateOptionsOrDie(base_options, + base_options_proto); break; case BaseOptions::Delegate::EDGETPU_NNAPI: base_options_proto.mutable_acceleration() @@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { ->set_accelerator_name("google-edgetpu"); break; } - return base_options_proto; } } // namespace core diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 021aebbe5..738d71093 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -17,7 +17,9 @@ limitations under the License. #define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_ #include +#include #include +#include #include "absl/memory/memory.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" @@ -38,7 +40,8 @@ struct BaseOptions { std::string model_asset_path = ""; // The delegate to run MediaPipe. If the delegate is not set, the default - // delegate CPU is used. + // delegate CPU is used. Use `delegate_options` to configure advanced + // features of the selected delegate." enum Delegate { CPU = 0, GPU = 1, @@ -48,6 +51,30 @@ struct BaseOptions { Delegate delegate = CPU; + // Options for CPU. + struct CpuOptions {}; + + // Options for GPU. + struct GpuOptions { + // Load pre-compiled serialized binary cache to accelerate init process. + // Only available on Android. Kernel caching will only be enabled if this + // path is set. NOTE: binary cache usage may be skipped if valid serialized + // model, specified by "serialized_model_dir", exists. + std::string cached_kernel_path; + + // A dir to load from and save to a pre-compiled serialized model used to + // accelerate init process. + // NOTE: serialized model takes precedence over binary cache + // specified by "cached_kernel_path", which still can be used if + // serialized model is invalid or missing. + std::string serialized_model_dir; + + // Unique token identifying the model. Used in conjunction with + // "serialized_model_dir". It is the caller's responsibility to ensure + // there is no clash of the tokens. + std::string model_token; + }; + // The file descriptor to a file opened with open(2), with optional additional // offset and length information. struct FileDescriptorMeta { @@ -67,6 +94,10 @@ struct BaseOptions { // built-in Ops. std::unique_ptr op_resolver = absl::make_unique(); + + // Options for the chosen delegate. If not set, the default delegate options + // is used. + std::optional> delegate_options; }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/core/base_options_test.cc b/mediapipe/tasks/cc/core/base_options_test.cc index dce95050d..af9a55a37 100644 --- a/mediapipe/tasks/cc/core/base_options_test.cc +++ b/mediapipe/tasks/cc/core/base_options_test.cc @@ -1,6 +1,9 @@ #include "mediapipe/tasks/cc/core/base_options.h" +#include +#include #include +#include #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/port/gmock.h" @@ -11,6 +14,8 @@ constexpr char kTestModelBundlePath[] = "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task"; +constexpr char kCachedModelDir[] = "/data/local/tmp"; +constexpr char kModelToken[] = "dummy_model_token"; namespace mediapipe { namespace tasks { @@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) { EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu"); } +TEST(DelegateOptionsTest, SucceedCpuOptions) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::CPU; + BaseOptions::CpuOptions cpu_options; + base_options.delegate_options = cpu_options; + proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); + EXPECT_TRUE(proto.acceleration().has_tflite()); + ASSERT_FALSE(proto.acceleration().has_gpu()); +} + +TEST(DelegateOptionsTest, SucceedGpuOptions) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::GPU; + BaseOptions::GpuOptions gpu_options; + gpu_options.cached_kernel_path = kCachedModelDir; + gpu_options.model_token = kModelToken; + base_options.delegate_options = gpu_options; + proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); + ASSERT_TRUE(proto.acceleration().has_gpu()); + ASSERT_FALSE(proto.acceleration().has_tflite()); + EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api()); + EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir); + EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken); +} + +TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) { + BaseOptions base_options; + base_options.delegate = BaseOptions::Delegate::CPU; + BaseOptions::GpuOptions gpu_options; + gpu_options.cached_kernel_path = kCachedModelDir; + gpu_options.model_token = kModelToken; + base_options.delegate_options = gpu_options; + ASSERT_DEATH( + { proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); }, + "Specified Delegate type does not match the provided " + "delegate options."); +} + } // namespace } // namespace core } // namespace tasks