InferenceCalculatorAdvancedGL save cache in Open().
PiperOrigin-RevId: 547652481
This commit is contained in:
parent
a2cd3e7f95
commit
cc2aa4f4cc
|
@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
gpu_delegate_options);
|
||||
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
bool UseSerializedModel() const { return use_serialized_model_; }
|
||||
|
||||
private:
|
||||
bool use_kernel_caching_ = false;
|
||||
|
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
|||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
||||
MP_RETURN_IF_ERROR(
|
||||
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
|
||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||
tflite_gpu_runner_.reset();
|
||||
return absl::OkStatus();
|
||||
|
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
|||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||
}
|
||||
|
||||
if (on_disk_cache_helper_.UseSerializedModel()) {
|
||||
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
||||
return tflite_gpu_runner_->Build();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
|
||||
}
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
|
|
@ -234,6 +234,11 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
|||
MP_RETURN_IF_ERROR(
|
||||
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||
|
||||
if (serialized_model_.empty() &&
|
||||
opencl_init_from_serialized_model_is_forced_) {
|
||||
ASSIGN_OR_RETURN(serialized_model_, GetSerializedModel());
|
||||
}
|
||||
|
||||
// Try to initialize from serialized model first.
|
||||
if (!serialized_model_.empty()) {
|
||||
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
|
||||
|
@ -270,7 +275,6 @@ absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
|||
}
|
||||
|
||||
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
||||
RET_CHECK(runner_) << "Runner is in invalid state.";
|
||||
if (serialized_model_used_) {
|
||||
return serialized_model_;
|
||||
}
|
||||
|
|
|
@ -62,6 +62,9 @@ class TFLiteGPURunner {
|
|||
|
||||
void ForceOpenGL() { opengl_is_forced_ = true; }
|
||||
void ForceOpenCL() { opencl_is_forced_ = true; }
|
||||
void ForceOpenCLInitFromSerializedModel() {
|
||||
opencl_init_from_serialized_model_is_forced_ = true;
|
||||
}
|
||||
|
||||
absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id);
|
||||
absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id);
|
||||
|
@ -141,6 +144,7 @@ class TFLiteGPURunner {
|
|||
|
||||
bool opencl_is_forced_ = false;
|
||||
bool opengl_is_forced_ = false;
|
||||
bool opencl_init_from_serialized_model_is_forced_ = false;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
Loading…
Reference in New Issue
Block a user