Internal change

PiperOrigin-RevId: 514580892
This commit is contained in:
MediaPipe Team 2023-03-06 18:11:38 -08:00 committed by Copybara-Service
parent 0337c7f52f
commit bd9a2ee1fc
4 changed files with 639 additions and 6 deletions

View File

@ -55,6 +55,27 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "segmentation_postprocessor_gl",
srcs = ["segmentation_postprocessor_gl.cc"],
hdrs = ["segmentation_postprocessor_gl.h"],
tags = ["nomac"],
deps = [
":tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:shader_util",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
)
cc_test(
name = "tensors_to_segmentation_calculator_test",
srcs = ["tensors_to_segmentation_calculator_test.cc"],

View File

@ -0,0 +1,502 @@
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace {
using mediapipe::kBasicSquareVertices;
using mediapipe::kBasicTextureVertices;
using mediapipe::kBasicVertexShader;
using ::mediapipe::tasks::vision::Shape;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
static constexpr char kActivationFragmentShader[] = R"(
DEFAULT_PRECISION(mediump, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
void main() {
vec4 in_value = texture2D(input_texture, sample_coordinate);
// Run activation function over all 4 channels at once.
%s
gl_FragColor = out_value;
})";
// Trivial passthrough fragment shader; do splitting in a custom vertex shader.
static constexpr char kPassthroughShader[] = R"(
DEFAULT_PRECISION(mediump, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
void main() {
gl_FragColor = texture2D(input_texture, sample_coordinate);
})";
// Vertex shader for splitting; kLayoutAligned means we just move across x-axis.
static constexpr char kSplitVertexShader[] = R"(
DEFAULT_PRECISION(highp, float)
attribute vec4 position;
attribute vec4 texture_coordinate;
varying vec2 sample_coordinate;
// We assume kLayoutAligned for now. Everything will be scaled properly, so just
// need offset for decimation iterations.
uniform float x_offset;
void main() {
sample_coordinate = vec2(texture_coordinate.x + x_offset, texture_coordinate.y);
gl_Position = position;
})";
// TODO: Consider using MRT to speed this up in the future.
static constexpr char kChannelSelectShader[] = R"(
DEFAULT_PRECISION(mediump, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
uniform int channel_select;
void main() {
vec4 in_value = texture2D(input_texture, sample_coordinate);
float out_value;
if (channel_select == 0) {
out_value = in_value.r;
} else if (channel_select == 1) {
out_value = in_value.g;
} else if (channel_select == 2) {
out_value = in_value.b;
} else {
out_value = in_value.a;
}
gl_FragColor = vec4(out_value, out_value, out_value, out_value);
})";
// Hard-coded for max of 3 textures for now, so num classes must be <= 12, and
// the cost of this shader will be higher than necessary for smaller numbers of
// classes.
// TODO: Improve this.
static constexpr char kArgmaxShader[] = R"(
DEFAULT_PRECISION(mediump, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture0;
uniform sampler2D input_texture1;
uniform sampler2D input_texture2;
int argmax4(vec4 vec) {
float aMax = max(vec.x, vec.y);
float bMax = max(vec.z, vec.w);
if (aMax >= bMax) {
if (vec.x >= vec.y) return 0;
return 1;
} else if (vec.z >= vec.w) return 2;
return 3;
}
float max4(vec4 vec) {
return max(max(vec.x, vec.y), max(vec.z, vec.w));
}
void main() {
// Grab all vecs
vec4 pixel0 = texture2D(input_texture0, sample_coordinate);
vec4 pixel1 = texture2D(input_texture1, sample_coordinate);
vec4 pixel2 = texture2D(input_texture2, sample_coordinate);
// Find vector which contains maximum value, and return its argmax
float max0 = max4(pixel0);
float max1 = max4(pixel1);
float max2 = max4(pixel2);
int argmax;
float out_value;
if (max0 >= max1) {
if (max0 >= max2) {
argmax = argmax4(pixel0);
} else {
argmax = argmax4(pixel2) + 8;
}
} else if (max1 >= max2) {
argmax = argmax4(pixel1) + 4;
} else {
argmax = argmax4(pixel2) + 8;
}
out_value = float(argmax) / 255.0;
gl_FragColor = vec4(out_value, out_value, out_value, out_value);
})";
} // namespace
// static
absl::Status SegmentationPostprocessorGl::UpdateContract(
CalculatorContract* cc) {
return GlCalculatorHelper::UpdateContract(cc);
}
absl::Status SegmentationPostprocessorGl::Initialize(
CalculatorContext* cc,
TensorsToSegmentationCalculatorOptions const& options) {
options_ = options; // Just copy for now
MP_RETURN_IF_ERROR(helper_.Open(cc));
MP_RETURN_IF_ERROR(GlInit());
return absl::OkStatus();
}
absl::Status SegmentationPostprocessorGl::GlInit() {
return helper_.RunInGlContext([this]() -> absl::Status {
// TODO: This part of the setup code is so common, we should really
// refactor to a helper utility.
const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
std::string activation_fn;
switch (options_.segmenter_options().activation()) {
case SegmenterOptions::SIGMOID:
LOG(INFO) << "SIGMOID activation function chosen on GPU";
activation_fn = "vec4 out_value = 1.0 / (exp(-in_value) + 1.0);";
break;
case SegmenterOptions::SOFTMAX:
LOG(ERROR) << "SOFTMAX activation function not implemented for GPU";
// TODO: Softmax algo per-pixel:
// (1) Find max of all channels
// (2) For each channel do exp(val - max_value) transform
// (3) Find sum over all channels
// (4) Divide by this sum
break;
case SegmenterOptions::NONE:
LOG(INFO) << "NONE activation function chosen on GPU";
activation_fn = "vec4 out_value = in_value;";
break;
}
// TODO: Skip activation step entirely for "NONE" to save a full
// renderpass. (And same applies for CATEGORY_MASK mode).
bool is_category_mask = options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK;
if (is_category_mask) {
LOG(INFO) << "CATEGORY_MASK requested; using NONE activation function.";
activation_fn = "vec4 out_value = in_value;";
}
const std::string activation_shader_source =
absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble),
absl::StrFormat(kActivationFragmentShader, activation_fn));
const std::string split_fragment_shader_source =
absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble),
std::string(kPassthroughShader));
const std::string split_vertex_shader_source =
absl::StrCat(std::string(mediapipe::kMediaPipeVertexShaderPreamble),
std::string(kSplitVertexShader));
const std::string channel_select_shader_source =
absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble),
std::string(kChannelSelectShader));
const std::string argmax_shader_source =
absl::StrCat(std::string(mediapipe::kMediaPipeFragmentShaderPreamble),
std::string(kArgmaxShader));
// Compile all our shader programs.
// Note: we enable `force_log_errors` so that we get full debugging error
// messages when compiling shaders on web, where normally such errors are
// suppressed. See //mediapipe/gpu/shader_util.cc for more
// info.
mediapipe::GlhCreateProgram(
kBasicVertexShader, activation_shader_source.c_str(), NUM_ATTRIBUTES,
&attr_name[0], attr_location, &activation_program_,
/* force_log_errors */ true);
RET_CHECK(activation_program_)
<< "Problem initializing the activation program.";
mediapipe::GlhCreateProgram(split_vertex_shader_source.c_str(),
split_fragment_shader_source.c_str(),
NUM_ATTRIBUTES, &attr_name[0], attr_location,
&split_program_,
/* force_log_errors */ true);
RET_CHECK(split_program_) << "Problem initializing the split program.";
mediapipe::GlhCreateProgram(
kBasicVertexShader, channel_select_shader_source.c_str(),
NUM_ATTRIBUTES, &attr_name[0], attr_location, &channel_select_program_,
/* force_log_errors */ true);
RET_CHECK(channel_select_program_)
<< "Problem initializing the channel select program.";
mediapipe::GlhCreateProgram(kBasicVertexShader,
argmax_shader_source.c_str(), NUM_ATTRIBUTES,
&attr_name[0], attr_location, &argmax_program_,
/* force_log_errors */ true);
RET_CHECK(argmax_program_) << "Problem initializing the argmax program.";
// Get uniform locations.
activation_texture_uniform_ =
glGetUniformLocation(activation_program_, "input_texture");
RET_CHECK(activation_texture_uniform_ > 0)
<< "activation input_texture uniform not found.";
split_texture_uniform_ =
glGetUniformLocation(split_program_, "input_texture");
RET_CHECK(split_texture_uniform_ > 0)
<< "split input_texture uniform not found.";
split_x_offset_uniform_ = glGetUniformLocation(split_program_, "x_offset");
RET_CHECK(split_x_offset_uniform_ > 0)
<< "split x_offset uniform not found.";
channel_select_texture_uniform_ =
glGetUniformLocation(channel_select_program_, "input_texture");
RET_CHECK(channel_select_texture_uniform_ > 0)
<< "channel select input_texture uniform not found.";
channel_select_index_uniform_ =
glGetUniformLocation(channel_select_program_, "channel_select");
RET_CHECK(channel_select_index_uniform_ > 0)
<< "channel select indexing uniform not found.";
argmax_texture0_uniform_ =
glGetUniformLocation(argmax_program_, "input_texture0");
RET_CHECK(argmax_texture0_uniform_ > 0)
<< "argmax input_texture0 uniform not found.";
argmax_texture1_uniform_ =
glGetUniformLocation(argmax_program_, "input_texture1");
RET_CHECK(argmax_texture1_uniform_ > 0)
<< "argmax input_texture1 uniform not found.";
argmax_texture2_uniform_ =
glGetUniformLocation(argmax_program_, "input_texture2");
RET_CHECK(argmax_texture2_uniform_ > 0)
<< "argmax input_texture2 uniform not found.";
// TODO: If ES3.0+ only, switch to VAO for handling attributes.
glGenBuffers(1, &square_vertices_);
glBindBuffer(GL_ARRAY_BUFFER, square_vertices_);
glBufferData(GL_ARRAY_BUFFER, sizeof(kBasicSquareVertices),
kBasicSquareVertices, GL_STATIC_DRAW);
glGenBuffers(1, &texture_vertices_);
glBindBuffer(GL_ARRAY_BUFFER, texture_vertices_);
glBufferData(GL_ARRAY_BUFFER, sizeof(kBasicTextureVertices),
kBasicTextureVertices, GL_STATIC_DRAW);
glBindBuffer(GL_ARRAY_BUFFER, 0);
return absl::OkStatus();
});
}
std::vector<std::unique_ptr<Image>>
SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape,
const Shape& output_shape,
const Tensor& tensor) {
std::vector<std::unique_ptr<Image>> image_outputs;
auto status = helper_.RunInGlContext([this, &input_shape, &output_shape,
&tensor,
&image_outputs]() -> absl::Status {
// Get Tensor input and image output parameters
int input_width, input_height;
if (!tensor.ready_as_opengl_texture_2d()) {
LOG(WARNING) << "Tensor wasn't ready on GPU; using slow workaround.";
(void)tensor.GetCpuReadView();
}
const auto layout = tensor.GetOpenGlTexture2dReadView().GetLayoutDimensions(
tensor.shape(), &input_width, &input_height);
if (layout != Tensor::OpenGlTexture2dView::Layout::kAligned) {
LOG(ERROR) << "Tensor layout not kAligned! Cannot handle.";
}
bool is_category_mask = options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK;
const GpuBufferFormat activation_output_format =
GpuBufferFormat::kRGBAFloat128;
const GpuBufferFormat chunk_output_format = GpuBufferFormat::kRGBAFloat128;
// Uint8 pipeline and conversions are lacking, so for now we just use F32
// textures even for category masks.
const GpuBufferFormat final_output_format = GpuBufferFormat::kGrayFloat32;
const Tensor::OpenGlTexture2dView read_view =
tensor.GetOpenGlTexture2dReadView();
const int width = input_shape.width; // Slice width from shape
const int height = input_shape.height; // Slice height from chape
const int num_outputs = input_shape.channels; // One output per channel
const int num_chunks = (input_shape.channels + 3) / 4; // ceil(channels/4)
const int output_width = output_shape.width; // Final output width
const int output_height = output_shape.height; // Final output height
// We disable blending or else our alpha channel may destroy our other
// channels' data.
glDisable(GL_BLEND);
// Step 0: bind buffers / textures
glBindBuffer(GL_ARRAY_BUFFER, square_vertices_);
glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr);
glEnableVertexAttribArray(ATTRIB_VERTEX);
glBindBuffer(GL_ARRAY_BUFFER, texture_vertices_);
glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr);
glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
// Step 1: apply activation pass
glUseProgram(activation_program_);
glUniform1i(activation_texture_uniform_, 1);
GlTexture activated_texture = helper_.CreateDestinationTexture(
input_width, input_height, activation_output_format);
helper_.BindFramebuffer(activated_texture);
// All our input source textures are just simple GL_TEXTURE_2D types.
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, read_view.name());
// Render
glClear(GL_COLOR_BUFFER_BIT);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// Step 2: split megatexture into 4-chunks (assume kLayoutAligned for now).
std::vector<GlTexture> chunks;
// # chunks: offset in pixels at which taps must be made
// 1 chunk: 0
// 2 chunks: -0.5, +0.5
// 3 chunks: -1,0,1
// 4 chunks: -1.5, -.5, .5, 1.5
// ...
// Step is always 1 pixel, while initial offset is (1 - N) * 0.5
glUseProgram(split_program_);
glUniform1i(split_texture_uniform_, 1);
const float tex_offset = 0.5 * (1.0 - (float)num_chunks);
for (int i = 0; i < num_chunks; i++) {
chunks.push_back(
helper_.CreateDestinationTexture(width, height, chunk_output_format));
helper_.BindFramebuffer(chunks.back());
glUniform1f(split_x_offset_uniform_,
((float)i + tex_offset) / (float)(input_width));
// Technically duplicated, but fine for now; we want this after the bind
glBindTexture(GL_TEXTURE_2D, activated_texture.name());
// Disable HW interpolation
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
// Render
glClear(GL_COLOR_BUFFER_BIT);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
}
std::vector<GlTexture> outputs;
if (is_category_mask) {
// Step 3: For CATEGORY, apply argmax shader with up to 3 textures to
// extract final index mask.
RET_CHECK(num_chunks <= 3)
<< "Cannot handle more than 12 classes in argmax shader.";
glUseProgram(argmax_program_);
glUniform1i(argmax_texture0_uniform_, 1);
glUniform1i(argmax_texture1_uniform_, 2);
glUniform1i(argmax_texture2_uniform_, 3);
outputs.push_back(helper_.CreateDestinationTexture(
output_width, output_height, final_output_format));
helper_.BindFramebuffer(outputs.back());
// Bind however many chunks we have
for (int i = 0; i < num_chunks; ++i) {
glActiveTexture(GL_TEXTURE1 + i);
glBindTexture(GL_TEXTURE_2D, chunks[i].name());
}
for (int i = num_chunks; i < 3; ++i) { // 3 is hard-coded max chunks
glActiveTexture(GL_TEXTURE1 + i);
// If texture is unbound, sampling from it should always give zeros.
// This is not ideal, but is ok for now for not polluting the argmax
// shader results too much.
glBindTexture(GL_TEXTURE_2D, 0);
}
glClear(GL_COLOR_BUFFER_BIT);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
// Unbind the extra textures here.
for (int i = 0; i < num_chunks; ++i) {
glActiveTexture(GL_TEXTURE1 + i);
glBindTexture(GL_TEXTURE_2D, 0);
}
} else {
// Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
// final textures.
glUseProgram(channel_select_program_);
glUniform1i(channel_select_texture_uniform_, 1);
for (int i = 0; i < num_outputs; i++) {
glUniform1i(channel_select_index_uniform_, (i % 4));
outputs.push_back(helper_.CreateDestinationTexture(
output_width, output_height, final_output_format));
helper_.BindFramebuffer(outputs.back());
// We have to rebind constantly because BindFramebuffer seems to
// interfere with this.
glBindTexture(GL_TEXTURE_2D, chunks[i / 4].name());
glClear(GL_COLOR_BUFFER_BIT);
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
}
}
// Unbind everything
glDisableVertexAttribArray(ATTRIB_VERTEX);
glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION);
glBindBuffer(GL_ARRAY_BUFFER, 0);
glBindFramebuffer(GL_FRAMEBUFFER, 0);
glBindTexture(GL_TEXTURE_2D, 0);
// Get Image vector from GlTexture vector
for (auto& output_texture : outputs) {
image_outputs.push_back(output_texture.GetFrame<Image>());
}
return absl::OkStatus();
});
if (!status.ok()) {
LOG(ERROR) << "Error with rendering: " << status;
}
return image_outputs;
}
// Cleanup OpenGL resources on destruction
SegmentationPostprocessorGl::~SegmentationPostprocessorGl() {
helper_.RunInGlContext([this] {
glDeleteProgram(activation_program_);
glDeleteProgram(argmax_program_);
glDeleteProgram(channel_select_program_);
glDeleteProgram(split_program_);
glDeleteBuffers(1, &square_vertices_);
glDeleteBuffers(1, &texture_vertices_);
activation_program_ = 0;
argmax_program_ = 0;
channel_select_program_ = 0;
split_program_ = 0;
square_vertices_ = 0;
texture_vertices_ = 0;
});
}
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,66 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SEGMENTATION_POSTPROCESSOR_GL_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SEGMENTATION_POSTPROCESSOR_GL_H_
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
class SegmentationPostprocessorGl {
public:
~SegmentationPostprocessorGl();
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Initialize(
CalculatorContext* cc,
TensorsToSegmentationCalculatorOptions const& options);
std::vector<std::unique_ptr<Image>> GetSegmentationResultGpu(
const vision::Shape& input_shape, const vision::Shape& output_shape,
const Tensor& tensor);
private:
absl::Status GlInit();
TensorsToSegmentationCalculatorOptions options_;
GlCalculatorHelper helper_;
// GL references (programs, buffers, uniforms)
GLuint activation_program_ = 0;
GLuint argmax_program_ = 0;
GLuint channel_select_program_ = 0;
GLuint split_program_ = 0;
GLuint square_vertices_ = 0;
GLuint texture_vertices_ = 0;
GLint activation_texture_uniform_;
GLint argmax_texture0_uniform_;
GLint argmax_texture1_uniform_;
GLint argmax_texture2_uniform_;
GLint channel_select_texture_uniform_;
GLint channel_select_index_uniform_;
GLint split_texture_uniform_;
GLint split_x_offset_uniform_;
};
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SEGMENTATION_POSTPROCESSOR_GL_H_

