diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc index d9825b15f..9e3fdc0ca 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc @@ -111,6 +111,7 @@ class TensorsToImageCalculator : public Node { private: TensorsToImageCalculatorOptions options_; absl::Status CpuProcess(CalculatorContext* cc); + int tensor_position_; #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_METAL_ENABLED @@ -166,6 +167,7 @@ absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) { << "Must specify either `input_tensor_float_range` or " "`input_tensor_uint_range` in the calculator options"; } + tensor_position_ = options_.tensor_position(); return absl::OkStatus(); } @@ -202,17 +204,23 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); - const auto& input_tensor = input_tensors[0]; + const auto& input_tensor = input_tensors[tensor_position_]; const int tensor_in_height = input_tensor.shape().dims[1]; const int tensor_in_width = input_tensor.shape().dims[2]; const int tensor_in_channels = input_tensor.shape().dims[3]; - RET_CHECK_EQ(tensor_in_channels, 3); + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); - auto output_frame = std::make_shared( - mediapipe::ImageFormat::SRGB, tensor_in_width, tensor_in_height); + auto format = mediapipe::ImageFormat::SRGB; + if (tensor_in_channels == 1) { + format = mediapipe::ImageFormat::GRAY8; + } + + auto output_frame = + std::make_shared(format, tensor_in_width, tensor_in_height); cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get()); constexpr float kOutputImageRangeMin = 0.0f; @@ -227,8 +235,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) { cv::Mat tensor_matview( cv::Size(tensor_in_width, tensor_in_height), @@ -239,8 +248,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else { return absl::InvalidArgumentError( absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0", @@ -264,10 +274,14 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + const int tensor_width = input_tensors[tensor_position_].shape().dims[2]; + const int tensor_height = input_tensors[tensor_position_].shape().dims[1]; + const int tensor_channels = input_tensors[tensor_position_].shape().dims[3]; + // TODO: Add 1 channel support. + RET_CHECK(tensor_channels == 3); // TODO: Fix unused variable [[maybe_unused]] id device = gpu_helper_.mtlDevice; @@ -277,8 +291,8 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { [command_buffer computeCommandEncoder]; [compute_encoder setComputePipelineState:to_buffer_program_]; - auto input_view = - mediapipe::MtlBufferView::GetReadView(input_tensors[0], command_buffer); + auto input_view = mediapipe::MtlBufferView::GetReadView( + input_tensors[tensor_position_], command_buffer); [compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer output = @@ -355,7 +369,7 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size_), R"( precision highp float; layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; - uniform ivec2 out_size; + uniform ivec3 out_size; )"); const std::string shader_body = R"( @@ -366,10 +380,11 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { void main() { int out_width = out_size.x; int out_height = out_size.y; + int out_channels = out_size.z; ivec2 gid = ivec2(gl_GlobalInvocationID.xy); if (gid.x >= out_width || gid.y >= out_height) { return; } - int linear_index = 3 * (gid.y * out_width + gid.x); + int linear_index = out_channels * (gid.y * out_width + gid.x); #ifdef FLIP_Y_COORD int y_coord = out_height - gid.y - 1; @@ -377,8 +392,14 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { int y_coord = gid.y; #endif // defined(FLIP_Y_COORD) + vec4 out_value; ivec2 out_coordinate = ivec2(gid.x, y_coord); - vec4 out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + if (out_channels == 3) { + out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + } else { + float in_value = input_data.elements[linear_index]; + out_value = vec4(in_value, in_value, in_value, 1.0); + } imageStore(output_texture, out_coordinate, out_value); })"; @@ -438,10 +459,15 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + + const auto& input_tensor = input_tensors[tensor_position_]; + const int tensor_width = input_tensor.shape().dims[2]; + const int tensor_height = input_tensor.shape().dims[1]; + const int tensor_in_channels = input_tensor.shape().dims[3]; + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -454,7 +480,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA8); - auto read_view = input_tensors[0].GetOpenGlBufferReadView(); + auto read_view = input_tensor.GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1}; @@ -462,8 +488,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { tflite::gpu::DivideRoundUp(workload, workgroup_size_); glUseProgram(gl_compute_program_->id()); - glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), - tensor_width, tensor_height); + glUniform3i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), + tensor_width, tensor_height, tensor_in_channels); MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups)); @@ -481,8 +507,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { #else - if (!input_tensors[0].ready_as_opengl_texture_2d()) { - (void)input_tensors[0].GetCpuReadView(); + if (!input_tensor.ready_as_opengl_texture_2d()) { + (void)input_tensor.GetCpuReadView(); } auto output_texture = @@ -490,7 +516,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, - input_tensors[0].GetOpenGlTexture2dReadView().name()); + input_tensor.GetOpenGlTexture2dReadView().name()); MP_RETURN_IF_ERROR(gl_renderer_->GlRender( tensor_width, tensor_height, output_texture.width(), diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto index 6bca86265..b0ecb8b5a 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto @@ -48,4 +48,8 @@ message TensorsToImageCalculatorOptions { FloatRange input_tensor_float_range = 2; UIntRange input_tensor_uint_range = 3; } + + // Determines which output tensor to slice when there are multiple output + // tensors available (e.g. network has multiple heads) + optional int32 tensor_position = 4 [default = 0]; }