Tensor: Fix use_ahwb_ flag and tests on local device involved.
PiperOrigin-RevId: 498249332
This commit is contained in:
parent
9580f04571
commit
1924f1cdff
|
@ -458,7 +458,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
|
|||
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
|
||||
}
|
||||
}
|
||||
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||
// Keep flag value if it was set previously.
|
||||
use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||
}
|
||||
|
||||
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
|
|
@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase {
|
|||
};
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
{
|
||||
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||
}
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
|
@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
|
|||
}
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})};
|
||||
{
|
||||
// Request Ahwb first to get Ahwb storage allocated internally.
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
|
@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
|
|||
TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
||||
// Request the CPU view to get the memory to be allocated.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
{
|
||||
|
@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
|||
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
|
||||
// Request the GPU view to get the ssbo allocated internally.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
RunInGlContext([&tensor] {
|
||||
|
|
|
@ -1,34 +1,28 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class TensorAhwbTest : public mediapipe::GpuTestBase {
|
||||
public:
|
||||
};
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
|
||||
TEST(TensorAhwbTest, TestCpuThenAHWB) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
||||
TEST(TensorAhwbTest, TestAHWBThenCpu) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
auto view = tensor.GetAHardwareBufferWriteView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetWritingFinishedFD(-1, [](bool) { return true; });
|
||||
}
|
||||
{
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
|
@ -36,21 +30,4 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenGl) {
|
||||
RunInGlContext([] {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
{
|
||||
auto ssbo = tensor.GetOpenGlBufferReadView().name();
|
||||
EXPECT_GT(ssbo, 0);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
|
Loading…
Reference in New Issue
Block a user