Internal change

PiperOrigin-RevId: 535751178
This commit is contained in:
MediaPipe Team 2023-05-26 17:24:03 -07:00 committed by Copybara-Service
parent fddc3facf0
commit d4c7ed2217
6 changed files with 338 additions and 32 deletions

View File

@ -50,7 +50,12 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
] + select({
"//conditions:default": [],
"//mediapipe:android": [
":segmentation_postprocessor_gl",
], ],
}),
alwayslink = 1, alwayslink = 1,
) )
@ -72,6 +77,29 @@ cc_library(
"//mediapipe/tasks/cc/vision/utils:image_utils", "//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
] + select({
"//conditions:default": [],
"//mediapipe:android": [
"ssbo_to_texture_converter",
],
}),
)
cc_library(
name = "ssbo_to_texture_converter",
srcs = ["ssbo_to_texture_converter.cc"],
hdrs = ["ssbo_to_texture_converter.h"],
tags = [
"nomac",
"notap",
],
deps = [
"//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gl_base",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util",
], ],
) )

View File

@ -16,6 +16,14 @@ namespace mediapipe {
namespace tasks { namespace tasks {
namespace { namespace {
// On most platforms, glGetUniformLocation returns -1 for an error status, but
// on web we'll see 0 instead.
#ifdef __EMSCRIPTEN__
const GLint kUniformErrorStatus = 0;
#else
const GLint kUniformErrorStatus = -1;
#endif // __EMSCRIPTEN__
using mediapipe::kBasicSquareVertices; using mediapipe::kBasicSquareVertices;
using mediapipe::kBasicTextureVertices; using mediapipe::kBasicTextureVertices;
using mediapipe::kBasicVertexShader; using mediapipe::kBasicVertexShader;
@ -341,7 +349,7 @@ absl::Status SegmentationPostprocessorGl::CreateBasicFragmentShaderProgram(
for (const auto& uniform_name : uniform_names) { for (const auto& uniform_name : uniform_names) {
shader_struct_ptr->uniforms[uniform_name] = shader_struct_ptr->uniforms[uniform_name] =
glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str()); glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str());
RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > 0) RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > kUniformErrorStatus)
<< uniform_name << " uniform not found for " << program_name << uniform_name << " uniform not found for " << program_name
<< " program"; << " program";
} }
@ -427,10 +435,10 @@ absl::Status SegmentationPostprocessorGl::GlInit(
// Get split program uniform locations. // Get split program uniform locations.
split_texture_uniform_ = split_texture_uniform_ =
glGetUniformLocation(split_program_, "input_texture"); glGetUniformLocation(split_program_, "input_texture");
RET_CHECK(split_texture_uniform_ > 0) RET_CHECK(split_texture_uniform_ > kUniformErrorStatus)
<< "split input_texture uniform not found."; << "split input_texture uniform not found.";
split_x_offset_uniform_ = glGetUniformLocation(split_program_, "x_offset"); split_x_offset_uniform_ = glGetUniformLocation(split_program_, "x_offset");
RET_CHECK(split_x_offset_uniform_ > 0) RET_CHECK(split_x_offset_uniform_ > kUniformErrorStatus)
<< "split x_offset uniform not found."; << "split x_offset uniform not found.";
// TODO: If ES3.0+ only, switch to VAO for handling attributes. // TODO: If ES3.0+ only, switch to VAO for handling attributes.
@ -445,10 +453,24 @@ absl::Status SegmentationPostprocessorGl::GlInit(
kBasicTextureVertices, GL_STATIC_DRAW); kBasicTextureVertices, GL_STATIC_DRAW);
glBindBuffer(GL_ARRAY_BUFFER, 0); glBindBuffer(GL_ARRAY_BUFFER, 0);
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
MP_RETURN_IF_ERROR(ssbo_to_texture_converter_.Init());
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
return absl::OkStatus(); return absl::OkStatus();
}); });
} }
// On Android, the extensions are prefixed by GL_, whereas on web they are not.
bool SegmentationPostprocessorGl::HasGlExtension(std::string const& extension) {
#ifdef __EMSCRIPTEN__
return helper_.GetGlContext().HasGlExtension(extension);
#else
return helper_.GetGlContext().HasGlExtension("GL_" + extension);
#endif // __EMSCRIPTEN__
}
std::vector<std::unique_ptr<Image>> std::vector<std::unique_ptr<Image>>
SegmentationPostprocessorGl::GetSegmentationResultGpu( SegmentationPostprocessorGl::GetSegmentationResultGpu(
const Shape& input_shape, const Shape& output_shape, const Tensor& tensor, const Shape& input_shape, const Shape& output_shape, const Tensor& tensor,
@ -459,18 +481,35 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
produce_category_mask, produce_category_mask,
&image_outputs]() -> absl::Status { &image_outputs]() -> absl::Status {
// Get Tensor input and image output parameters // Get Tensor input and image output parameters
const int width = input_shape.width; // Slice width from shape
const int height = input_shape.height; // Slice height from chape
const int num_outputs = input_shape.channels; // One output per channel
const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4)
const int output_width = output_shape.width; // Final output width
const int output_height = output_shape.height; // Final output height
int input_width, input_height; int input_width, input_height;
if (!tensor.ready_as_opengl_texture_2d()) { if (!tensor.ready_on_gpu()) {
LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround."; LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround.";
(void)tensor.GetCpuReadView(); (void)tensor.GetCpuReadView();
} }
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
// If our Tensor is an SSBO, then it's also linearized, so we convert to a
// kAligned 2d texture using a special converter and then proceed as before.
GLuint ssbo_tex_id;
ASSIGN_OR_RETURN(ssbo_tex_id,
ssbo_to_texture_converter_.ConvertTensorToGlTexture(
tensor, width, height, num_outputs));
std::tie(input_width, input_height) =
ssbo_to_texture_converter_.GetTextureSize();
#else
const auto layout = tensor.GetOpenGlTexture2dReadView().GetLayoutDimensions( const auto layout = tensor.GetOpenGlTexture2dReadView().GetLayoutDimensions(
tensor.shape(), &input_width, &input_height); tensor.shape(), &input_width, &input_height);
if (layout != Tensor::OpenGlTexture2dView::Layout::kAligned) { if (layout != Tensor::OpenGlTexture2dView::Layout::kAligned) {
LOG(ERROR) << "Tensor layout not kAligned! Cannot handle."; LOG(ERROR) << "Tensor layout not kAligned! Cannot handle.";
} }
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
// Optimization: Only apply SOFTMAX when producing confidence masks, since // Optimization: Only apply SOFTMAX when producing confidence masks, since
// SOFTMAX errors out when num_classes = 1, so we don't have to worry about // SOFTMAX errors out when num_classes = 1, so we don't have to worry about
@ -486,14 +525,12 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
// (3) blending // (3) blending
// Otherwise, we just try for F16. See b/277656755 for more information. // Otherwise, we just try for F16. See b/277656755 for more information.
// TODO: In the future, separate these 3 different restrictions. // TODO: In the future, separate these 3 different restrictions.
// TODO: Also, we should extend this logic to non-web platforms. // TODO: Also, we should extend this logic to all platforms.
static bool can_use_f32 = static bool can_use_f32 = HasGlExtension("EXT_color_buffer_float") &&
helper_.GetGlContext().HasGlExtension("EXT_color_buffer_float") && HasGlExtension("OES_texture_float_linear") &&
helper_.GetGlContext().HasGlExtension("OES_texture_float_linear") && HasGlExtension("EXT_float_blend");
helper_.GetGlContext().HasGlExtension("EXT_float_blend");
static bool can_use_f16_backup = static bool can_use_f16_backup =
helper_.GetGlContext().HasGlExtension("EXT_color_buffer_half_float"); HasGlExtension("EXT_color_buffer_half_float");
RET_CHECK(can_use_f32 || can_use_f16_backup) RET_CHECK(can_use_f32 || can_use_f16_backup)
<< "Segmentation postprocessing error: GPU does not fully support " << "Segmentation postprocessing error: GPU does not fully support "
<< "4-channel float32 or float16 formats."; << "4-channel float32 or float16 formats.";
@ -510,15 +547,6 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
const GpuBufferFormat final_output_format = const GpuBufferFormat final_output_format =
can_use_f32 ? GpuBufferFormat::kGrayFloat32 can_use_f32 ? GpuBufferFormat::kGrayFloat32
: GpuBufferFormat::kGrayHalf16; : GpuBufferFormat::kGrayHalf16;
const Tensor::OpenGlTexture2dView read_view =
tensor.GetOpenGlTexture2dReadView();
const int width = input_shape.width; // Slice width from shape
const int height = input_shape.height; // Slice height from chape
const int num_outputs = input_shape.channels; // One output per channel
const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4)
const int output_width = output_shape.width; // Final output width
const int output_height = output_shape.height; // Final output height
// We disable blending or else our alpha channel may destroy our other // We disable blending or else our alpha channel may destroy our other
// channels' data. // channels' data.
@ -540,9 +568,16 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
input_width, input_height, activation_output_format); input_width, input_height, activation_output_format);
helper_.BindFramebuffer(activated_texture); helper_.BindFramebuffer(activated_texture);
// All our input source textures are just simple GL_TEXTURE_2D types. // All our input source textures will be just simple GL_TEXTURE_2D types.
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
glBindTexture(GL_TEXTURE_2D, ssbo_tex_id);
#else
const Tensor::OpenGlTexture2dView read_view =
tensor.GetOpenGlTexture2dReadView();
glBindTexture(GL_TEXTURE_2D, read_view.name()); glBindTexture(GL_TEXTURE_2D, read_view.name());
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
// Render // Render
glClear(GL_COLOR_BUFFER_BIT); glClear(GL_COLOR_BUFFER_BIT);
@ -841,6 +876,10 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() {
glDeleteProgram(softmax_max_shader_.program); glDeleteProgram(softmax_max_shader_.program);
glDeleteProgram(softmax_transform_and_sum_shader_.program); glDeleteProgram(softmax_transform_and_sum_shader_.program);
glDeleteProgram(softmax_normalization_shader_.program); glDeleteProgram(softmax_normalization_shader_.program);
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
ssbo_to_texture_converter_.Close();
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
}); });
} }

