Verify that kernel cache is only used when OpenCL is active

PiperOrigin-RevId: 491463306
This commit is contained in:
MediaPipe Team 2022-11-28 15:46:30 -08:00 committed by Copybara-Service
parent 26a7ca5c64
commit 7b74fd53f5
3 changed files with 9 additions and 7 deletions

View File

@ -258,9 +258,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
tflite::gpu::TFLiteGPURunner* gpu_runner) const { tflite::gpu::TFLiteGPURunner* gpu_runner) const {
if (use_kernel_caching_) { if (use_kernel_caching_) {
// Save kernel file. // Save kernel file.
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>( ASSIGN_OR_RETURN(std::vector<uint8_t> kernel_cache,
gpu_runner->GetSerializedBinaryCache()); gpu_runner->GetSerializedBinaryCache());
std::string cache_str(kernel_cache->begin(), kernel_cache->end()); std::string cache_str(kernel_cache.begin(), kernel_cache.end());
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
} }

View File

@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() {
#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID)
if (use_kernel_caching_) { if (use_kernel_caching_) {
// Save kernel file. // Save kernel file.
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>( ASSIGN_OR_RETURN(std::vector<uint8_t> kernel_cache,
tflite_gpu_runner_->GetSerializedBinaryCache()); tflite_gpu_runner_->GetSerializedBinaryCache());
std::string cache_str(kernel_cache->begin(), kernel_cache->end()); std::string cache_str(kernel_cache.begin(), kernel_cache.end());
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
} }

View File

@ -21,6 +21,7 @@
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -89,7 +90,8 @@ class TFLiteGPURunner {
serialized_binary_cache_ = std::move(cache); serialized_binary_cache_ = std::move(cache);
} }
std::vector<uint8_t> GetSerializedBinaryCache() { absl::StatusOr<std::vector<uint8_t>> GetSerializedBinaryCache() {
RET_CHECK(cl_environment_) << "CL environment is not initialized.";
return cl_environment_->GetSerializedBinaryCache(); return cl_environment_->GetSerializedBinaryCache();
} }