diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 8aee46185..e265eaee7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl gpu_delegate_options); absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; + bool UseSerializedModel() const { return use_serialized_model_; } private: bool use_kernel_caching_ = false; @@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( } absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { - MP_RETURN_IF_ERROR( - on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get())); return gpu_helper_.RunInGlContext([this]() -> absl::Status { tflite_gpu_runner_.reset(); return absl::OkStatus(); @@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner( tflite_gpu_runner_->GetOutputShapes()[i].c}; } + if (on_disk_cache_helper_.UseSerializedModel()) { + tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel(); + } + MP_RETURN_IF_ERROR( on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); - return tflite_gpu_runner_->Build(); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); + return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()); } #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 4e40975cb..c1b272b67 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -234,6 +234,11 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( MP_RETURN_IF_ERROR( cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); + if (serialized_model_.empty() && + opencl_init_from_serialized_model_is_forced_) { + ASSIGN_OR_RETURN(serialized_model_, GetSerializedModel()); + } + // Try to initialize from serialized model first. if (!serialized_model_.empty()) { absl::Status init_status = InitializeOpenCLFromSerializedModel(builder); @@ -270,7 +275,6 @@ absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( } absl::StatusOr> TFLiteGPURunner::GetSerializedModel() { - RET_CHECK(runner_) << "Runner is in invalid state."; if (serialized_model_used_) { return serialized_model_; } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 5eeaa230f..c64981ef8 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -62,6 +62,9 @@ class TFLiteGPURunner { void ForceOpenGL() { opengl_is_forced_ = true; } void ForceOpenCL() { opencl_is_forced_ = true; } + void ForceOpenCLInitFromSerializedModel() { + opencl_init_from_serialized_model_is_forced_ = true; + } absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id); absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id); @@ -141,6 +144,7 @@ class TFLiteGPURunner { bool opencl_is_forced_ = false; bool opengl_is_forced_ = false; + bool opencl_init_from_serialized_model_is_forced_ = false; }; } // namespace gpu