InferenceCalculatorAdvancedGL save cache in Open().

PiperOrigin-RevId: 547652481
This commit is contained in:
MediaPipe Team 2023-07-12 18:07:02 -07:00 committed by Copybara-Service
parent a2cd3e7f95
commit cc2aa4f4cc
3 changed files with 16 additions and 4 deletions

View File

@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
gpu_delegate_options); gpu_delegate_options);
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool UseSerializedModel() const { return use_serialized_model_; }
private: private:
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
} }
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
MP_RETURN_IF_ERROR(
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
return gpu_helper_.RunInGlContext([this]() -> absl::Status { return gpu_helper_.RunInGlContext([this]() -> absl::Status {
tflite_gpu_runner_.reset(); tflite_gpu_runner_.reset();
return absl::OkStatus(); return absl::OkStatus();
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c}; tflite_gpu_runner_->GetOutputShapes()[i].c};
} }
if (on_disk_cache_helper_.UseSerializedModel()) {
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
}
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
return tflite_gpu_runner_->Build(); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
} }
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)

View File

@ -234,6 +234,11 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
if (serialized_model_.empty() &&
opencl_init_from_serialized_model_is_forced_) {
ASSIGN_OR_RETURN(serialized_model_, GetSerializedModel());
}
// Try to initialize from serialized model first. // Try to initialize from serialized model first.
if (!serialized_model_.empty()) { if (!serialized_model_.empty()) {
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder); absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
@ -270,7 +275,6 @@ absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
} }
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() { absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
RET_CHECK(runner_) << "Runner is in invalid state.";
if (serialized_model_used_) { if (serialized_model_used_) {
return serialized_model_; return serialized_model_;
} }

View File

@ -62,6 +62,9 @@ class TFLiteGPURunner {
void ForceOpenGL() { opengl_is_forced_ = true; } void ForceOpenGL() { opengl_is_forced_ = true; }
void ForceOpenCL() { opencl_is_forced_ = true; } void ForceOpenCL() { opencl_is_forced_ = true; }
void ForceOpenCLInitFromSerializedModel() {
opencl_init_from_serialized_model_is_forced_ = true;
}
absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id); absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id);
absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id); absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id);
@ -141,6 +144,7 @@ class TFLiteGPURunner {
bool opencl_is_forced_ = false; bool opencl_is_forced_ = false;
bool opengl_is_forced_ = false; bool opengl_is_forced_ = false;
bool opencl_init_from_serialized_model_is_forced_ = false;
}; };
} // namespace gpu } // namespace gpu