Make cache writes optional in InferenceCalculatorAdvancedGL
Previously, caches were always written, and an error would cause the graph to close abruptly. This prevented services with read-only access to the cache from using the calculator. The new behavior allows services to choose whether or not to write caches. PiperOrigin-RevId: 561866791
This commit is contained in:
parent
dea6ccba25
commit
de0c7f2a30
|
@ -88,6 +88,20 @@ message InferenceCalculatorOptions {
|
|||
// serialized model is invalid or missing.
|
||||
optional string serialized_model_dir = 7;
|
||||
|
||||
enum CacheWritingBehavior {
|
||||
// Do not write any caches.
|
||||
NO_WRITE = 0;
|
||||
|
||||
// Try to write caches, log on failure.
|
||||
TRY_WRITE = 1;
|
||||
|
||||
// Write caches or return an error if write fails.
|
||||
WRITE_OR_ERROR = 2;
|
||||
}
|
||||
// Specifies how GPU caches are written to disk.
|
||||
optional CacheWritingBehavior cache_writing_behavior = 10
|
||||
[default = WRITE_OR_ERROR];
|
||||
|
||||
// Unique token identifying the model. Used in conjunction with
|
||||
// "serialized_model_dir". It is the caller's responsibility to ensure
|
||||
// there is no clash of the tokens.
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,6 +27,7 @@
|
|||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
#include "absl/log/absl_log.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/filesystem.h"
|
||||
|
@ -68,14 +70,21 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
|
||||
gpu_delegate_options);
|
||||
absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
// Writes caches to disk based on |cache_writing_behavior_|.
|
||||
absl::Status SaveGpuCachesBasedOnBehavior(
|
||||
tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
bool UseSerializedModel() const { return use_serialized_model_; }
|
||||
|
||||
private:
|
||||
// Writes caches to disk, returns error on failure.
|
||||
absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const;
|
||||
|
||||
bool use_kernel_caching_ = false;
|
||||
std::string cached_kernel_filename_;
|
||||
bool use_serialized_model_ = false;
|
||||
std::string serialized_model_path_;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::CacheWritingBehavior
|
||||
cache_writing_behavior_;
|
||||
};
|
||||
|
||||
// Helper class that wraps everything related to GPU inference acceleration.
|
||||
|
@ -232,7 +241,8 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
|||
MP_RETURN_IF_ERROR(
|
||||
on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get()));
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||
return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get());
|
||||
return on_disk_cache_helper_.SaveGpuCachesBasedOnBehavior(
|
||||
tflite_gpu_runner_.get());
|
||||
}
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||
|
@ -261,9 +271,36 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
|
|||
mediapipe::file::JoinPath(gpu_delegate_options.serialized_model_dir(),
|
||||
gpu_delegate_options.model_token());
|
||||
}
|
||||
cache_writing_behavior_ = gpu_delegate_options.has_cache_writing_behavior()
|
||||
? gpu_delegate_options.cache_writing_behavior()
|
||||
: mediapipe::InferenceCalculatorOptions::
|
||||
Delegate::Gpu::WRITE_OR_ERROR;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::
|
||||
SaveGpuCachesBasedOnBehavior(
|
||||
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
|
||||
switch (cache_writing_behavior_) {
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::NO_WRITE:
|
||||
return absl::OkStatus();
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::TRY_WRITE: {
|
||||
auto status = SaveGpuCaches(gpu_runner);
|
||||
if (!status.ok()) {
|
||||
ABSL_LOG_FIRST_N(WARNING, 1) << "Failed to save gpu caches: " << status;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::WRITE_OR_ERROR:
|
||||
return SaveGpuCaches(gpu_runner);
|
||||
default:
|
||||
ABSL_LOG_FIRST_N(ERROR, 1)
|
||||
<< "Unknown cache writing behavior: "
|
||||
<< static_cast<uint32_t>(cache_writing_behavior_);
|
||||
return absl::InvalidArgumentError("Unknown cache writing behavior.");
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status
|
||||
InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
|
||||
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
|
||||
|
|
|
@ -122,6 +122,8 @@ describe('TaskRunner', () => {
|
|||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
cacheWritingBehavior: InferenceCalculatorOptions.Delegate.Gpu
|
||||
.CacheWritingBehavior.WRITE_OR_ERROR,
|
||||
modelToken: undefined,
|
||||
usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage
|
||||
.SUSTAINED_SPEED,
|
||||
|
|
Loading…
Reference in New Issue
Block a user