Support more GPU formats in tensor converter calculator.
PiperOrigin-RevId: 556987807
This commit is contained in:
parent
a183212a13
commit
b6f5414b3d
|
@ -652,6 +652,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/gpu:gpu_buffer_format",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -704,6 +705,7 @@ cc_test(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -25,6 +26,7 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -406,16 +408,27 @@ absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) {
|
|||
// Get input image sizes.
|
||||
const auto& input =
|
||||
cc->Inputs().Tag(kGpuBufferTag).Get<mediapipe::GpuBuffer>();
|
||||
mediapipe::ImageFormat::Format format =
|
||||
mediapipe::ImageFormatForGpuBufferFormat(input.format());
|
||||
mediapipe::GpuBufferFormat format = input.format();
|
||||
const bool include_alpha = (max_num_channels_ == 4);
|
||||
const bool single_channel = (max_num_channels_ == 1);
|
||||
if (!(format == mediapipe::ImageFormat::GRAY8 ||
|
||||
format == mediapipe::ImageFormat::SRGB ||
|
||||
format == mediapipe::ImageFormat::SRGBA))
|
||||
RET_CHECK_FAIL() << "Unsupported GPU input format.";
|
||||
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
|
||||
RET_CHECK_FAIL() << "Num input channels is less than desired output.";
|
||||
|
||||
RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGB24 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBA32 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBAHalf64 ||
|
||||
format == mediapipe::GpuBufferFormat::kGrayFloat32 ||
|
||||
format == mediapipe::GpuBufferFormat::kGrayHalf16 ||
|
||||
format == mediapipe::GpuBufferFormat::kOneComponent8)
|
||||
<< "Unsupported GPU input format: " << static_cast<uint32_t>(format);
|
||||
if (include_alpha) {
|
||||
RET_CHECK(format == mediapipe::GpuBufferFormat::kBGRA32 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBA32 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBAFloat128 ||
|
||||
format == mediapipe::GpuBufferFormat::kRGBAHalf64)
|
||||
<< "Num input channels is less than desired output, input format: "
|
||||
<< static_cast<uint32_t>(format);
|
||||
}
|
||||
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
|
|
|
@ -12,7 +12,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
|
@ -24,8 +27,10 @@
|
|||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
|
@ -40,7 +45,6 @@ constexpr char kTransposeOptionsString[] =
|
|||
} // namespace
|
||||
|
||||
using RandomEngine = std::mt19937_64;
|
||||
using testing::Eq;
|
||||
const uint32_t kSeed = 1234;
|
||||
const int kNumSizes = 8;
|
||||
const int sizes[kNumSizes][2] = {{1, 1}, {12, 1}, {1, 9}, {2, 2},
|
||||
|
@ -127,7 +131,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixColMajor) {
|
|||
auto tensor_buffer = view.buffer<float>();
|
||||
for (int i = 0; i < num_rows * num_columns; ++i) {
|
||||
const float expected = uniform_dist(random);
|
||||
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
|
||||
EXPECT_FLOAT_EQ(tensor_buffer[i], expected) << "at i = " << i;
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
|
@ -189,7 +193,7 @@ TEST_F(TensorConverterCalculatorTest, RandomMatrixRowMajor) {
|
|||
auto tensor_buffer = view.buffer<float>();
|
||||
for (int i = 0; i < num_rows * num_columns; ++i) {
|
||||
const float expected = uniform_dist(random);
|
||||
EXPECT_EQ(expected, tensor_buffer[i]) << "at i = " << i;
|
||||
EXPECT_EQ(tensor_buffer[i], expected) << "at i = " << i;
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
|
@ -244,7 +248,7 @@ TEST_F(TensorConverterCalculatorTest, CustomDivAndSub) {
|
|||
const Tensor* tensor = &tensor_vec[0];
|
||||
EXPECT_EQ(Tensor::ElementType::kFloat32, tensor->element_type());
|
||||
auto view = tensor->GetCpuReadView();
|
||||
EXPECT_FLOAT_EQ(67.0f, *view.buffer<float>());
|
||||
EXPECT_FLOAT_EQ(*view.buffer<float>(), 67.0f);
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
|
@ -299,16 +303,13 @@ TEST_F(TensorConverterCalculatorTest, SetOutputRange) {
|
|||
const Tensor* tensor = &tensor_vec[0];
|
||||
|
||||
// Calculate the expected normalized value:
|
||||
float normalized_value =
|
||||
float expected_value =
|
||||
range.first + (200 * (range.second - range.first)) / 255.0;
|
||||
|
||||
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
|
||||
auto view = tensor->GetCpuReadView();
|
||||
float dataf = *view.buffer<float>();
|
||||
EXPECT_THAT(
|
||||
normalized_value,
|
||||
testing::FloatNear(dataf, 2.0f * std::abs(dataf) *
|
||||
std::numeric_limits<float>::epsilon()));
|
||||
float actual_value = *view.buffer<float>();
|
||||
EXPECT_FLOAT_EQ(actual_value, expected_value);
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
|
@ -362,8 +363,8 @@ TEST_F(TensorConverterCalculatorTest, FlipVertically) {
|
|||
|
||||
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
|
||||
const float* dataf = tensor->GetCpuReadView().buffer<float>();
|
||||
EXPECT_EQ(kY1Value, static_cast<int>(roundf(dataf[0]))); // Y0, Y1 flipped!
|
||||
EXPECT_EQ(kY0Value, static_cast<int>(roundf(dataf[1])));
|
||||
EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY1Value); // Y0, Y1 flipped!
|
||||
EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY0Value);
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
|
@ -417,8 +418,8 @@ TEST_F(TensorConverterCalculatorTest, GpuOriginOverridesFlipVertically) {
|
|||
|
||||
EXPECT_EQ(tensor->element_type(), Tensor::ElementType::kFloat32);
|
||||
const float* dataf = tensor->GetCpuReadView().buffer<float>();
|
||||
EXPECT_EQ(kY0Value, static_cast<int>(roundf(dataf[0]))); // Not flipped!
|
||||
EXPECT_EQ(kY1Value, static_cast<int>(roundf(dataf[1])));
|
||||
EXPECT_EQ(static_cast<int>(roundf(dataf[0])), kY0Value); // Not flipped!
|
||||
EXPECT_EQ(static_cast<int>(roundf(dataf[1])), kY1Value);
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
|
|
Loading…
Reference in New Issue
Block a user