Internal change

PiperOrigin-RevId: 488797407
This commit is contained in:
Camillo Lugaresi 2022-11-15 17:04:39 -08:00 committed by Copybara-Service
parent 6702ef3d57
commit 77b3edbb67
2 changed files with 38 additions and 23 deletions

View File

@ -1,6 +1,7 @@
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include <memory> #include <memory>
#include <utility>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const {
"]"); "]");
} }
internal::GpuBufferStorage& GpuBuffer::GetStorageForView( internal::GpuBufferStorage* GpuBuffer::GetStorageForView(
TypeId view_provider_type, bool for_writing) const { TypeId view_provider_type, bool for_writing) const {
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr; const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
// TODO: choose best conversion. // TODO: choose best conversion.
if (!chosen_storage) { if (!chosen_storage) {
for (const auto& s : storages_) { for (const auto& s : storages_) {
auto converter = internal::GpuBufferStorageRegistry::Get() if (auto converter = internal::GpuBufferStorageRegistry::Get()
.StorageConverterForViewProvider(view_provider_type, .StorageConverterForViewProvider(
s->storage_type()); view_provider_type, s->storage_type())) {
if (converter) { if (auto new_storage = converter(s)) {
storages_.push_back(converter(s)); storages_.push_back(new_storage);
chosen_storage = &storages_.back(); chosen_storage = &storages_.back();
break;
}
} }
} }
} }
if (for_writing) { if (for_writing) {
if (!chosen_storage) { if (chosen_storage) {
// Allocate a new storage supporting the requested view.
auto factory = internal::GpuBufferStorageRegistry::Get()
.StorageFactoryForViewProvider(view_provider_type);
if (factory) {
storages_ = {factory(width(), height(), format())};
chosen_storage = &storages_.back();
}
} else {
// Discard all other storages. // Discard all other storages.
storages_ = {*chosen_storage}; storages_ = {*chosen_storage};
chosen_storage = &storages_.back(); chosen_storage = &storages_.back();
} else {
// Allocate a new storage supporting the requested view.
if (auto factory =
internal::GpuBufferStorageRegistry::Get()
.StorageFactoryForViewProvider(view_provider_type)) {
if (auto new_storage = factory(width(), height(), format())) {
storages_ = {std::move(new_storage)};
chosen_storage = &storages_.back();
} }
} }
}
}
return chosen_storage ? chosen_storage->get() : nullptr;
}
internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
TypeId view_provider_type, bool for_writing) const {
auto* chosen_storage =
GpuBuffer::GetStorageForView(view_provider_type, for_writing);
CHECK(chosen_storage) << "no view provider found for requested view " CHECK(chosen_storage) << "no view provider found for requested view "
<< view_provider_type.name() << "; storages available: " << view_provider_type.name() << "; storages available: "
<< absl::StrJoin(storages_, ", ", << absl::StrJoin(storages_, ", ",
StorageTypeFormatter()); StorageTypeFormatter());
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); DCHECK(chosen_storage->can_down_cast_to(view_provider_type));
return **chosen_storage; return *chosen_storage;
} }
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU

View File

@ -105,7 +105,7 @@ class GpuBuffer {
// specific view type; see the corresponding ViewProvider. // specific view type; see the corresponding ViewProvider.
template <class View, class... Args> template <class View, class... Args>
decltype(auto) GetReadView(Args... args) const { decltype(auto) GetReadView(Args... args) const {
return GetViewProvider<View>(false)->GetReadView( return GetViewProviderOrDie<View>(false).GetReadView(
internal::types<View>{}, std::make_shared<GpuBuffer>(*this), internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
@ -114,7 +114,7 @@ class GpuBuffer {
// specific view type; see the corresponding ViewProvider. // specific view type; see the corresponding ViewProvider.
template <class View, class... Args> template <class View, class... Args>
decltype(auto) GetWriteView(Args... args) { decltype(auto) GetWriteView(Args... args) {
return GetViewProvider<View>(true)->GetWriteView( return GetViewProviderOrDie<View>(true).GetWriteView(
internal::types<View>{}, std::make_shared<GpuBuffer>(*this), internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
@ -147,13 +147,17 @@ class GpuBuffer {
GpuBufferFormat format_ = GpuBufferFormat::kUnknown; GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
}; };
internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
bool for_writing) const;
internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type,
bool for_writing) const; bool for_writing) const;
template <class View> template <class View>
internal::ViewProvider<View>* GetViewProvider(bool for_writing) const { internal::ViewProvider<View>& GetViewProviderOrDie(bool for_writing) const {
using VP = internal::ViewProvider<View>; using VP = internal::ViewProvider<View>;
return GetStorageForView(kTypeId<VP>, for_writing).template down_cast<VP>(); return *GetStorageForViewOrDie(kTypeId<VP>, for_writing)
.template down_cast<VP>();
} }
std::shared_ptr<internal::GpuBufferStorage>& no_storage() const { std::shared_ptr<internal::GpuBufferStorage>& no_storage() const {