Adding support for 2 things in tensors_to_image_calculator:

1) 1 channel support for conversion after inference.
2) multitask support by allowing for different tensor outputs.

PiperOrigin-RevId: 549412331
This commit is contained in:
Steven Hickson 2023-07-19 13:38:43 -07:00 committed by Copybara-Service
parent 085840388b
commit e47af74b15
2 changed files with 59 additions and 29 deletions

View File

@ -111,6 +111,7 @@ class TensorsToImageCalculator : public Node {
private: private:
TensorsToImageCalculatorOptions options_; TensorsToImageCalculatorOptions options_;
absl::Status CpuProcess(CalculatorContext* cc); absl::Status CpuProcess(CalculatorContext* cc);
int tensor_position_;
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
@ -166,6 +167,7 @@ absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) {
<< "Must specify either `input_tensor_float_range` or " << "Must specify either `input_tensor_float_range` or "
"`input_tensor_uint_range` in the calculator options"; "`input_tensor_uint_range` in the calculator options";
} }
tensor_position_ = options_.tensor_position();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -202,17 +204,23 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& input_tensors = kInputTensors(cc).Get(); const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1) RET_CHECK_GT(input_tensors.size(), tensor_position_)
<< "Expect 1 input tensor, but have " << input_tensors.size(); << "Expect input tensor at position " << tensor_position_
<< ", but have tensors of size " << input_tensors.size();
const auto& input_tensor = input_tensors[0]; const auto& input_tensor = input_tensors[tensor_position_];
const int tensor_in_height = input_tensor.shape().dims[1]; const int tensor_in_height = input_tensor.shape().dims[1];
const int tensor_in_width = input_tensor.shape().dims[2]; const int tensor_in_width = input_tensor.shape().dims[2];
const int tensor_in_channels = input_tensor.shape().dims[3]; const int tensor_in_channels = input_tensor.shape().dims[3];
RET_CHECK_EQ(tensor_in_channels, 3); RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1);
auto output_frame = std::make_shared<ImageFrame>( auto format = mediapipe::ImageFormat::SRGB;
mediapipe::ImageFormat::SRGB, tensor_in_width, tensor_in_height); if (tensor_in_channels == 1) {
format = mediapipe::ImageFormat::GRAY8;
}
auto output_frame =
std::make_shared<ImageFrame>(format, tensor_in_width, tensor_in_height);
cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get()); cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get());
constexpr float kOutputImageRangeMin = 0.0f; constexpr float kOutputImageRangeMin = 0.0f;
@ -227,8 +235,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) {
GetValueRangeTransformation( GetValueRangeTransformation(
input_range.min(), input_range.max(), input_range.min(), input_range.max(),
kOutputImageRangeMin, kOutputImageRangeMax)); kOutputImageRangeMin, kOutputImageRangeMax));
tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, tensor_matview.convertTo(output_matview,
transform.offset); CV_MAKETYPE(CV_8U, tensor_in_channels),
transform.scale, transform.offset);
} else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) { } else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) {
cv::Mat tensor_matview( cv::Mat tensor_matview(
cv::Size(tensor_in_width, tensor_in_height), cv::Size(tensor_in_width, tensor_in_height),
@ -239,8 +248,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) {
GetValueRangeTransformation( GetValueRangeTransformation(
input_range.min(), input_range.max(), input_range.min(), input_range.max(),
kOutputImageRangeMin, kOutputImageRangeMax)); kOutputImageRangeMin, kOutputImageRangeMax));
tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, tensor_matview.convertTo(output_matview,
transform.offset); CV_MAKETYPE(CV_8U, tensor_in_channels),
transform.scale, transform.offset);
} else { } else {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0", absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0",
@ -264,10 +274,14 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& input_tensors = kInputTensors(cc).Get(); const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1) RET_CHECK_GT(input_tensors.size(), tensor_position_)
<< "Expect 1 input tensor, but have " << input_tensors.size(); << "Expect input tensor at position " << tensor_position_
const int tensor_width = input_tensors[0].shape().dims[2]; << ", but have tensors of size " << input_tensors.size();
const int tensor_height = input_tensors[0].shape().dims[1]; const int tensor_width = input_tensors[tensor_position_].shape().dims[2];
const int tensor_height = input_tensors[tensor_position_].shape().dims[1];
const int tensor_channels = input_tensors[tensor_position_].shape().dims[3];
// TODO: Add 1 channel support.
RET_CHECK(tensor_channels == 3);
// TODO: Fix unused variable // TODO: Fix unused variable
[[maybe_unused]] id<MTLDevice> device = gpu_helper_.mtlDevice; [[maybe_unused]] id<MTLDevice> device = gpu_helper_.mtlDevice;
@ -277,8 +291,8 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) {
[command_buffer computeCommandEncoder]; [command_buffer computeCommandEncoder];
[compute_encoder setComputePipelineState:to_buffer_program_]; [compute_encoder setComputePipelineState:to_buffer_program_];
auto input_view = auto input_view = mediapipe::MtlBufferView::GetReadView(
mediapipe::MtlBufferView::GetReadView(input_tensors[0], command_buffer); input_tensors[tensor_position_], command_buffer);
[compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0]; [compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0];
mediapipe::GpuBuffer output = mediapipe::GpuBuffer output =
@ -355,7 +369,7 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) {
absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size_), R"( absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size_), R"(
precision highp float; precision highp float;
layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture;
uniform ivec2 out_size; uniform ivec3 out_size;
)"); )");
const std::string shader_body = R"( const std::string shader_body = R"(
@ -366,10 +380,11 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) {
void main() { void main() {
int out_width = out_size.x; int out_width = out_size.x;
int out_height = out_size.y; int out_height = out_size.y;
int out_channels = out_size.z;
ivec2 gid = ivec2(gl_GlobalInvocationID.xy); ivec2 gid = ivec2(gl_GlobalInvocationID.xy);
if (gid.x >= out_width || gid.y >= out_height) { return; } if (gid.x >= out_width || gid.y >= out_height) { return; }
int linear_index = 3 * (gid.y * out_width + gid.x); int linear_index = out_channels * (gid.y * out_width + gid.x);
#ifdef FLIP_Y_COORD #ifdef FLIP_Y_COORD
int y_coord = out_height - gid.y - 1; int y_coord = out_height - gid.y - 1;
@ -377,8 +392,14 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) {
int y_coord = gid.y; int y_coord = gid.y;
#endif // defined(FLIP_Y_COORD) #endif // defined(FLIP_Y_COORD)
vec4 out_value;
ivec2 out_coordinate = ivec2(gid.x, y_coord); ivec2 out_coordinate = ivec2(gid.x, y_coord);
vec4 out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); if (out_channels == 3) {
out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0);
} else {
float in_value = input_data.elements[linear_index];
out_value = vec4(in_value, in_value, in_value, 1.0);
}
imageStore(output_texture, out_coordinate, out_value); imageStore(output_texture, out_coordinate, out_value);
})"; })";
@ -438,10 +459,15 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& input_tensors = kInputTensors(cc).Get(); const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1) RET_CHECK_GT(input_tensors.size(), tensor_position_)
<< "Expect 1 input tensor, but have " << input_tensors.size(); << "Expect input tensor at position " << tensor_position_
const int tensor_width = input_tensors[0].shape().dims[2]; << ", but have tensors of size " << input_tensors.size();
const int tensor_height = input_tensors[0].shape().dims[1];
const auto& input_tensor = input_tensors[tensor_position_];
const int tensor_width = input_tensor.shape().dims[2];
const int tensor_height = input_tensor.shape().dims[1];
const int tensor_in_channels = input_tensor.shape().dims[3];
RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1);
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
@ -454,7 +480,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0, glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA8); GL_WRITE_ONLY, GL_RGBA8);
auto read_view = input_tensors[0].GetOpenGlBufferReadView(); auto read_view = input_tensor.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name());
const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1}; const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1};
@ -462,8 +488,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
tflite::gpu::DivideRoundUp(workload, workgroup_size_); tflite::gpu::DivideRoundUp(workload, workgroup_size_);
glUseProgram(gl_compute_program_->id()); glUseProgram(gl_compute_program_->id());
glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), glUniform3i(glGetUniformLocation(gl_compute_program_->id(), "out_size"),
tensor_width, tensor_height); tensor_width, tensor_height, tensor_in_channels);
MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups)); MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups));
@ -481,8 +507,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
#else #else
if (!input_tensors[0].ready_as_opengl_texture_2d()) { if (!input_tensor.ready_as_opengl_texture_2d()) {
(void)input_tensors[0].GetCpuReadView(); (void)input_tensor.GetCpuReadView();
} }
auto output_texture = auto output_texture =
@ -490,7 +516,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, glBindTexture(GL_TEXTURE_2D,
input_tensors[0].GetOpenGlTexture2dReadView().name()); input_tensor.GetOpenGlTexture2dReadView().name());
MP_RETURN_IF_ERROR(gl_renderer_->GlRender( MP_RETURN_IF_ERROR(gl_renderer_->GlRender(
tensor_width, tensor_height, output_texture.width(), tensor_width, tensor_height, output_texture.width(),

View File

@ -48,4 +48,8 @@ message TensorsToImageCalculatorOptions {
FloatRange input_tensor_float_range = 2; FloatRange input_tensor_float_range = 2;
UIntRange input_tensor_uint_range = 3; UIntRange input_tensor_uint_range = 3;
} }
// Determines which output tensor to slice when there are multiple output
// tensors available (e.g. network has multiple heads)
optional int32 tensor_position = 4 [default = 0];
} }