Internal change
PiperOrigin-RevId: 488797407
This commit is contained in:
parent
6702ef3d57
commit
77b3edbb67
|
@ -1,6 +1,7 @@
|
|||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
|
@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const {
|
|||
"]");
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
|
||||
internal::GpuBufferStorage* GpuBuffer::GetStorageForView(
|
||||
TypeId view_provider_type, bool for_writing) const {
|
||||
const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr;
|
||||
|
||||
|
@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView(
|
|||
// TODO: choose best conversion.
|
||||
if (!chosen_storage) {
|
||||
for (const auto& s : storages_) {
|
||||
auto converter = internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageConverterForViewProvider(view_provider_type,
|
||||
s->storage_type());
|
||||
if (converter) {
|
||||
storages_.push_back(converter(s));
|
||||
chosen_storage = &storages_.back();
|
||||
if (auto converter = internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageConverterForViewProvider(
|
||||
view_provider_type, s->storage_type())) {
|
||||
if (auto new_storage = converter(s)) {
|
||||
storages_.push_back(new_storage);
|
||||
chosen_storage = &storages_.back();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (for_writing) {
|
||||
if (!chosen_storage) {
|
||||
// Allocate a new storage supporting the requested view.
|
||||
auto factory = internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageFactoryForViewProvider(view_provider_type);
|
||||
if (factory) {
|
||||
storages_ = {factory(width(), height(), format())};
|
||||
chosen_storage = &storages_.back();
|
||||
}
|
||||
} else {
|
||||
if (chosen_storage) {
|
||||
// Discard all other storages.
|
||||
storages_ = {*chosen_storage};
|
||||
chosen_storage = &storages_.back();
|
||||
} else {
|
||||
// Allocate a new storage supporting the requested view.
|
||||
if (auto factory =
|
||||
internal::GpuBufferStorageRegistry::Get()
|
||||
.StorageFactoryForViewProvider(view_provider_type)) {
|
||||
if (auto new_storage = factory(width(), height(), format())) {
|
||||
storages_ = {std::move(new_storage)};
|
||||
chosen_storage = &storages_.back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chosen_storage ? chosen_storage->get() : nullptr;
|
||||
}
|
||||
|
||||
internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie(
|
||||
TypeId view_provider_type, bool for_writing) const {
|
||||
auto* chosen_storage =
|
||||
GpuBuffer::GetStorageForView(view_provider_type, for_writing);
|
||||
CHECK(chosen_storage) << "no view provider found for requested view "
|
||||
<< view_provider_type.name() << "; storages available: "
|
||||
<< absl::StrJoin(storages_, ", ",
|
||||
StorageTypeFormatter());
|
||||
DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type));
|
||||
return **chosen_storage;
|
||||
DCHECK(chosen_storage->can_down_cast_to(view_provider_type));
|
||||
return *chosen_storage;
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
|
|
@ -105,7 +105,7 @@ class GpuBuffer {
|
|||
// specific view type; see the corresponding ViewProvider.
|
||||
template <class View, class... Args>
|
||||
decltype(auto) GetReadView(Args... args) const {
|
||||
return GetViewProvider<View>(false)->GetReadView(
|
||||
return GetViewProviderOrDie<View>(false).GetReadView(
|
||||
internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ class GpuBuffer {
|
|||
// specific view type; see the corresponding ViewProvider.
|
||||
template <class View, class... Args>
|
||||
decltype(auto) GetWriteView(Args... args) {
|
||||
return GetViewProvider<View>(true)->GetWriteView(
|
||||
return GetViewProviderOrDie<View>(true).GetWriteView(
|
||||
internal::types<View>{}, std::make_shared<GpuBuffer>(*this),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
@ -147,13 +147,17 @@ class GpuBuffer {
|
|||
GpuBufferFormat format_ = GpuBufferFormat::kUnknown;
|
||||
};
|
||||
|
||||
internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type,
|
||||
internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type,
|
||||
bool for_writing) const;
|
||||
|
||||
internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type,
|
||||
bool for_writing) const;
|
||||
|
||||
template <class View>
|
||||
internal::ViewProvider<View>* GetViewProvider(bool for_writing) const {
|
||||
internal::ViewProvider<View>& GetViewProviderOrDie(bool for_writing) const {
|
||||
using VP = internal::ViewProvider<View>;
|
||||
return GetStorageForView(kTypeId<VP>, for_writing).template down_cast<VP>();
|
||||
return *GetStorageForViewOrDie(kTypeId<VP>, for_writing)
|
||||
.template down_cast<VP>();
|
||||
}
|
||||
|
||||
std::shared_ptr<internal::GpuBufferStorage>& no_storage() const {
|
||||
|
|
Loading…
Reference in New Issue
Block a user