Tensor: Fix use_ahwb_ flag and tests on local device involved.

PiperOrigin-RevId: 498249332
This commit is contained in:
Nikolay Chirkov 2022-12-28 14:27:42 -08:00 committed by Copybara-Service
parent 9580f04571
commit 1924f1cdff
3 changed files with 22 additions and 36 deletions

View File

@ -458,7 +458,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
} }
} }
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); // Keep flag value if it was set previously.
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
} }
#else // MEDIAPIPE_TENSOR_USE_AHWB #else // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
}; };
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
}
RunInGlContext([&tensor] { RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name(); auto ssbo_name = ssbo_view.name();
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
} }
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
{
// Request Ahwb first to get Ahwb storage allocated internally.
auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
}
RunInGlContext([&tensor] { RunInGlContext([&tensor] {
auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_view = tensor.GetOpenGlBufferWriteView();
auto ssbo_name = ssbo_view.name(); auto ssbo_name = ssbo_view.name();
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
// Request the CPU view to get the memory to be allocated. // Request the CPU view to get the memory to be allocated.
// Request Ahwb view then to transform the storage into Ahwb. // Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
{ {
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
// Request the GPU view to get the ssbo allocated internally. // Request the GPU view to get the ssbo allocated internally.
// Request Ahwb view then to transform the storage into Ahwb. // Request Ahwb view then to transform the storage into Ahwb.
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
constexpr size_t num_elements = 20; constexpr size_t num_elements = 20;
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
RunInGlContext([&tensor] { RunInGlContext([&tensor] {

View File

@ -1,34 +1,28 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_test_base.h"
#include "testing/base/public/gmock.h" #include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h" #include "testing/base/public/gunit.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#if !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
class TensorAhwbTest : public mediapipe::GpuTestBase { TEST(TensorAhwbTest, TestCpuThenAHWB) {
public:
};
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{ {
auto ptr = tensor.GetCpuWriteView().buffer<float>(); auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr); EXPECT_NE(ptr, nullptr);
} }
{ {
auto ahwb = tensor.GetAHardwareBufferReadView().handle(); auto view = tensor.GetAHardwareBufferReadView();
EXPECT_NE(ahwb, nullptr); EXPECT_NE(view.handle(), nullptr);
view.SetReadingFinishedFunc([](bool) { return true; });
} }
} }
TEST_F(TensorAhwbTest, TestAHWBThenCpu) { TEST(TensorAhwbTest, TestAHWBThenCpu) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{ {
auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); auto view = tensor.GetAHardwareBufferWriteView();
EXPECT_NE(ahwb, nullptr); EXPECT_NE(view.handle(), nullptr);
view.SetWritingFinishedFD(-1, [](bool) { return true; });
} }
{ {
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
@ -36,21 +30,4 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
} }
} }
TEST_F(TensorAhwbTest, TestCpuThenGl) {
RunInGlContext([] {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
}
{
auto ssbo = tensor.GetOpenGlBufferReadView().name();
EXPECT_GT(ssbo, 0);
}
});
}
} // namespace mediapipe } // namespace mediapipe
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TENSOR_USE_AHWB