Changed the image to tensor converter interface such that the "Convert"

function assumes the Tensor is preallocated before calling the function.

PiperOrigin-RevId: 481752678
This commit is contained in:
MediaPipe Team 2022-10-17 15:18:35 -07:00 committed by Copybara-Service
parent a5e4219590
commit 58e5cc3c88
6 changed files with 181 additions and 106 deletions

View File

@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
}
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
const Size size{image->width(), image->height()};
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
options_.output_tensor_height(),
options_.keep_aspect_ratio(), &roi));
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
}
if (kOutMatrix(cc).IsConnected()) {
std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
/*flip_horizontaly=*/false,
&matrix);
GetRotatedSubRectToRectTransformMatrix(
roi, image->width(), image->height(),
/*flip_horizontaly=*/false, &matrix);
kOutMatrix(cc).Send(std::move(matrix));
}
// Lazy initialization of the GPU or CPU converter.
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
ASSIGN_OR_RETURN(Tensor tensor,
(image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, {output_width_, output_height_},
range_min_, range_max_));
Tensor::ElementType output_tensor_type =
GetOutputTensorType(image->UsesGpu());
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>();
result->push_back(std::move(tensor));
@ -292,15 +295,31 @@ class ImageToTensorCalculator : public Node {
}
}
Tensor::ElementType GetOutputTensorType() {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(
cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
#else
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined.";

View File

@ -42,13 +42,16 @@ class ImageToTensorConverter {
// @image contains image to extract from.
// @roi describes region of interest within the image to extract (absolute
// values).
// @output_dims dimensions of output tensor.
// @range_min/max describes output tensor range image pixels should converted
// to.
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims,
float range_min, float range_max) = 0;
// @tensor_buffer_offset an inteter representing the offset of the tensor
// buffer the result should be written to.
// @output_tensor a tensor with pre-defined shape. The "Convert" is
// responsible of populating the content into the output tensor.
virtual absl::Status Convert(const mediapipe::Image& input,
const RotatedRect& roi, float range_min,
float range_max, int tensor_buffer_offset,
Tensor& output_tensor) = 0;
};
} // namespace mediapipe

View File

@ -264,10 +264,10 @@ class GlProcessor : public ImageToTensorConverter {
});
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -275,46 +275,46 @@ class GlProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(Tensor::ElementType::kFloat32,
{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max, tensor_buffer_offset]() -> absl::Status {
constexpr int kRgbaNumChannels = 4;
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
source_texture.width() * source_texture.height() *
kRgbaNumChannels * sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
&output_dims, range_min,
range_max]() -> absl::Status {
constexpr int kRgbaNumChannels = 4;
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
source_texture.width() * source_texture.height() * kRgbaNumChannels *
sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(auto transform,
GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax,
range_min, range_max));
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(
auto transform,
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max));
const int output_size = output_tensor.bytes() / output_shape.dims[0];
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
buffer_view.name(), output_size,
/*offset=*/tensor_buffer_offset,
/*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()),
roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_queue_.get(), &output));
auto buffer_view = tensor.GetOpenGlBufferWriteView();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
buffer_view.name(), tensor.bytes(),
/*offset=*/0,
/*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width),
command_queue_.get(), &output));
return absl::OkStatus();
}));
return absl::OkStatus();
}));
return tensor;
return absl::OkStatus();
}
~GlProcessor() override {
@ -326,6 +326,17 @@ class GlProcessor : public ImageToTensorConverter {
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
std::unique_ptr<SubRectExtractorGl> extractor_;
mediapipe::GlCalculatorHelper gl_helper_;

View File

@ -168,10 +168,10 @@ class GlProcessor : public ImageToTensorConverter {
});
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -179,15 +179,15 @@ class GlProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
// TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(
Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
range_min, range_max]() -> absl::Status {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max]() -> absl::Status {
auto input_texture = gl_helper_.CreateSourceTexture(input);
constexpr float kInputImageRangeMin = 0.0f;
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax,
range_min, range_max));
auto tensor_view = tensor.GetOpenGlTexture2dWriteView();
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
/*flip_horizontaly=*/false,
transform.scale, transform.offset,
output_dims, &tensor_view));
output_shape, &tensor_view));
return absl::OkStatus();
}));
return tensor;
return absl::OkStatus();
}
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta,
const Size& output_dims,
const Tensor::Shape& output_shape,
Tensor::OpenGlTexture2dView* output) {
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
std::array<float, 16> transform_mat;
glDisable(GL_DEPTH_TEST);
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
glViewport(0, 0, output_dims.width, output_dims.height);
glViewport(0, 0, output_width, output_height);
glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, output->name());
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
mediapipe::GlCalculatorHelper gl_helper_;
bool use_custom_zero_border_ = false;
BorderMode border_mode_ = BorderMode::kReplicate;

View File

@ -347,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
return absl::OkStatus();
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -358,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@autoreleasepool {
id<MTLTexture> texture =
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
constexpr int kNumChannels = 4;
Tensor tensor(Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width,
kNumChannels});
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(
@ -376,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
range_min, range_max));
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer);
const auto& buffer_view =
output_tensor.GetMtlBufferWriteView(command_buffer);
MP_RETURN_IF_ERROR(extractor_->Execute(
texture, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width),
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_buffer, buffer_view.buffer()));
[command_buffer commit];
return tensor;
return absl::OkStatus();
}
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 4)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
MPPMetalHelper* metal_helper_ = nil;
std::unique_ptr<SubRectExtractorMetal> extractor_;
};

View File

@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
}
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
input.image_format() != mediapipe::ImageFormat::SRGBA) {
return InvalidArgumentError(
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
static_cast<uint32_t>(input.image_format())));
}
auto src = mediapipe::formats::MatView(&input);
// TODO: Remove the check once tensor_buffer_offset > 0 is
// supported.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height,
output_dims.width, kNumChannels});
auto buffer_view = tensor.GetCpuWriteView();
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
const int output_channels = output_shape.dims[3];
auto buffer_view = output_tensor.GetCpuWriteView();
cv::Mat dst;
switch (tensor_type_) {
case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<int8>());
break;
case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<float>());
break;
case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<uint8>());
break;
default:
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
cv::Mat src_points;
cv::boxPoints(rotated_rect, src_points);
const float dst_width = output_dims.width;
const float dst_height = output_dims.height;
const float dst_width = output_width;
const float dst_height = output_height;
/* clang-format off */
float dst_corners[8] = {0.0f, dst_height,
0.0f, 0.0f,
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
dst_width, dst_height};
/* clang-format on */
auto src = mediapipe::formats::MatView(&input);
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
cv::Mat projection_matrix =
cv::getPerspectiveTransform(src_points, dst_points);
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
/*flags=*/cv::INTER_LINEAR,
/*borderMode=*/border_mode_);
if (transformed.channels() > kNumChannels) {
if (transformed.channels() > output_channels) {
cv::Mat proper_channels_mat;
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
transformed = proper_channels_mat;
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max));
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
return tensor;
return absl::OkStatus();
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
enum cv::BorderTypes border_mode_;
Tensor::ElementType tensor_type_;
int mat_type_;