View File

@ -21,6 +21,14 @@
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
// On Android with compute shaders we include the SSBO-to-texture converter
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \
defined(__ANDROID__)
#define TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING 1
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h"
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 &&
// defined(__ANDROID__)
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -45,6 +53,7 @@ class SegmentationPostprocessorGl {
}; };
absl::Status GlInit(const bool produce_confidence_masks); absl::Status GlInit(const bool produce_confidence_masks);
bool HasGlExtension(std::string const& extension);
absl::Status CreateBasicFragmentShaderProgram( absl::Status CreateBasicFragmentShaderProgram(
std::string const& program_name, std::string const& program_name,
std::string const& fragment_shader_source, std::string const& fragment_shader_source,
@ -69,6 +78,10 @@ class SegmentationPostprocessorGl {
GlShader softmax_max_shader_; GlShader softmax_max_shader_;
GlShader softmax_transform_and_sum_shader_; GlShader softmax_transform_and_sum_shader_;
GlShader softmax_normalization_shader_; GlShader softmax_normalization_shader_;
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
SsboToTextureConverter ssbo_to_texture_converter_;
#endif
}; };
} // namespace tasks } // namespace tasks

View File

@ -0,0 +1,162 @@
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h"
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
// Quick compile-time warning to ensure usage on the proper platform.
#if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31)
#warning "SsboToTextureConverter should be used with OpenGL ES 3.1 or above"
#endif
namespace mediapipe {
namespace tasks {
namespace {
using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader;
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, 1};
// "Delinearization" shader:
// Example data using n=5 channels: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14 -->
// 0,1,2,3 | 4,X,X,X | 5,6,7,8 | 9,X,X,X | 10,11,12,13 | 14,X,X,X
const char delinearization_shader_source[] = R"(
precision highp float;
layout(rgba32f, binding = 0) writeonly uniform highp image2D output_texture;
uniform ivec2 out_size;
uniform int num_channels;
uniform int num_channels_padded; // ^ rounded up to nearest multiple of 4
layout(std430, binding = 2) readonly buffer B0 {
float elements[];
} input_data; // data tensor
void main() {
int out_width = out_size.x;
int out_height = out_size.y;
ivec2 gid = ivec2(gl_GlobalInvocationID.xy);
if (gid.x >= out_width || gid.y >= out_height) { return; }
int linear_index_pixels = gid.y * out_width + gid.x;
int linear_index = linear_index_pixels * 4;
int num_completed_chunks = linear_index / num_channels_padded;
int offset = linear_index % num_channels_padded;
int data_index = num_completed_chunks * num_channels + offset;
// Early exit if fully outside buffer
int data_size = input_data.elements.length();
if (data_index >= data_size) return;
// We add some extra logic here just to ensure we don't overrun buffer and get
// undefined behavior. TODO: Come up with nicer way around this if
// we end up needing this sort of patch more frequently.
float x = input_data.elements[data_index];
float y = 0.0;
float z = 0.0;
float w = 0.0;
if (data_index + 3 < data_size) {
w = input_data.elements[data_index + 3];
z = input_data.elements[data_index + 2];
y = input_data.elements[data_index + 1];
} else if (data_index + 2 < data_size) {
z = input_data.elements[data_index + 2];
y = input_data.elements[data_index + 1];
} else if (data_index + 1 < data_size) {
y = input_data.elements[data_index + 1];
}
ivec2 output_coordinate = ivec2(gid.x, gid.y);
vec4 out_value = vec4(x, y, z, w);
imageStore(output_texture, output_coordinate, out_value);
})";
// Commonly used to compute the number of blocks to launch in a kernel.
int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size;
}
} // namespace
absl::Status SsboToTextureConverter::Init() {
GlShader delinearization_shader;
std::string delinearization_shader_source_with_headers =
absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size),
delinearization_shader_source);
MP_RETURN_IF_ERROR(GlShader::CompileShader(
GL_COMPUTE_SHADER, delinearization_shader_source_with_headers,
&delinearization_shader));
delinearization_program_ = absl::make_unique<GlProgram>();
MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(
delinearization_shader, delinearization_program_.get()));
return absl::OkStatus();
}
void SsboToTextureConverter::Close() { delinearization_program_.reset(); }
std::pair<const uint32_t, const uint32_t>
SsboToTextureConverter::GetTextureSize() {
return std::make_pair(texture_width_, texture_height_);
}
absl::StatusOr<GLuint> SsboToTextureConverter::ConvertTensorToGlTexture(
const Tensor& tensor, const uint32_t width, const uint32_t height,
const uint32_t channels) {
// The tflite::gpu:: namespace looks like it's much simpler and older-- it
// doesn't tap into any memory pools, and doesn't allow linearF32 filtering
// where available, for example. The key difference is that it uses
// glTexStorage2D for allocation instead of glTexImage2D, which is necessary
// in order to create an immutable format (as required by glBindImageTexture).
// MP will automatically use this for RGBA16F but not RGBA32F textures
// currently, oddly enough. So options are:
// (1) extend MP to similarly handle RGBA32F
// (2) just make our own texture here and keep reusing, recreating if the size
// changes, which should generally not happen. (This is ok because we use
// the texture immediately and never output it from the calculator).
// (3) Change glBindImageTexture call to alternative so we can just use
// existing MP glTexImage2D storage creation? This seems less than
// ideal since it's rather nice to keep the above program in compute
// shader format.
// TODO: To be safe for this initial implementation, we go with
// option #2, as it's simplest/easiest, but this should be cleaned up later.
const uint32_t num_pixels_per_element = ((channels + 3) / 4);
const uint32_t padded_channels = 4 * num_pixels_per_element;
const uint32_t texture_width = width * num_pixels_per_element;
const uint32_t texture_height = height;
if (texture_width != texture_width_ || texture_height != texture_height_) {
// tflite::gpu::gl::GlTexture autoreleases, so we don't have to worry about
// freeing memory.
MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::FLOAT32, {texture_width, texture_height},
&out_texture_));
texture_width_ = texture_width;
texture_height_ = texture_height;
}
glBindImageTexture(0 /* output index */, out_texture_.id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA32F);
auto read_view = tensor.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2 /* input index */,
read_view.name());
glUseProgram(delinearization_program_->id());
glUniform2i(glGetUniformLocation(delinearization_program_->id(), "out_size"),
texture_width, texture_height);
glUniform1i(
glGetUniformLocation(delinearization_program_->id(), "num_channels"),
channels);
glUniform1i(glGetUniformLocation(delinearization_program_->id(),
"num_channels_padded"),
padded_channels);
const tflite::gpu::uint3 workgroups = {
NumGroups(texture_width, kWorkgroupSize),
NumGroups(texture_height, kWorkgroupSize), 1};
MP_RETURN_IF_ERROR(delinearization_program_->Dispatch(workgroups));
return out_texture_.id();
}
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,55 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_
#include <utility>
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gl_base.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h"
namespace mediapipe {
namespace tasks {
// Helper class for converting Android and Linux Tensors from OpenGL ES >=3.1
// SSBO objects into OpenGL ES <=3.0 2D textures. Cannot be used with other
// Tensor backends.
class SsboToTextureConverter {
public:
SsboToTextureConverter() = default;
~SsboToTextureConverter() = default;
absl::Status Init();
void Close();
absl::StatusOr<GLuint> ConvertTensorToGlTexture(const Tensor& tensor,
const uint32_t width,
const uint32_t height,
const uint32_t channels);
// Should only be called after ConvertTensorToGlTexture
std::pair<const uint32_t, const uint32_t> GetTextureSize();
private:
uint32_t texture_width_;
uint32_t texture_height_;
tflite::gpu::gl::GlTexture out_texture_;
std::unique_ptr<tflite::gpu::gl::GlProgram> delinearization_program_;
};
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_

View File

@ -43,9 +43,18 @@ limitations under the License.
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#ifdef __EMSCRIPTEN__ #ifdef __EMSCRIPTEN__
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h" #define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \
!MEDIAPIPE_USING_SWIFTSHADER && defined(MEDIAPIPE_ANDROID)
#define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1
#else
#undef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h"
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
// TODO: consolidate TensorToSegmentationCalculator. // TODO: consolidate TensorToSegmentationCalculator.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -308,19 +317,19 @@ class TensorsToSegmentationCalculator : public Node {
const float* tensors_buffer); const float* tensors_buffer);
TensorsToSegmentationCalculatorOptions options_; TensorsToSegmentationCalculatorOptions options_;
#ifdef __EMSCRIPTEN__ #ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
SegmentationPostprocessorGl postprocessor_; SegmentationPostprocessorGl postprocessor_;
#endif // __EMSCRIPTEN__ #endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
}; };
// static // static
absl::Status TensorsToSegmentationCalculator::UpdateContract( absl::Status TensorsToSegmentationCalculator::UpdateContract(
CalculatorContract* cc) { CalculatorContract* cc) {
#ifdef __EMSCRIPTEN__ #ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
return SegmentationPostprocessorGl::UpdateContract(cc); return SegmentationPostprocessorGl::UpdateContract(cc);
#else #else
return absl::OkStatus(); return absl::OkStatus();
#endif // __EMSCRIPTEN__ #endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
} }
absl::Status TensorsToSegmentationCalculator::Open( absl::Status TensorsToSegmentationCalculator::Open(
@ -340,9 +349,9 @@ absl::Status TensorsToSegmentationCalculator::Open(
"connected."); "connected.");
} }
} }
#ifdef __EMSCRIPTEN__ #ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_)); MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
#endif // __EMSCRIPTEN__ #endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
return absl::OkStatus(); return absl::OkStatus();
} }
@ -390,11 +399,11 @@ absl::Status TensorsToSegmentationCalculator::Process(
} }
// Use GPU postprocessing on web when Tensor is there already. // Use GPU postprocessing on web when Tensor is there already.
#ifdef __EMSCRIPTEN__ #ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
Shape output_shape = {/* height= */ output_height, Shape output_shape = {/* height= */ output_height,
/* width= */ output_width, /* width= */ output_width,
/* channels= */ input_shape.channels}; /* channels= */ input_shape.channels};
if (input_tensor.ready_as_opengl_texture_2d()) { if (input_tensor.ready_on_gpu()) {
bool produce_category_mask = options_.segmenter_options().output_type() == bool produce_category_mask = options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK || SegmenterOptions::CATEGORY_MASK ||
cc->Outputs().HasTag("CATEGORY_MASK"); cc->Outputs().HasTag("CATEGORY_MASK");
@ -428,7 +437,7 @@ absl::Status TensorsToSegmentationCalculator::Process(
} }
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // __EMSCRIPTEN__ #endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
// Otherwise, use CPU postprocessing. // Otherwise, use CPU postprocessing.
const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>(); const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>();