Make cache writes optional in InferenceCalculatorAdvancedGL

Previously, caches were always written, and an error would cause the graph to close abruptly. This prevented services with read-only access to the cache from using the calculator.

The new behavior allows services to choose whether or not to write caches.

PiperOrigin-RevId: 561866791
This commit is contained in:
MediaPipe Team 2023-08-31 23:29:04 -07:00 committed by Copybara-Service
parent dea6ccba25
commit de0c7f2a30
3 changed files with 55 additions and 2 deletions

View File

@ -88,6 +88,20 @@ message InferenceCalculatorOptions {
// serialized model is invalid or missing. // serialized model is invalid or missing.
optional string serialized_model_dir = 7; optional string serialized_model_dir = 7;
enum CacheWritingBehavior {
// Do not write any caches.
NO_WRITE = 0;
// Try to write caches, log on failure.
TRY_WRITE = 1;
// Write caches or return an error if write fails.
WRITE_OR_ERROR = 2;
}
// Specifies how GPU caches are written to disk.
optional CacheWritingBehavior cache_writing_behavior = 10
[default = WRITE_OR_ERROR];
// Unique token identifying the model. Used in conjunction with // Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure // "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens. // there is no clash of the tokens.

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,6 +27,7 @@
#include "mediapipe/util/tflite/tflite_gpu_runner.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h"
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#include "absl/log/absl_log.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h" #include "mediapipe/util/android/file/base/filesystem.h"
@ -68,14 +70,21 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
gpu_delegate_options); gpu_delegate_options);
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; // Writes caches to disk based on |cache_writing_behavior_|.
absl::Status SaveGpuCachesBasedOnBehavior(
tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool UseSerializedModel() const { return use_serialized_model_; } bool UseSerializedModel() const { return use_serialized_model_; }
private: private:
// Writes caches to disk, returns error on failure.
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
std::string cached_kernel_filename_; std::string cached_kernel_filename_;
bool use_serialized_model_ = false; bool use_serialized_model_ = false;
std::string serialized_model_path_; std::string serialized_model_path_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::CacheWritingBehavior
cache_writing_behavior_;
}; };
// Helper class that wraps everything related to GPU inference acceleration. // Helper class that wraps everything related to GPU inference acceleration.
@ -232,7 +241,8 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()); return on_disk_cache_helper_.SaveGpuCachesBasedOnBehavior(
tflite_gpu_runner_.get());
} }
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
@ -261,9 +271,36 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(), mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(),
gpu_delegate_options.model_token()); gpu_delegate_options.model_token());
} }
cache_writing_behavior_ = gpu_delegate_options.has_cache_writing_behavior()
? gpu_delegate_options.cache_writing_behavior()
: mediapipe::InferenceCalculatorOptions::
Delegate::Gpu::WRITE_OR_ERROR;
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::
SaveGpuCachesBasedOnBehavior(
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
switch (cache_writing_behavior_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::NO_WRITE:
return absl::OkStatus();
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::TRY_WRITE: {
auto status = SaveGpuCaches(gpu_runner);
if (!status.ok()) {
ABSL_LOG_FIRST_N(WARNING, 1) << "Failed to save gpu caches: " << status;
}
return absl::OkStatus();
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::WRITE_OR_ERROR:
return SaveGpuCaches(gpu_runner);
default:
ABSL_LOG_FIRST_N(ERROR, 1)
<< "Unknown cache writing behavior: "
<< static_cast<uint32_t>(cache_writing_behavior_);
return absl::InvalidArgumentError("Unknown cache writing behavior.");
}
}
absl::Status absl::Status
InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
tflite::gpu::TFLiteGPURunner* gpu_runner) const { tflite::gpu::TFLiteGPURunner* gpu_runner) const {

View File

@ -122,6 +122,8 @@ describe('TaskRunner', () => {
allowPrecisionLoss: true, allowPrecisionLoss: true,
cachedKernelPath: undefined, cachedKernelPath: undefined,
serializedModelDir: undefined, serializedModelDir: undefined,
cacheWritingBehavior: InferenceCalculatorOptions.Delegate.Gpu
.CacheWritingBehavior.WRITE_OR_ERROR,
modelToken: undefined, modelToken: undefined,
usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage
.SUSTAINED_SPEED, .SUSTAINED_SPEED,