Changed the image to tensor converter interface such that the "Convert"
function assumes the Tensor is preallocated before calling the function. PiperOrigin-RevId: 481752678
This commit is contained in:
parent
a5e4219590
commit
58e5cc3c88
|
@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
||||||
const Size size{image->width(), image->height()};
|
|
||||||
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
|
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
|
||||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||||
options_.output_tensor_height(),
|
options_.output_tensor_height(),
|
||||||
options_.keep_aspect_ratio(), &roi));
|
options_.keep_aspect_ratio(), &roi));
|
||||||
|
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
|
||||||
}
|
}
|
||||||
if (kOutMatrix(cc).IsConnected()) {
|
if (kOutMatrix(cc).IsConnected()) {
|
||||||
std::array<float, 16> matrix;
|
std::array<float, 16> matrix;
|
||||||
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
|
GetRotatedSubRectToRectTransformMatrix(
|
||||||
/*flip_horizontaly=*/false,
|
roi, image->width(), image->height(),
|
||||||
&matrix);
|
/*flip_horizontaly=*/false, &matrix);
|
||||||
kOutMatrix(cc).Send(std::move(matrix));
|
kOutMatrix(cc).Send(std::move(matrix));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lazy initialization of the GPU or CPU converter.
|
// Lazy initialization of the GPU or CPU converter.
|
||||||
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(Tensor tensor,
|
Tensor::ElementType output_tensor_type =
|
||||||
(image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
GetOutputTensorType(image->UsesGpu());
|
||||||
->Convert(*image, roi, {output_width_, output_height_},
|
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
|
||||||
range_min_, range_max_));
|
GetNumOutputChannels(*image)});
|
||||||
|
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||||
|
->Convert(*image, roi, range_min_, range_max_,
|
||||||
|
/*tensor_buffer_offset=*/0, tensor));
|
||||||
|
|
||||||
auto result = std::make_unique<std::vector<Tensor>>();
|
auto result = std::make_unique<std::vector<Tensor>>();
|
||||||
result->push_back(std::move(tensor));
|
result->push_back(std::move(tensor));
|
||||||
|
@ -292,7 +295,8 @@ class ImageToTensorCalculator : public Node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::ElementType GetOutputTensorType() {
|
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
|
||||||
|
if (!uses_gpu) {
|
||||||
if (is_float_output_) {
|
if (is_float_output_) {
|
||||||
return Tensor::ElementType::kFloat32;
|
return Tensor::ElementType::kFloat32;
|
||||||
}
|
}
|
||||||
|
@ -302,6 +306,21 @@ class ImageToTensorCalculator : public Node {
|
||||||
return Tensor::ElementType::kUInt8;
|
return Tensor::ElementType::kUInt8;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Always use float32 when GPU is enabled.
|
||||||
|
return Tensor::ElementType::kFloat32;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetNumOutputChannels(const Image& image) {
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
|
if (image.UsesGpu()) {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
// All of the processors except for Metal expect 3 channels.
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
|
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
|
||||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
cpu_converter_,
|
cpu_converter_,
|
||||||
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
|
CreateOpenCvConverter(cc, GetBorderMode(),
|
||||||
|
GetOutputTensorType(/*uses_gpu=*/false)));
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||||
|
|
|
@ -42,13 +42,16 @@ class ImageToTensorConverter {
|
||||||
// @image contains image to extract from.
|
// @image contains image to extract from.
|
||||||
// @roi describes region of interest within the image to extract (absolute
|
// @roi describes region of interest within the image to extract (absolute
|
||||||
// values).
|
// values).
|
||||||
// @output_dims dimensions of output tensor.
|
|
||||||
// @range_min/max describes output tensor range image pixels should converted
|
// @range_min/max describes output tensor range image pixels should converted
|
||||||
// to.
|
// to.
|
||||||
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
// @tensor_buffer_offset an inteter representing the offset of the tensor
|
||||||
const RotatedRect& roi,
|
// buffer the result should be written to.
|
||||||
const Size& output_dims,
|
// @output_tensor a tensor with pre-defined shape. The "Convert" is
|
||||||
float range_min, float range_max) = 0;
|
// responsible of populating the content into the output tensor.
|
||||||
|
virtual absl::Status Convert(const mediapipe::Image& input,
|
||||||
|
const RotatedRect& roi, float range_min,
|
||||||
|
float range_max, int tensor_buffer_offset,
|
||||||
|
Tensor& output_tensor) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -264,10 +264,10 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||||
const RotatedRect& roi,
|
float range_min, float range_max,
|
||||||
const Size& output_dims, float range_min,
|
int tensor_buffer_offset,
|
||||||
float range_max) override {
|
Tensor& output_tensor) override {
|
||||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||||
|
@ -275,46 +275,46 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
"Only 4-channel texture input formats are supported, passed format: ",
|
"Only 4-channel texture input formats are supported, passed format: ",
|
||||||
static_cast<uint32_t>(input.format())));
|
static_cast<uint32_t>(input.format())));
|
||||||
}
|
}
|
||||||
|
const auto& output_shape = output_tensor.shape();
|
||||||
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
|
|
||||||
constexpr int kNumChannels = 3;
|
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||||
{1, output_dims.height, output_dims.width, kNumChannels});
|
range_max, tensor_buffer_offset]() -> absl::Status {
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
|
|
||||||
&output_dims, range_min,
|
|
||||||
range_max]() -> absl::Status {
|
|
||||||
constexpr int kRgbaNumChannels = 4;
|
constexpr int kRgbaNumChannels = 4;
|
||||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||||
tflite::gpu::gl::GlTexture input_texture(
|
tflite::gpu::gl::GlTexture input_texture(
|
||||||
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
||||||
source_texture.width() * source_texture.height() * kRgbaNumChannels *
|
source_texture.width() * source_texture.height() *
|
||||||
sizeof(uint8_t),
|
kRgbaNumChannels * sizeof(uint8_t),
|
||||||
/*layer=*/0,
|
/*layer=*/0,
|
||||||
/*owned=*/false);
|
/*owned=*/false);
|
||||||
|
|
||||||
constexpr float kInputImageRangeMin = 0.0f;
|
constexpr float kInputImageRangeMin = 0.0f;
|
||||||
constexpr float kInputImageRangeMax = 1.0f;
|
constexpr float kInputImageRangeMax = 1.0f;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(auto transform,
|
||||||
auto transform,
|
GetValueRangeTransformation(kInputImageRangeMin,
|
||||||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
kInputImageRangeMax,
|
||||||
range_min, range_max));
|
range_min, range_max));
|
||||||
|
|
||||||
auto buffer_view = tensor.GetOpenGlBufferWriteView();
|
const int output_size = output_tensor.bytes() / output_shape.dims[0];
|
||||||
|
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
|
||||||
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
|
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
|
||||||
buffer_view.name(), tensor.bytes(),
|
buffer_view.name(), output_size,
|
||||||
/*offset=*/0,
|
/*offset=*/tensor_buffer_offset,
|
||||||
/*has_ownership=*/false);
|
/*has_ownership=*/false);
|
||||||
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
|
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
|
||||||
input_texture,
|
input_texture,
|
||||||
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
|
tflite::gpu::HW(source_texture.height(), source_texture.width()),
|
||||||
|
roi,
|
||||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||||
command_queue_.get(), &output));
|
command_queue_.get(), &output));
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return tensor;
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
~GlProcessor() override {
|
~GlProcessor() override {
|
||||||
|
@ -326,6 +326,17 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||||
|
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||||
|
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||||
|
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||||
|
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||||
|
"converter.";
|
||||||
|
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||||
|
<< "Wrong output channel: " << output_shape.dims[3];
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
|
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
|
||||||
std::unique_ptr<SubRectExtractorGl> extractor_;
|
std::unique_ptr<SubRectExtractorGl> extractor_;
|
||||||
mediapipe::GlCalculatorHelper gl_helper_;
|
mediapipe::GlCalculatorHelper gl_helper_;
|
||||||
|
|
|
@ -168,10 +168,10 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||||
const RotatedRect& roi,
|
float range_min, float range_max,
|
||||||
const Size& output_dims, float range_min,
|
int tensor_buffer_offset,
|
||||||
float range_max) override {
|
Tensor& output_tensor) override {
|
||||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||||
|
@ -179,15 +179,15 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
"Only 4-channel texture input formats are supported, passed format: ",
|
"Only 4-channel texture input formats are supported, passed format: ",
|
||||||
static_cast<uint32_t>(input.format())));
|
static_cast<uint32_t>(input.format())));
|
||||||
}
|
}
|
||||||
|
// TODO: support tensor_buffer_offset > 0 scenario.
|
||||||
|
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||||
|
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||||
|
const auto& output_shape = output_tensor.shape();
|
||||||
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
|
|
||||||
constexpr int kNumChannels = 3;
|
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||||
Tensor tensor(
|
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||||
Tensor::ElementType::kFloat32,
|
range_max]() -> absl::Status {
|
||||||
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
|
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
|
|
||||||
range_min, range_max]() -> absl::Status {
|
|
||||||
auto input_texture = gl_helper_.CreateSourceTexture(input);
|
auto input_texture = gl_helper_.CreateSourceTexture(input);
|
||||||
|
|
||||||
constexpr float kInputImageRangeMin = 0.0f;
|
constexpr float kInputImageRangeMin = 0.0f;
|
||||||
|
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
GetValueRangeTransformation(kInputImageRangeMin,
|
GetValueRangeTransformation(kInputImageRangeMin,
|
||||||
kInputImageRangeMax,
|
kInputImageRangeMax,
|
||||||
range_min, range_max));
|
range_min, range_max));
|
||||||
auto tensor_view = tensor.GetOpenGlTexture2dWriteView();
|
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
|
||||||
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
|
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
|
||||||
/*flip_horizontaly=*/false,
|
/*flip_horizontaly=*/false,
|
||||||
transform.scale, transform.offset,
|
transform.scale, transform.offset,
|
||||||
output_dims, &tensor_view));
|
output_shape, &tensor_view));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return tensor;
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
|
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
|
||||||
const RotatedRect& sub_rect,
|
const RotatedRect& sub_rect,
|
||||||
bool flip_horizontaly, float alpha, float beta,
|
bool flip_horizontaly, float alpha, float beta,
|
||||||
const Size& output_dims,
|
const Tensor::Shape& output_shape,
|
||||||
Tensor::OpenGlTexture2dView* output) {
|
Tensor::OpenGlTexture2dView* output) {
|
||||||
|
const int output_height = output_shape.dims[1];
|
||||||
|
const int output_width = output_shape.dims[2];
|
||||||
std::array<float, 16> transform_mat;
|
std::array<float, 16> transform_mat;
|
||||||
|
|
||||||
glDisable(GL_DEPTH_TEST);
|
glDisable(GL_DEPTH_TEST);
|
||||||
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
|
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
|
||||||
glViewport(0, 0, output_dims.width, output_dims.height);
|
glViewport(0, 0, output_width, output_height);
|
||||||
|
|
||||||
glActiveTexture(GL_TEXTURE0);
|
glActiveTexture(GL_TEXTURE0);
|
||||||
glBindTexture(GL_TEXTURE_2D, output->name());
|
glBindTexture(GL_TEXTURE_2D, output->name());
|
||||||
|
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||||
|
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||||
|
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||||
|
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||||
|
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||||
|
"converter.";
|
||||||
|
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||||
|
<< "Wrong output channel: " << output_shape.dims[3];
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
mediapipe::GlCalculatorHelper gl_helper_;
|
mediapipe::GlCalculatorHelper gl_helper_;
|
||||||
bool use_custom_zero_border_ = false;
|
bool use_custom_zero_border_ = false;
|
||||||
BorderMode border_mode_ = BorderMode::kReplicate;
|
BorderMode border_mode_ = BorderMode::kReplicate;
|
||||||
|
|
|
@ -347,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||||
const RotatedRect& roi,
|
float range_min, float range_max,
|
||||||
const Size& output_dims, float range_min,
|
int tensor_buffer_offset,
|
||||||
float range_max) override {
|
Tensor& output_tensor) override {
|
||||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||||
|
@ -358,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
|
||||||
"Only 4-channel texture input formats are supported, passed format: ",
|
"Only 4-channel texture input formats are supported, passed format: ",
|
||||||
static_cast<uint32_t>(input.format())));
|
static_cast<uint32_t>(input.format())));
|
||||||
}
|
}
|
||||||
|
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||||
|
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||||
|
const auto& output_shape = output_tensor.shape();
|
||||||
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
|
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
id<MTLTexture> texture =
|
id<MTLTexture> texture =
|
||||||
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
|
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
|
||||||
|
|
||||||
constexpr int kNumChannels = 4;
|
|
||||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
|
||||||
Tensor::Shape{1, output_dims.height, output_dims.width,
|
|
||||||
kNumChannels});
|
|
||||||
|
|
||||||
constexpr float kInputImageRangeMin = 0.0f;
|
constexpr float kInputImageRangeMin = 0.0f;
|
||||||
constexpr float kInputImageRangeMax = 1.0f;
|
constexpr float kInputImageRangeMax = 1.0f;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
|
@ -376,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
|
||||||
range_min, range_max));
|
range_min, range_max));
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
||||||
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer);
|
const auto& buffer_view =
|
||||||
|
output_tensor.GetMtlBufferWriteView(command_buffer);
|
||||||
MP_RETURN_IF_ERROR(extractor_->Execute(
|
MP_RETURN_IF_ERROR(extractor_->Execute(
|
||||||
texture, roi,
|
texture, roi,
|
||||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||||
command_buffer, buffer_view.buffer()));
|
command_buffer, buffer_view.buffer()));
|
||||||
[command_buffer commit];
|
[command_buffer commit];
|
||||||
return tensor;
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||||
|
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||||
|
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||||
|
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||||
|
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||||
|
"converter.";
|
||||||
|
RET_CHECK_EQ(output_shape.dims[3], 4)
|
||||||
|
<< "Wrong output channel: " << output_shape.dims[3];
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
MPPMetalHelper* metal_helper_ = nil;
|
MPPMetalHelper* metal_helper_ = nil;
|
||||||
std::unique_ptr<SubRectExtractorMetal> extractor_;
|
std::unique_ptr<SubRectExtractorMetal> extractor_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||||
const RotatedRect& roi,
|
float range_min, float range_max,
|
||||||
const Size& output_dims, float range_min,
|
int tensor_buffer_offset,
|
||||||
float range_max) override {
|
Tensor& output_tensor) override {
|
||||||
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
|
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
|
||||||
input.image_format() != mediapipe::ImageFormat::SRGBA) {
|
input.image_format() != mediapipe::ImageFormat::SRGBA) {
|
||||||
return InvalidArgumentError(
|
return InvalidArgumentError(
|
||||||
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
|
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
|
||||||
static_cast<uint32_t>(input.image_format())));
|
static_cast<uint32_t>(input.image_format())));
|
||||||
}
|
}
|
||||||
auto src = mediapipe::formats::MatView(&input);
|
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
||||||
|
// supported.
|
||||||
|
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||||
|
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||||
|
const auto& output_shape = output_tensor.shape();
|
||||||
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
|
|
||||||
constexpr int kNumChannels = 3;
|
const int output_height = output_shape.dims[1];
|
||||||
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height,
|
const int output_width = output_shape.dims[2];
|
||||||
output_dims.width, kNumChannels});
|
const int output_channels = output_shape.dims[3];
|
||||||
auto buffer_view = tensor.GetCpuWriteView();
|
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||||
cv::Mat dst;
|
cv::Mat dst;
|
||||||
switch (tensor_type_) {
|
switch (tensor_type_) {
|
||||||
case Tensor::ElementType::kInt8:
|
case Tensor::ElementType::kInt8:
|
||||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||||
buffer_view.buffer<int8>());
|
buffer_view.buffer<int8>());
|
||||||
break;
|
break;
|
||||||
case Tensor::ElementType::kFloat32:
|
case Tensor::ElementType::kFloat32:
|
||||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||||
buffer_view.buffer<float>());
|
buffer_view.buffer<float>());
|
||||||
break;
|
break;
|
||||||
case Tensor::ElementType::kUInt8:
|
case Tensor::ElementType::kUInt8:
|
||||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||||
buffer_view.buffer<uint8>());
|
buffer_view.buffer<uint8>());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
cv::Mat src_points;
|
cv::Mat src_points;
|
||||||
cv::boxPoints(rotated_rect, src_points);
|
cv::boxPoints(rotated_rect, src_points);
|
||||||
|
|
||||||
const float dst_width = output_dims.width;
|
const float dst_width = output_width;
|
||||||
const float dst_height = output_dims.height;
|
const float dst_height = output_height;
|
||||||
/* clang-format off */
|
/* clang-format off */
|
||||||
float dst_corners[8] = {0.0f, dst_height,
|
float dst_corners[8] = {0.0f, dst_height,
|
||||||
0.0f, 0.0f,
|
0.0f, 0.0f,
|
||||||
|
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
dst_width, dst_height};
|
dst_width, dst_height};
|
||||||
/* clang-format on */
|
/* clang-format on */
|
||||||
|
|
||||||
|
auto src = mediapipe::formats::MatView(&input);
|
||||||
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
|
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
|
||||||
cv::Mat projection_matrix =
|
cv::Mat projection_matrix =
|
||||||
cv::getPerspectiveTransform(src_points, dst_points);
|
cv::getPerspectiveTransform(src_points, dst_points);
|
||||||
|
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
/*flags=*/cv::INTER_LINEAR,
|
/*flags=*/cv::INTER_LINEAR,
|
||||||
/*borderMode=*/border_mode_);
|
/*borderMode=*/border_mode_);
|
||||||
|
|
||||||
if (transformed.channels() > kNumChannels) {
|
if (transformed.channels() > output_channels) {
|
||||||
cv::Mat proper_channels_mat;
|
cv::Mat proper_channels_mat;
|
||||||
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
|
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
|
||||||
transformed = proper_channels_mat;
|
transformed = proper_channels_mat;
|
||||||
|
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||||
range_min, range_max));
|
range_min, range_max));
|
||||||
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
|
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
|
||||||
return tensor;
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||||
|
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||||
|
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||||
|
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||||
|
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||||
|
"converter.";
|
||||||
|
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||||
|
<< "Wrong output channel: " << output_shape.dims[3];
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
enum cv::BorderTypes border_mode_;
|
enum cv::BorderTypes border_mode_;
|
||||||
Tensor::ElementType tensor_type_;
|
Tensor::ElementType tensor_type_;
|
||||||
int mat_type_;
|
int mat_type_;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user