Internal change
PiperOrigin-RevId: 535751178
This commit is contained in:
parent
fddc3facf0
commit
d4c7ed2217
|
@ -50,7 +50,12 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": [
|
||||
":segmentation_postprocessor_gl",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
@ -72,6 +77,29 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/utils:image_utils",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": [
|
||||
"ssbo_to_texture_converter",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ssbo_to_texture_converter",
|
||||
srcs = ["ssbo_to_texture_converter.cc"],
|
||||
hdrs = ["ssbo_to_texture_converter.h"],
|
||||
tags = [
|
||||
"nomac",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/gpu:gl_base",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,14 @@ namespace mediapipe {
|
|||
namespace tasks {
|
||||
namespace {
|
||||
|
||||
// On most platforms, glGetUniformLocation returns -1 for an error status, but
|
||||
// on web we'll see 0 instead.
|
||||
#ifdef __EMSCRIPTEN__
|
||||
const GLint kUniformErrorStatus = 0;
|
||||
#else
|
||||
const GLint kUniformErrorStatus = -1;
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
using mediapipe::kBasicSquareVertices;
|
||||
using mediapipe::kBasicTextureVertices;
|
||||
using mediapipe::kBasicVertexShader;
|
||||
|
@ -341,7 +349,7 @@ absl::Status SegmentationPostprocessorGl::CreateBasicFragmentShaderProgram(
|
|||
for (const auto& uniform_name : uniform_names) {
|
||||
shader_struct_ptr->uniforms[uniform_name] =
|
||||
glGetUniformLocation(shader_struct_ptr->program, uniform_name.c_str());
|
||||
RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > 0)
|
||||
RET_CHECK(shader_struct_ptr->uniforms[uniform_name] > kUniformErrorStatus)
|
||||
<< uniform_name << " uniform not found for " << program_name
|
||||
<< " program";
|
||||
}
|
||||
|
@ -427,10 +435,10 @@ absl::Status SegmentationPostprocessorGl::GlInit(
|
|||
// Get split program uniform locations.
|
||||
split_texture_uniform_ =
|
||||
glGetUniformLocation(split_program_, "input_texture");
|
||||
RET_CHECK(split_texture_uniform_ > 0)
|
||||
RET_CHECK(split_texture_uniform_ > kUniformErrorStatus)
|
||||
<< "split input_texture uniform not found.";
|
||||
split_x_offset_uniform_ = glGetUniformLocation(split_program_, "x_offset");
|
||||
RET_CHECK(split_x_offset_uniform_ > 0)
|
||||
RET_CHECK(split_x_offset_uniform_ > kUniformErrorStatus)
|
||||
<< "split x_offset uniform not found.";
|
||||
|
||||
// TODO: If ES3.0+ only, switch to VAO for handling attributes.
|
||||
|
@ -445,10 +453,24 @@ absl::Status SegmentationPostprocessorGl::GlInit(
|
|||
kBasicTextureVertices, GL_STATIC_DRAW);
|
||||
|
||||
glBindBuffer(GL_ARRAY_BUFFER, 0);
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
MP_RETURN_IF_ERROR(ssbo_to_texture_converter_.Init());
|
||||
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
|
||||
return absl::OkStatus();
|
||||
});
|
||||
}
|
||||
|
||||
// On Android, the extensions are prefixed by GL_, whereas on web they are not.
|
||||
bool SegmentationPostprocessorGl::HasGlExtension(std::string const& extension) {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
return helper_.GetGlContext().HasGlExtension(extension);
|
||||
#else
|
||||
return helper_.GetGlContext().HasGlExtension("GL_" + extension);
|
||||
#endif // __EMSCRIPTEN__
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Image>>
|
||||
SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
||||
const Shape& input_shape, const Shape& output_shape, const Tensor& tensor,
|
||||
|
@ -459,18 +481,35 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
|||
produce_category_mask,
|
||||
&image_outputs]() -> absl::Status {
|
||||
// Get Tensor input and image output parameters
|
||||
const int width = input_shape.width; // Slice width from shape
|
||||
const int height = input_shape.height; // Slice height from chape
|
||||
const int num_outputs = input_shape.channels; // One output per channel
|
||||
const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4)
|
||||
const int output_width = output_shape.width; // Final output width
|
||||
const int output_height = output_shape.height; // Final output height
|
||||
int input_width, input_height;
|
||||
|
||||
if (!tensor.ready_as_opengl_texture_2d()) {
|
||||
if (!tensor.ready_on_gpu()) {
|
||||
LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround.";
|
||||
(void)tensor.GetCpuReadView();
|
||||
}
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
// If our Tensor is an SSBO, then it's also linearized, so we convert to a
|
||||
// kAligned 2d texture using a special converter and then proceed as before.
|
||||
GLuint ssbo_tex_id;
|
||||
ASSIGN_OR_RETURN(ssbo_tex_id,
|
||||
ssbo_to_texture_converter_.ConvertTensorToGlTexture(
|
||||
tensor, width, height, num_outputs));
|
||||
std::tie(input_width, input_height) =
|
||||
ssbo_to_texture_converter_.GetTextureSize();
|
||||
#else
|
||||
const auto layout = tensor.GetOpenGlTexture2dReadView().GetLayoutDimensions(
|
||||
tensor.shape(), &input_width, &input_height);
|
||||
if (layout != Tensor::OpenGlTexture2dView::Layout::kAligned) {
|
||||
LOG(ERROR) << "Tensor layout not kAligned! Cannot handle.";
|
||||
}
|
||||
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
|
||||
// Optimization: Only apply SOFTMAX when producing confidence masks, since
|
||||
// SOFTMAX errors out when num_classes = 1, so we don't have to worry about
|
||||
|
@ -486,14 +525,12 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
|||
// (3) blending
|
||||
// Otherwise, we just try for F16. See b/277656755 for more information.
|
||||
// TODO: In the future, separate these 3 different restrictions.
|
||||
// TODO: Also, we should extend this logic to non-web platforms.
|
||||
static bool can_use_f32 =
|
||||
helper_.GetGlContext().HasGlExtension("EXT_color_buffer_float") &&
|
||||
helper_.GetGlContext().HasGlExtension("OES_texture_float_linear") &&
|
||||
helper_.GetGlContext().HasGlExtension("EXT_float_blend");
|
||||
// TODO: Also, we should extend this logic to all platforms.
|
||||
static bool can_use_f32 = HasGlExtension("EXT_color_buffer_float") &&
|
||||
HasGlExtension("OES_texture_float_linear") &&
|
||||
HasGlExtension("EXT_float_blend");
|
||||
static bool can_use_f16_backup =
|
||||
helper_.GetGlContext().HasGlExtension("EXT_color_buffer_half_float");
|
||||
|
||||
HasGlExtension("EXT_color_buffer_half_float");
|
||||
RET_CHECK(can_use_f32 || can_use_f16_backup)
|
||||
<< "Segmentation postprocessing error: GPU does not fully support "
|
||||
<< "4-channel float32 or float16 formats.";
|
||||
|
@ -510,15 +547,6 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
|||
const GpuBufferFormat final_output_format =
|
||||
can_use_f32 ? GpuBufferFormat::kGrayFloat32
|
||||
: GpuBufferFormat::kGrayHalf16;
|
||||
const Tensor::OpenGlTexture2dView read_view =
|
||||
tensor.GetOpenGlTexture2dReadView();
|
||||
|
||||
const int width = input_shape.width; // Slice width from shape
|
||||
const int height = input_shape.height; // Slice height from chape
|
||||
const int num_outputs = input_shape.channels; // One output per channel
|
||||
const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4)
|
||||
const int output_width = output_shape.width; // Final output width
|
||||
const int output_height = output_shape.height; // Final output height
|
||||
|
||||
// We disable blending or else our alpha channel may destroy our other
|
||||
// channels' data.
|
||||
|
@ -540,9 +568,16 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
|
|||
input_width, input_height, activation_output_format);
|
||||
helper_.BindFramebuffer(activated_texture);
|
||||
|
||||
// All our input source textures are just simple GL_TEXTURE_2D types.
|
||||
// All our input source textures will be just simple GL_TEXTURE_2D types.
|
||||
glActiveTexture(GL_TEXTURE1);
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
glBindTexture(GL_TEXTURE_2D, ssbo_tex_id);
|
||||
#else
|
||||
const Tensor::OpenGlTexture2dView read_view =
|
||||
tensor.GetOpenGlTexture2dReadView();
|
||||
glBindTexture(GL_TEXTURE_2D, read_view.name());
|
||||
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
|
||||
// Render
|
||||
glClear(GL_COLOR_BUFFER_BIT);
|
||||
|
@ -841,6 +876,10 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() {
|
|||
glDeleteProgram(softmax_max_shader_.program);
|
||||
glDeleteProgram(softmax_transform_and_sum_shader_.program);
|
||||
glDeleteProgram(softmax_normalization_shader_.program);
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
ssbo_to_texture_converter_.Close();
|
||||
#endif // TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,14 @@
|
|||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
// On Android with compute shaders we include the SSBO-to-texture converter
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \
|
||||
defined(__ANDROID__)
|
||||
#define TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING 1
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h"
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 &&
|
||||
// defined(__ANDROID__)
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
||||
|
@ -45,6 +53,7 @@ class SegmentationPostprocessorGl {
|
|||
};
|
||||
|
||||
absl::Status GlInit(const bool produce_confidence_masks);
|
||||
bool HasGlExtension(std::string const& extension);
|
||||
absl::Status CreateBasicFragmentShaderProgram(
|
||||
std::string const& program_name,
|
||||
std::string const& fragment_shader_source,
|
||||
|
@ -69,6 +78,10 @@ class SegmentationPostprocessorGl {
|
|||
GlShader softmax_max_shader_;
|
||||
GlShader softmax_transform_and_sum_shader_;
|
||||
GlShader softmax_normalization_shader_;
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GLES_31_POSTPROCESSING
|
||||
SsboToTextureConverter ssbo_to_texture_converter_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace tasks
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/ssbo_to_texture_converter.h"
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
|
||||
|
||||
// Quick compile-time warning to ensure usage on the proper platform.
|
||||
#if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31)
|
||||
#warning "SsboToTextureConverter should be used with OpenGL ES 3.1 or above"
|
||||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace {
|
||||
|
||||
using ::tflite::gpu::gl::GlProgram;
|
||||
using ::tflite::gpu::gl::GlShader;
|
||||
|
||||
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
|
||||
const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, 1};
|
||||
|
||||
// "Delinearization" shader:
|
||||
// Example data using n=5 channels: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14 -->
|
||||
// 0,1,2,3 | 4,X,X,X | 5,6,7,8 | 9,X,X,X | 10,11,12,13 | 14,X,X,X
|
||||
const char delinearization_shader_source[] = R"(
|
||||
precision highp float;
|
||||
layout(rgba32f, binding = 0) writeonly uniform highp image2D output_texture;
|
||||
|
||||
uniform ivec2 out_size;
|
||||
uniform int num_channels;
|
||||
uniform int num_channels_padded; // ^ rounded up to nearest multiple of 4
|
||||
|
||||
layout(std430, binding = 2) readonly buffer B0 {
|
||||
float elements[];
|
||||
} input_data; // data tensor
|
||||
|
||||
void main() {
|
||||
int out_width = out_size.x;
|
||||
int out_height = out_size.y;
|
||||
|
||||
ivec2 gid = ivec2(gl_GlobalInvocationID.xy);
|
||||
if (gid.x >= out_width || gid.y >= out_height) { return; }
|
||||
int linear_index_pixels = gid.y * out_width + gid.x;
|
||||
int linear_index = linear_index_pixels * 4;
|
||||
|
||||
int num_completed_chunks = linear_index / num_channels_padded;
|
||||
int offset = linear_index % num_channels_padded;
|
||||
int data_index = num_completed_chunks * num_channels + offset;
|
||||
|
||||
// Early exit if fully outside buffer
|
||||
int data_size = input_data.elements.length();
|
||||
if (data_index >= data_size) return;
|
||||
|
||||
// We add some extra logic here just to ensure we don't overrun buffer and get
|
||||
// undefined behavior. TODO: Come up with nicer way around this if
|
||||
// we end up needing this sort of patch more frequently.
|
||||
float x = input_data.elements[data_index];
|
||||
float y = 0.0;
|
||||
float z = 0.0;
|
||||
float w = 0.0;
|
||||
if (data_index + 3 < data_size) {
|
||||
w = input_data.elements[data_index + 3];
|
||||
z = input_data.elements[data_index + 2];
|
||||
y = input_data.elements[data_index + 1];
|
||||
} else if (data_index + 2 < data_size) {
|
||||
z = input_data.elements[data_index + 2];
|
||||
y = input_data.elements[data_index + 1];
|
||||
} else if (data_index + 1 < data_size) {
|
||||
y = input_data.elements[data_index + 1];
|
||||
}
|
||||
|
||||
ivec2 output_coordinate = ivec2(gid.x, gid.y);
|
||||
vec4 out_value = vec4(x, y, z, w);
|
||||
imageStore(output_texture, output_coordinate, out_value);
|
||||
})";
|
||||
|
||||
// Commonly used to compute the number of blocks to launch in a kernel.
|
||||
int NumGroups(const int size, const int group_size) { // NOLINT
|
||||
return (size + group_size - 1) / group_size;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status SsboToTextureConverter::Init() {
|
||||
GlShader delinearization_shader;
|
||||
std::string delinearization_shader_source_with_headers =
|
||||
absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size),
|
||||
delinearization_shader_source);
|
||||
MP_RETURN_IF_ERROR(GlShader::CompileShader(
|
||||
GL_COMPUTE_SHADER, delinearization_shader_source_with_headers,
|
||||
&delinearization_shader));
|
||||
delinearization_program_ = absl::make_unique<GlProgram>();
|
||||
MP_RETURN_IF_ERROR(GlProgram::CreateWithShader(
|
||||
delinearization_shader, delinearization_program_.get()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void SsboToTextureConverter::Close() { delinearization_program_.reset(); }
|
||||
|
||||
std::pair<const uint32_t, const uint32_t>
|
||||
SsboToTextureConverter::GetTextureSize() {
|
||||
return std::make_pair(texture_width_, texture_height_);
|
||||
}
|
||||
|
||||
absl::StatusOr<GLuint> SsboToTextureConverter::ConvertTensorToGlTexture(
|
||||
const Tensor& tensor, const uint32_t width, const uint32_t height,
|
||||
const uint32_t channels) {
|
||||
// The tflite::gpu:: namespace looks like it's much simpler and older-- it
|
||||
// doesn't tap into any memory pools, and doesn't allow linearF32 filtering
|
||||
// where available, for example. The key difference is that it uses
|
||||
// glTexStorage2D for allocation instead of glTexImage2D, which is necessary
|
||||
// in order to create an immutable format (as required by glBindImageTexture).
|
||||
// MP will automatically use this for RGBA16F but not RGBA32F textures
|
||||
// currently, oddly enough. So options are:
|
||||
// (1) extend MP to similarly handle RGBA32F
|
||||
// (2) just make our own texture here and keep reusing, recreating if the size
|
||||
// changes, which should generally not happen. (This is ok because we use
|
||||
// the texture immediately and never output it from the calculator).
|
||||
// (3) Change glBindImageTexture call to alternative so we can just use
|
||||
// existing MP glTexImage2D storage creation? This seems less than
|
||||
// ideal since it's rather nice to keep the above program in compute
|
||||
// shader format.
|
||||
// TODO: To be safe for this initial implementation, we go with
|
||||
// option #2, as it's simplest/easiest, but this should be cleaned up later.
|
||||
const uint32_t num_pixels_per_element = ((channels + 3) / 4);
|
||||
const uint32_t padded_channels = 4 * num_pixels_per_element;
|
||||
const uint32_t texture_width = width * num_pixels_per_element;
|
||||
const uint32_t texture_height = height;
|
||||
if (texture_width != texture_width_ || texture_height != texture_height_) {
|
||||
// tflite::gpu::gl::GlTexture autoreleases, so we don't have to worry about
|
||||
// freeing memory.
|
||||
MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture(
|
||||
tflite::gpu::DataType::FLOAT32, {texture_width, texture_height},
|
||||
&out_texture_));
|
||||
texture_width_ = texture_width;
|
||||
texture_height_ = texture_height;
|
||||
}
|
||||
|
||||
glBindImageTexture(0 /* output index */, out_texture_.id(), 0, GL_FALSE, 0,
|
||||
GL_WRITE_ONLY, GL_RGBA32F);
|
||||
auto read_view = tensor.GetOpenGlBufferReadView();
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2 /* input index */,
|
||||
read_view.name());
|
||||
|
||||
glUseProgram(delinearization_program_->id());
|
||||
glUniform2i(glGetUniformLocation(delinearization_program_->id(), "out_size"),
|
||||
texture_width, texture_height);
|
||||
glUniform1i(
|
||||
glGetUniformLocation(delinearization_program_->id(), "num_channels"),
|
||||
channels);
|
||||
glUniform1i(glGetUniformLocation(delinearization_program_->id(),
|
||||
"num_channels_padded"),
|
||||
padded_channels);
|
||||
|
||||
const tflite::gpu::uint3 workgroups = {
|
||||
NumGroups(texture_width, kWorkgroupSize),
|
||||
NumGroups(texture_height, kWorkgroupSize), 1};
|
||||
MP_RETURN_IF_ERROR(delinearization_program_->Dispatch(workgroups));
|
||||
return out_texture_.id();
|
||||
}
|
||||
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
||||
// Helper class for converting Android and Linux Tensors from OpenGL ES >=3.1
|
||||
// SSBO objects into OpenGL ES <=3.0 2D textures. Cannot be used with other
|
||||
// Tensor backends.
|
||||
class SsboToTextureConverter {
|
||||
public:
|
||||
SsboToTextureConverter() = default;
|
||||
~SsboToTextureConverter() = default;
|
||||
absl::Status Init();
|
||||
void Close();
|
||||
absl::StatusOr<GLuint> ConvertTensorToGlTexture(const Tensor& tensor,
|
||||
const uint32_t width,
|
||||
const uint32_t height,
|
||||
const uint32_t channels);
|
||||
|
||||
// Should only be called after ConvertTensorToGlTexture
|
||||
std::pair<const uint32_t, const uint32_t> GetTextureSize();
|
||||
|
||||
private:
|
||||
uint32_t texture_width_;
|
||||
uint32_t texture_height_;
|
||||
tflite::gpu::gl::GlTexture out_texture_;
|
||||
std::unique_ptr<tflite::gpu::gl::GlProgram> delinearization_program_;
|
||||
};
|
||||
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SSBO_TO_TEXTURE_CONVERTER_H_
|
|
@ -43,9 +43,18 @@ limitations under the License.
|
|||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h"
|
||||
#define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1
|
||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 && \
|
||||
!MEDIAPIPE_USING_SWIFTSHADER && defined(MEDIAPIPE_ANDROID)
|
||||
#define TASK_SEGMENTATION_USE_GL_POSTPROCESSING 1
|
||||
#else
|
||||
#undef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h"
|
||||
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
|
||||
// TODO: consolidate TensorToSegmentationCalculator.
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -308,19 +317,19 @@ class TensorsToSegmentationCalculator : public Node {
|
|||
const float* tensors_buffer);
|
||||
TensorsToSegmentationCalculatorOptions options_;
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
SegmentationPostprocessorGl postprocessor_;
|
||||
#endif // __EMSCRIPTEN__
|
||||
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
};
|
||||
|
||||
// static
|
||||
absl::Status TensorsToSegmentationCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
return SegmentationPostprocessorGl::UpdateContract(cc);
|
||||
#else
|
||||
return absl::OkStatus();
|
||||
#endif // __EMSCRIPTEN__
|
||||
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
}
|
||||
|
||||
absl::Status TensorsToSegmentationCalculator::Open(
|
||||
|
@ -340,9 +349,9 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
|||
"connected.");
|
||||
}
|
||||
}
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
|
||||
#endif // __EMSCRIPTEN__
|
||||
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -390,11 +399,11 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
|||
}
|
||||
|
||||
// Use GPU postprocessing on web when Tensor is there already.
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#ifdef TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
Shape output_shape = {/* height= */ output_height,
|
||||
/* width= */ output_width,
|
||||
/* channels= */ input_shape.channels};
|
||||
if (input_tensor.ready_as_opengl_texture_2d()) {
|
||||
if (input_tensor.ready_on_gpu()) {
|
||||
bool produce_category_mask = options_.segmenter_options().output_type() ==
|
||||
SegmenterOptions::CATEGORY_MASK ||
|
||||
cc->Outputs().HasTag("CATEGORY_MASK");
|
||||
|
@ -428,7 +437,7 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
|||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif // __EMSCRIPTEN__
|
||||
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING
|
||||
|
||||
// Otherwise, use CPU postprocessing.
|
||||
const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>();
|
||||
|
|
Loading…
Reference in New Issue
Block a user