Internal change

PiperOrigin-RevId: 488797407
This commit is contained in:
Camillo Lugaresi 2022-11-15 17:04:39 -08:00 committed by Copybara-Service
parent 6702ef3d57
commit 77b3edbb67
2 changed files with 38 additions and 23 deletions

View File

@ -1,6 +1,7 @@
#include "mediapipe/gpu/gpu_buffer.h"
#include <memory>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const {
"]");
}
internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
internal::GpuBufferStorage* GpuBuffer::GetStorageForView(
TypeId view_provider_type, bool for_writing) const {
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
// TODO: choose best conversion.
if (!chosen_storage) {
for (const auto& s : storages_) {
auto converter = internal::GpuBufferStorageRegistry::Get()
.StorageConverterForViewProvider(view_provider_type,
s->storage_type());
if (converter) {
storages_.push_back(converter(s));
if (auto converter = internal::GpuBufferStorageRegistry::Get()
.StorageConverterForViewProvider(
view_provider_type, s->storage_type())) {
if (auto new_storage = converter(s)) {
storages_.push_back(new_storage);
chosen_storage = &storages_.back();
break;
}
}
}
}
if (for_writing) {
if (!chosen_storage) {
// Allocate a new storage supporting the requested view.
auto factory = internal::GpuBufferStorageRegistry::Get()
.StorageFactoryForViewProvider(view_provider_type);
if (factory) {
storages_ = {factory(width(), height(), format())};
chosen_storage = &storages_.back();
}
} else {
if (chosen_storage) {
// Discard all other storages.
storages_ = {*chosen_storage};
chosen_storage = &storages_.back();
} else {
// Allocate a new storage supporting the requested view.
if (auto factory =
internal::GpuBufferStorageRegistry::Get()
.StorageFactoryForViewProvider(view_provider_type)) {
if (auto new_storage = factory(width(), height(), format())) {
storages_ = {std::move(new_storage)};
chosen_storage = &storages_.back();
}
}
}
}
return chosen_storage ? chosen_storage->get() : nullptr;
}
internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
TypeId view_provider_type, bool for_writing) const {
auto* chosen_storage =
GpuBuffer::GetStorageForView(view_provider_type, for_writing);
CHECK(chosen_storage) << "no view provider found for requested view "
<< view_provider_type.name() << "; storages available: "
<< absl::StrJoin(storages_, ", ",
StorageTypeFormatter());
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type));
return **chosen_storage;
DCHECK(chosen_storage->can_down_cast_to(view_provider_type));
return *chosen_storage;
}
#if !MEDIAPIPE_DISABLE_GPU

View File

@ -105,7 +105,7 @@ class GpuBuffer {
// specific view type; see the corresponding ViewProvider.
template <class View, class... Args>
decltype(auto) GetReadView(Args... args) const {
return GetViewProvider<View>(false)->GetReadView(
return GetViewProviderOrDie<View>(false).GetReadView(
internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
std::forward<Args>(args)...);
}
@ -114,7 +114,7 @@ class GpuBuffer {
// specific view type; see the corresponding ViewProvider.
template <class View, class... Args>
decltype(auto) GetWriteView(Args... args) {
return GetViewProvider<View>(true)->GetWriteView(
return GetViewProviderOrDie<View>(true).GetWriteView(
internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
std::forward<Args>(args)...);
}
@ -147,13 +147,17 @@ class GpuBuffer {
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
};
internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type,
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
bool for_writing) const;
internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type,
bool for_writing) const;
template <class View>
internal::ViewProvider<View>* GetViewProvider(bool for_writing) const {
internal::ViewProvider<View>& GetViewProviderOrDie(bool for_writing) const {
using VP = internal::ViewProvider<View>;
return GetStorageForView(kTypeId<VP>, for_writing).template down_cast<VP>();
return *GetStorageForViewOrDie(kTypeId<VP>, for_writing)
.template down_cast<VP>();
}
std::shared_ptr<internal::GpuBufferStorage>& no_storage() const {