diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e570ce8ba..35a73fd8f 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,6 +1,7 @@ #include "mediapipe/gpu/gpu_buffer.h" #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const { "]"); } -internal::GpuBufferStorage& GpuBuffer::GetStorageForView( +internal::GpuBufferStorage* GpuBuffer::GetStorageForView( TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; @@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( // TODO: choose best conversion. if (!chosen_storage) { for (const auto& s : storages_) { - auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider(view_provider_type, - s->storage_type()); - if (converter) { - storages_.push_back(converter(s)); - chosen_storage = &storages_.back(); + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + if (auto new_storage = converter(s)) { + storages_.push_back(new_storage); + chosen_storage = &storages_.back(); + break; + } } } } if (for_writing) { - if (!chosen_storage) { - // Allocate a new storage supporting the requested view. - auto factory = internal::GpuBufferStorageRegistry::Get() - .StorageFactoryForViewProvider(view_provider_type); - if (factory) { - storages_ = {factory(width(), height(), format())}; - chosen_storage = &storages_.back(); - } - } else { + if (chosen_storage) { // Discard all other storages. storages_ = {*chosen_storage}; chosen_storage = &storages_.back(); + } else { + // Allocate a new storage supporting the requested view. + if (auto factory = + internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type)) { + if (auto new_storage = factory(width(), height(), format())) { + storages_ = {std::move(new_storage)}; + chosen_storage = &storages_.back(); + } + } } } + return chosen_storage ? chosen_storage->get() : nullptr; +} +internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( + TypeId view_provider_type, bool for_writing) const { + auto* chosen_storage = + GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " << absl::StrJoin(storages_, ", ", StorageTypeFormatter()); - DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); - return **chosen_storage; + DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + return *chosen_storage; } #if !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 57e077151..ad5c130b5 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -105,7 +105,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetReadView(Args... args) const { - return GetViewProvider(false)->GetReadView( + return GetViewProviderOrDie(false).GetReadView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -114,7 +114,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetWriteView(Args... args) { - return GetViewProvider(true)->GetWriteView( + return GetViewProviderOrDie(true).GetWriteView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -147,13 +147,17 @@ class GpuBuffer { GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, bool for_writing) const; + internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, + bool for_writing) const; + template - internal::ViewProvider* GetViewProvider(bool for_writing) const { + internal::ViewProvider& GetViewProviderOrDie(bool for_writing) const { using VP = internal::ViewProvider; - return GetStorageForView(kTypeId, for_writing).template down_cast(); + return *GetStorageForViewOrDie(kTypeId, for_writing) + .template down_cast(); } std::shared_ptr& no_storage() const {