View File

@ -39,6 +39,10 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "mediapipe/util/label_map.pb.h"
#ifdef __EMSCRIPTEN__
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/segmentation_postprocessor_gl.h"
#endif // __EMSCRIPTEN__
// TODO: consolidate TensorToSegmentationCalculator.
namespace mediapipe {
namespace tasks {
@ -118,23 +122,41 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
absl::Status Process(CalculatorContext* cc);
private:
std::vector<Image> GetSegmentationResult(const Shape& input_shape,
const Shape& output_shape,
const float* tensors_buffer);
std::vector<Image> GetSegmentationResultCpu(const Shape& input_shape,
const Shape& output_shape,
const float* tensors_buffer);
TensorsToSegmentationCalculatorOptions options_;
#ifdef __EMSCRIPTEN__
SegmentationPostprocessorGl postprocessor_;
#endif // __EMSCRIPTEN__
};
// static
absl::Status TensorsToSegmentationCalculator::UpdateContract(
CalculatorContract* cc) {
#ifdef __EMSCRIPTEN__
return SegmentationPostprocessorGl::UpdateContract(cc);
#else
return absl::OkStatus();
#endif // __EMSCRIPTEN__
}
absl::Status TensorsToSegmentationCalculator::Open(
mediapipe::CalculatorContext* cc) {
options_ = cc->Options<TensorsToSegmentationCalculatorOptions>();
RET_CHECK_NE(options_.segmenter_options().output_type(),
SegmenterOptions::UNSPECIFIED)
<< "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK].";
#ifdef __EMSCRIPTEN__
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
#endif // __EMSCRIPTEN__
return absl::OkStatus();
}
@ -167,7 +189,29 @@ absl::Status TensorsToSegmentationCalculator::Process(
? 1
: input_shape.channels};
std::vector<Image> segmented_masks = GetSegmentationResult(
// Use GPU postprocessing on web when Tensor is there already and has <= 12
// categories.
#ifdef __EMSCRIPTEN__
if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) {
std::vector<std::unique_ptr<Image>> segmented_masks =
postprocessor_.GetSegmentationResultGpu(input_shape, output_shape,
input_tensor);
for (int i = 0; i < segmented_masks.size(); ++i) {
// Real output on GPU.
// kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
// Reformat as CPU for now for testing.
// TODO: Switch to real GPU output when GPU output pipeline is
// ready.
Image new_image(segmented_masks[i]->GetImageFrameSharedPtr());
kSegmentationOut(cc)[i].Send(std::move(new_image));
}
return absl::OkStatus();
}
#endif // __EMSCRIPTEN__
// Otherwise, use CPU postprocessing.
std::vector<Image> segmented_masks = GetSegmentationResultCpu(
input_shape, output_shape, input_tensor.GetCpuReadView().buffer<float>());
for (int i = 0; i < segmented_masks.size(); ++i) {
kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
@ -175,7 +219,7 @@ absl::Status TensorsToSegmentationCalculator::Process(
return absl::OkStatus();
}
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResult(
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
const Shape& input_shape, const Shape& output_shape,
const float* tensors_buffer) {
std::function<void(absl::Span<const float> values,