Special treatment for 1-class segmentation category mask output on GPU.
PiperOrigin-RevId: 523271622
This commit is contained in:
parent
87ec846ed6
commit
3bb411e99d
|
@ -185,6 +185,28 @@ void main() {
|
||||||
}
|
}
|
||||||
})";
|
})";
|
||||||
|
|
||||||
|
// Special argmax shader for N=1 classes. We don't need to worry about softmax
|
||||||
|
// activation (it is assumed softmax requires N > 1 classes), but this should
|
||||||
|
// occur after SIGMOID activation if specified. Instead of a true argmax, we
|
||||||
|
// simply use 0.5 as the cutoff, assigning 1 (foreground) or 0 (background)
|
||||||
|
// based on whether the confidence value reaches this cutoff or not,
|
||||||
|
// respectively.
|
||||||
|
static constexpr char kArgmaxOneClassShader[] = R"(
|
||||||
|
DEFAULT_PRECISION(mediump, float)
|
||||||
|
in vec2 sample_coordinate;
|
||||||
|
uniform sampler2D input_texture;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
float input_val = texture2D(input_texture, sample_coordinate).x;
|
||||||
|
// Category is just value rounded to nearest integer; then we map to either
|
||||||
|
// 0 or 1/255 accordingly. If the input has been activated properly, then the
|
||||||
|
// values should always be in the range [0, 1]. But just in case it hasn't, to
|
||||||
|
// avoid category overflow issues when the activation function is not properly
|
||||||
|
// chosen, we add an extra clamp here, as performance hit is minimal.
|
||||||
|
float category = clamp(floor(input_val + 0.5), 0.0, 1.0);
|
||||||
|
gl_FragColor = vec4(category / 255.0, 0.0, 0.0, 1.0);
|
||||||
|
})";
|
||||||
|
|
||||||
// Softmax is in 3 steps:
|
// Softmax is in 3 steps:
|
||||||
// - First we find max over all masks
|
// - First we find max over all masks
|
||||||
// - Then we transform all masks to be exp(val - maxval), and also add to
|
// - Then we transform all masks to be exp(val - maxval), and also add to
|
||||||
|
@ -377,11 +399,14 @@ absl::Status SegmentationPostprocessorGl::GlInit() {
|
||||||
"softmax normalization", kNormalizationShader,
|
"softmax normalization", kNormalizationShader,
|
||||||
{"sum_texture", "current_chunk"}, &softmax_normalization_shader_));
|
{"sum_texture", "current_chunk"}, &softmax_normalization_shader_));
|
||||||
|
|
||||||
// Category mask shaders (Argmax)
|
// Category mask shaders (Argmax and special 1-class fg/bg argmax)
|
||||||
MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
|
MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
|
||||||
"argmax", kArgmaxShader,
|
"argmax", kArgmaxShader,
|
||||||
{"prev_max_texture", "current_chunk", "num_channels", "argmax_offset"},
|
{"prev_max_texture", "current_chunk", "num_channels", "argmax_offset"},
|
||||||
&argmax_shader_));
|
&argmax_shader_));
|
||||||
|
MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
|
||||||
|
"one-class argmax", kArgmaxOneClassShader, {"input_texture"},
|
||||||
|
&argmax_one_class_shader_));
|
||||||
|
|
||||||
// Split shader. This is created separately since it uses a custom vertex
|
// Split shader. This is created separately since it uses a custom vertex
|
||||||
// shader. TODO: Refactor so this shares common init code as well.
|
// shader. TODO: Refactor so this shares common init code as well.
|
||||||
|
@ -646,10 +671,25 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape,
|
||||||
|
|
||||||
std::vector<GlTexture> outputs;
|
std::vector<GlTexture> outputs;
|
||||||
if (is_category_mask) {
|
if (is_category_mask) {
|
||||||
// Step 3: For CATEGORY, apply argmax shader iteratively with each chunk
|
// Step 3, N = 1: For CATEGORY with 1 class, use special FG/BG argmax
|
||||||
// to get a 2-channel texture representing "combined maxval" and "argmax",
|
// shader instead of our usual N-class system.
|
||||||
// and then slice off the second channel for the category mask output,
|
if (num_outputs == 1) {
|
||||||
// using our usual channel_select program.
|
outputs.push_back(helper_.CreateDestinationTexture(
|
||||||
|
output_width, output_height, final_output_format));
|
||||||
|
helper_.BindFramebuffer(outputs.back());
|
||||||
|
glUseProgram(argmax_one_class_shader_.program);
|
||||||
|
glUniform1i(argmax_one_class_shader_.uniforms["input_texture"], 1);
|
||||||
|
glActiveTexture(GL_TEXTURE1);
|
||||||
|
// Only one chunk, and softmax cannot be applied to 1-class models
|
||||||
|
// anyways.
|
||||||
|
glBindTexture(GL_TEXTURE_2D, chunks[0].name());
|
||||||
|
glClear(GL_COLOR_BUFFER_BIT);
|
||||||
|
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
|
||||||
|
} else {
|
||||||
|
// Step 3, N > 1: For CATEGORY with N classes, apply argmax shader
|
||||||
|
// iteratively with each chunk to get a 2-channel texture representing
|
||||||
|
// "combined maxval" and "argmax", and then slice off the second channel
|
||||||
|
// for the category mask output, using our usual channel_select program.
|
||||||
glUseProgram(argmax_shader_.program);
|
glUseProgram(argmax_shader_.program);
|
||||||
glUniform1i(argmax_shader_.uniforms["current_chunk"], 1);
|
glUniform1i(argmax_shader_.uniforms["current_chunk"], 1);
|
||||||
glUniform1i(argmax_shader_.uniforms["prev_max_texture"], 2);
|
glUniform1i(argmax_shader_.uniforms["prev_max_texture"], 2);
|
||||||
|
@ -702,6 +742,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape,
|
||||||
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
|
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
|
||||||
glClear(GL_COLOR_BUFFER_BIT);
|
glClear(GL_COLOR_BUFFER_BIT);
|
||||||
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
|
glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
|
// Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
|
||||||
// final textures.
|
// final textures.
|
||||||
|
@ -760,6 +801,7 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() {
|
||||||
|
|
||||||
glDeleteProgram(activation_shader_.program);
|
glDeleteProgram(activation_shader_.program);
|
||||||
glDeleteProgram(argmax_shader_.program);
|
glDeleteProgram(argmax_shader_.program);
|
||||||
|
glDeleteProgram(argmax_one_class_shader_.program);
|
||||||
glDeleteProgram(channel_select_shader_.program);
|
glDeleteProgram(channel_select_shader_.program);
|
||||||
glDeleteProgram(softmax_max_shader_.program);
|
glDeleteProgram(softmax_max_shader_.program);
|
||||||
glDeleteProgram(softmax_transform_and_sum_shader_.program);
|
glDeleteProgram(softmax_transform_and_sum_shader_.program);
|
||||||
|
|
|
@ -63,6 +63,7 @@ class SegmentationPostprocessorGl {
|
||||||
|
|
||||||
GlShader activation_shader_;
|
GlShader activation_shader_;
|
||||||
GlShader argmax_shader_;
|
GlShader argmax_shader_;
|
||||||
|
GlShader argmax_one_class_shader_;
|
||||||
GlShader channel_select_shader_;
|
GlShader channel_select_shader_;
|
||||||
GlShader softmax_max_shader_;
|
GlShader softmax_max_shader_;
|
||||||
GlShader softmax_transform_and_sum_shader_;
|
GlShader softmax_transform_and_sum_shader_;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user