diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index 6ec53603b..94842c769 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -15,6 +15,9 @@ #ifndef MEDIAPIPE_GPU_GPU_TEST_BASE_H_ #define MEDIAPIPE_GPU_GPU_TEST_BASE_H_ +#include +#include + #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/gpu/gl_calculator_helper.h" @@ -22,9 +25,9 @@ namespace mediapipe { -class GpuTestBase : public ::testing::Test { +class GpuTestEnvironment { protected: - GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } + GpuTestEnvironment() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); @@ -35,6 +38,12 @@ class GpuTestBase : public ::testing::Test { GlCalculatorHelper helper_; }; +class GpuTestBase : public testing::Test, public GpuTestEnvironment {}; + +template +class GpuTestWithParamBase : public testing::TestWithParam, + public GpuTestEnvironment {}; + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_TEST_BASE_H_