Internal update

PiperOrigin-RevId: 543508346
This commit is contained in:
MediaPipe Team 2023-06-26 12:18:25 -07:00 committed by Copybara-Service
parent 570880190b
commit 9de1b2577f
4 changed files with 121 additions and 2 deletions

View File

@ -29,6 +29,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"@com_google_absl//absl/log",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",

View File

@ -17,15 +17,56 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <variant>
#include "absl/log/log.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace core { namespace core {
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
const BaseOptions::CpuOptions& options) {
proto::Acceleration acceleration_proto = proto::Acceleration();
acceleration_proto.mutable_tflite();
return acceleration_proto;
}
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
const BaseOptions::GpuOptions& options) {
proto::Acceleration acceleration_proto = proto::Acceleration();
auto* gpu = acceleration_proto.mutable_gpu();
gpu->set_use_advanced_gpu_api(true);
gpu->set_cached_kernel_path(options.cached_kernel_path);
gpu->set_serialized_model_dir(options.serialized_model_dir);
gpu->set_model_token(options.model_token);
return acceleration_proto;
}
template <typename T>
void SetDelegateOptionsOrDie(const BaseOptions* base_options,
proto::BaseOptions& base_options_proto) {
if (base_options->delegate_options.has_value()) {
if (!std::holds_alternative<T>(*base_options->delegate_options)) {
LOG(FATAL) << "Specified Delegate type does not match the provided "
"delegate options.";
} else {
std::visit(
[&base_options_proto](const auto& delegate_options) {
proto::Acceleration acceleration_proto =
ConvertDelegateOptionsToAccelerationProto(delegate_options);
base_options_proto.mutable_acceleration()->Swap(
&acceleration_proto);
},
*base_options->delegate_options);
}
}
}
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
proto::BaseOptions base_options_proto; proto::BaseOptions base_options_proto;
if (!base_options->model_asset_path.empty()) { if (!base_options->model_asset_path.empty()) {
@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
switch (base_options->delegate) { switch (base_options->delegate) {
case BaseOptions::Delegate::CPU: case BaseOptions::Delegate::CPU:
base_options_proto.mutable_acceleration()->mutable_tflite(); base_options_proto.mutable_acceleration()->mutable_tflite();
SetDelegateOptionsOrDie<BaseOptions::CpuOptions>(base_options,
base_options_proto);
break; break;
case BaseOptions::Delegate::GPU: case BaseOptions::Delegate::GPU:
base_options_proto.mutable_acceleration() base_options_proto.mutable_acceleration()
->mutable_gpu() ->mutable_gpu()
->set_use_advanced_gpu_api(true); ->set_use_advanced_gpu_api(true);
SetDelegateOptionsOrDie<BaseOptions::GpuOptions>(base_options,
base_options_proto);
break; break;
case BaseOptions::Delegate::EDGETPU_NNAPI: case BaseOptions::Delegate::EDGETPU_NNAPI:
base_options_proto.mutable_acceleration() base_options_proto.mutable_acceleration()
@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
->set_accelerator_name("google-edgetpu"); ->set_accelerator_name("google-edgetpu");
break; break;
} }
return base_options_proto; return base_options_proto;
} }
} // namespace core } // namespace core

View File

@ -17,7 +17,9 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <variant>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
@ -38,7 +40,8 @@ struct BaseOptions {
std::string model_asset_path = ""; std::string model_asset_path = "";
// The delegate to run MediaPipe. If the delegate is not set, the default // The delegate to run MediaPipe. If the delegate is not set, the default
// delegate CPU is used. // delegate CPU is used. Use `delegate_options` to configure advanced
// features of the selected delegate."
enum Delegate { enum Delegate {
CPU = 0, CPU = 0,
GPU = 1, GPU = 1,
@ -48,6 +51,30 @@ struct BaseOptions {
Delegate delegate = CPU; Delegate delegate = CPU;
// Options for CPU.
struct CpuOptions {};
// Options for GPU.
struct GpuOptions {
// Load pre-compiled serialized binary cache to accelerate init process.
// Only available on Android. Kernel caching will only be enabled if this
// path is set. NOTE: binary cache usage may be skipped if valid serialized
// model, specified by "serialized_model_dir", exists.
std::string cached_kernel_path;
// A dir to load from and save to a pre-compiled serialized model used to
// accelerate init process.
// NOTE: serialized model takes precedence over binary cache
// specified by "cached_kernel_path", which still can be used if
// serialized model is invalid or missing.
std::string serialized_model_dir;
// Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens.
std::string model_token;
};
// The file descriptor to a file opened with open(2), with optional additional // The file descriptor to a file opened with open(2), with optional additional
// offset and length information. // offset and length information.
struct FileDescriptorMeta { struct FileDescriptorMeta {
@ -67,6 +94,10 @@ struct BaseOptions {
// built-in Ops. // built-in Ops.
std::unique_ptr<tflite::OpResolver> op_resolver = std::unique_ptr<tflite::OpResolver> op_resolver =
absl::make_unique<MediaPipeBuiltinOpResolver>(); absl::make_unique<MediaPipeBuiltinOpResolver>();
// Options for the chosen delegate. If not set, the default delegate options
// is used.
std::optional<std::variant<CpuOptions, GpuOptions>> delegate_options;
}; };
// Converts a BaseOptions to a BaseOptionsProto. // Converts a BaseOptions to a BaseOptionsProto.

View File

@ -1,6 +1,9 @@
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include <memory>
#include <optional>
#include <string> #include <string>
#include <variant>
#include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -11,6 +14,8 @@
constexpr char kTestModelBundlePath[] = constexpr char kTestModelBundlePath[] =
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task"; "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
constexpr char kCachedModelDir[] = "/data/local/tmp";
constexpr char kModelToken[] = "dummy_model_token";
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) {
EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu"); EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
} }
TEST(DelegateOptionsTest, SucceedCpuOptions) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::CPU;
BaseOptions::CpuOptions cpu_options;
base_options.delegate_options = cpu_options;
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
EXPECT_TRUE(proto.acceleration().has_tflite());
ASSERT_FALSE(proto.acceleration().has_gpu());
}
TEST(DelegateOptionsTest, SucceedGpuOptions) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::GPU;
BaseOptions::GpuOptions gpu_options;
gpu_options.cached_kernel_path = kCachedModelDir;
gpu_options.model_token = kModelToken;
base_options.delegate_options = gpu_options;
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
ASSERT_TRUE(proto.acceleration().has_gpu());
ASSERT_FALSE(proto.acceleration().has_tflite());
EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
}
TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::CPU;
BaseOptions::GpuOptions gpu_options;
gpu_options.cached_kernel_path = kCachedModelDir;
gpu_options.model_token = kModelToken;
base_options.delegate_options = gpu_options;
ASSERT_DEATH(
{ proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); },
"Specified Delegate type does not match the provided "
"delegate options.");
}
} // namespace } // namespace
} // namespace core } // namespace core
} // namespace tasks } // namespace tasks