InferenceCalculatorAdvancedGL save cache in Open().
PiperOrigin-RevId: 547652481
This commit is contained in:
parent
a2cd3e7f95
commit
cc2aa4f4cc
|
@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl
|
||||||
gpu_delegate_options);
|
gpu_delegate_options);
|
||||||
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||||
|
bool UseSerializedModel() const { return use_serialized_model_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool use_kernel_caching_ = false;
|
bool use_kernel_caching_ = false;
|
||||||
|
@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() {
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()));
|
|
||||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||||
tflite_gpu_runner_.reset();
|
tflite_gpu_runner_.reset();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
||||||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (on_disk_cache_helper_.UseSerializedModel()) {
|
||||||
|
tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel();
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
||||||
return tflite_gpu_runner_->Build();
|
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||||
|
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
|
@ -234,6 +234,11 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
|
||||||
|
|
||||||
|
if (serialized_model_.empty() &&
|
||||||
|
opencl_init_from_serialized_model_is_forced_) {
|
||||||
|
ASSIGN_OR_RETURN(serialized_model_, GetSerializedModel());
|
||||||
|
}
|
||||||
|
|
||||||
// Try to initialize from serialized model first.
|
// Try to initialize from serialized model first.
|
||||||
if (!serialized_model_.empty()) {
|
if (!serialized_model_.empty()) {
|
||||||
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
|
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
|
||||||
|
@ -270,7 +275,6 @@ absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
|
||||||
RET_CHECK(runner_) << "Runner is in invalid state.";
|
|
||||||
if (serialized_model_used_) {
|
if (serialized_model_used_) {
|
||||||
return serialized_model_;
|
return serialized_model_;
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,9 @@ class TFLiteGPURunner {
|
||||||
|
|
||||||
void ForceOpenGL() { opengl_is_forced_ = true; }
|
void ForceOpenGL() { opengl_is_forced_ = true; }
|
||||||
void ForceOpenCL() { opencl_is_forced_ = true; }
|
void ForceOpenCL() { opencl_is_forced_ = true; }
|
||||||
|
void ForceOpenCLInitFromSerializedModel() {
|
||||||
|
opencl_init_from_serialized_model_is_forced_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id);
|
absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id);
|
||||||
absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id);
|
absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id);
|
||||||
|
@ -141,6 +144,7 @@ class TFLiteGPURunner {
|
||||||
|
|
||||||
bool opencl_is_forced_ = false;
|
bool opencl_is_forced_ = false;
|
||||||
bool opengl_is_forced_ = false;
|
bool opengl_is_forced_ = false;
|
||||||
|
bool opencl_init_from_serialized_model_is_forced_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|
Loading…
Reference in New Issue
Block a user