Support more GPU formats in tensor converter calculator.

PiperOrigin-RevId: 556987807
This commit is contained in:
MediaPipe Team 2023-08-14 19:54:28 -07:00 committed by Copybara-Service
parent a183212a13
commit b6f5414b3d
3 changed files with 38 additions and 22 deletions

View File

@ -652,6 +652,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/gpu:gpu_buffer_format",
"//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -704,6 +705,7 @@ cc_test(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <string> #include <string>
#include <vector> #include <vector>
@ -25,6 +26,7 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/gpu/gpu_origin.pb.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -406,16 +408,27 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) {
// Get input image sizes. // Get input image sizes.
const auto& input = const auto& input =
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format = mediapipe::GpuBufferFormat format = input.format();
mediapipe::ImageFormatForGpuBufferFormat(input.format());
const bool include_alpha = (max_num_channels_ == 4); const bool include_alpha = (max_num_channels_ == 4);
const bool single_channel = (max_num_channels_ == 1); const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB || RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
format == mediapipe::ImageFormat::SRGBA)) format == mediapipe::GpuBufferFormat::kRGB24 ||
RET_CHECK_FAIL() << "Unsupported GPU input format."; format == mediapipe::GpuBufferFormat::kRGBA32 ||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
RET_CHECK_FAIL() << "Num input channels is less than desired output."; format == mediapipe::GpuBufferFormat::kRGBAHalf64 ||
format == mediapipe::GpuBufferFormat::kGrayFloat32 ||
format == mediapipe::GpuBufferFormat::kGrayHalf16 ||
format == mediapipe::GpuBufferFormat::kOneComponent8)
<< "Unsupported GPU input format: " << static_cast<uint32_t>(format);
if (include_alpha) {
RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
format == mediapipe::GpuBufferFormat::kRGBA32 ||
format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
format == mediapipe::GpuBufferFormat::kRGBAHalf64)
<< "Num input channels is less than desired output, input format: "
<< static_cast<uint32_t>(format);
}
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;

View File

@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdint>
#include <memory>
#include <random> #include <random>
#include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -24,8 +27,10 @@
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT #include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h" #include "mediapipe/framework/tool/validate_type.h"
@ -40,7 +45,6 @@ constexpr char kTransposeOptionsString[] =
} // namespace } // namespace
using RandomEngine = std::mt19937_64; using RandomEngine = std::mt19937_64;
using testing::Eq;
const uint32_t kSeed = 1234; const uint32_t kSeed = 1234;
const int kNumSizes = 8; const int kNumSizes = 8;
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2}, const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
@ -127,7 +131,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) {
auto tensor_buffer = view.buffer<float>(); auto tensor_buffer = view.buffer<float>();
for (int i = 0; i < num_rows * num_columns; ++i) { for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random); const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; EXPECT_FLOAT_EQ(tensor_buffer[i], expected) << "at i = " << i;
} }
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
@ -189,7 +193,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) {
auto tensor_buffer = view.buffer<float>(); auto tensor_buffer = view.buffer<float>();
for (int i = 0; i < num_rows * num_columns; ++i) { for (int i = 0; i < num_rows * num_columns; ++i) {
const float expected = uniform_dist(random); const float expected = uniform_dist(random);
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i; EXPECT_EQ(tensor_buffer[i], expected) << "at i = " << i;
} }
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
@ -244,7 +248,7 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) {
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type()); EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type());
auto view = tensor->GetCpuReadView(); auto view = tensor->GetCpuReadView();
EXPECT_FLOAT_EQ(67.0f, *view.buffer<float>()); EXPECT_FLOAT_EQ(*view.buffer<float>(), 67.0f);
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -299,16 +303,13 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
const Tensor* tensor = &tensor_vec[0]; const Tensor* tensor = &tensor_vec[0];
// Calculate the expected normalized value: // Calculate the expected normalized value:
float normalized_value = float expected_value =
range.first + (200 * (range.second - range.first)) / 255.0; range.first + (200 * (range.second - range.first)) / 255.0;
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
auto view = tensor->GetCpuReadView(); auto view = tensor->GetCpuReadView();
float dataf = *view.buffer<float>(); float actual_value = *view.buffer<float>();
EXPECT_THAT( EXPECT_FLOAT_EQ(actual_value, expected_value);
normalized_value,
testing::FloatNear(dataf, 2.0f * std::abs(dataf) *
std::numeric_limits<float>::epsilon()));
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -362,8 +363,8 @@ TEST_F(TensorConverterCalculatorTest, FlipVertically) {
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
const float* dataf = tensor->GetCpuReadView().buffer<float>(); const float* dataf = tensor->GetCpuReadView().buffer<float>();
EXPECT_EQ(kY1Value, static_cast<int>(roundf(dataf[0]))); // Y0, Y1 flipped! EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY1Value); // Y0, Y1 flipped!
EXPECT_EQ(kY0Value, static_cast<int>(roundf(dataf[1]))); EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY0Value);
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().
@ -417,8 +418,8 @@ TEST_F(TensorConverterCalculatorTest, GpuOriginOverridesFlipVertically) {
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32); EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
const float* dataf = tensor->GetCpuReadView().buffer<float>(); const float* dataf = tensor->GetCpuReadView().buffer<float>();
EXPECT_EQ(kY0Value, static_cast<int>(roundf(dataf[0]))); // Not flipped! EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY0Value); // Not flipped!
EXPECT_EQ(kY1Value, static_cast<int>(roundf(dataf[1]))); EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY1Value);
// Fully close graph at end, otherwise calculator+tensors are destroyed // Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone(). // after calling WaitUntilDone().