Special treatment for 1-class segmentation category mask output on GPU.
PiperOrigin-RevId: 523271622
This commit is contained in:
		
							parent
							
								
									87ec846ed6
								
							
						
					
					
						commit
						3bb411e99d
					
				| 
						 | 
					@ -185,6 +185,28 @@ void main() {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
})";
 | 
					})";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Special argmax shader for N=1 classes. We don't need to worry about softmax
 | 
				
			||||||
 | 
					// activation (it is assumed softmax requires N > 1 classes), but this should
 | 
				
			||||||
 | 
					// occur after SIGMOID activation if specified. Instead of a true argmax, we
 | 
				
			||||||
 | 
					// simply use 0.5 as the cutoff, assigning 1 (foreground) or 0 (background)
 | 
				
			||||||
 | 
					// based on whether the confidence value reaches this cutoff or not,
 | 
				
			||||||
 | 
					// respectively.
 | 
				
			||||||
 | 
					static constexpr char kArgmaxOneClassShader[] = R"(
 | 
				
			||||||
 | 
					DEFAULT_PRECISION(mediump, float)
 | 
				
			||||||
 | 
					in vec2 sample_coordinate;
 | 
				
			||||||
 | 
					uniform sampler2D input_texture;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void main() {
 | 
				
			||||||
 | 
					  float input_val = texture2D(input_texture, sample_coordinate).x;
 | 
				
			||||||
 | 
					  // Category is just value rounded to nearest integer; then we map to either
 | 
				
			||||||
 | 
					  // 0 or 1/255 accordingly. If the input has been activated properly, then the
 | 
				
			||||||
 | 
					  // values should always be in the range [0, 1]. But just in case it hasn't, to
 | 
				
			||||||
 | 
					  // avoid category overflow issues when the activation function is not properly
 | 
				
			||||||
 | 
					  // chosen, we add an extra clamp here, as performance hit is minimal.
 | 
				
			||||||
 | 
					  float category = clamp(floor(input_val + 0.5), 0.0, 1.0);
 | 
				
			||||||
 | 
					  gl_FragColor = vec4(category / 255.0, 0.0, 0.0, 1.0);
 | 
				
			||||||
 | 
					})";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Softmax is in 3 steps:
 | 
					// Softmax is in 3 steps:
 | 
				
			||||||
// - First we find max over all masks
 | 
					// - First we find max over all masks
 | 
				
			||||||
// - Then we transform all masks to be exp(val - maxval), and also add to
 | 
					// - Then we transform all masks to be exp(val - maxval), and also add to
 | 
				
			||||||
| 
						 | 
					@ -377,11 +399,14 @@ absl::Status SegmentationPostprocessorGl::GlInit() {
 | 
				
			||||||
        "softmax normalization", kNormalizationShader,
 | 
					        "softmax normalization", kNormalizationShader,
 | 
				
			||||||
        {"sum_texture", "current_chunk"}, &softmax_normalization_shader_));
 | 
					        {"sum_texture", "current_chunk"}, &softmax_normalization_shader_));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Category mask shaders (Argmax)
 | 
					    // Category mask shaders (Argmax and special 1-class fg/bg argmax)
 | 
				
			||||||
    MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
 | 
					    MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
 | 
				
			||||||
        "argmax", kArgmaxShader,
 | 
					        "argmax", kArgmaxShader,
 | 
				
			||||||
        {"prev_max_texture", "current_chunk", "num_channels", "argmax_offset"},
 | 
					        {"prev_max_texture", "current_chunk", "num_channels", "argmax_offset"},
 | 
				
			||||||
        &argmax_shader_));
 | 
					        &argmax_shader_));
 | 
				
			||||||
 | 
					    MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram(
 | 
				
			||||||
 | 
					        "one-class argmax", kArgmaxOneClassShader, {"input_texture"},
 | 
				
			||||||
 | 
					        &argmax_one_class_shader_));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Split shader. This is created separately since it uses a custom vertex
 | 
					    // Split shader. This is created separately since it uses a custom vertex
 | 
				
			||||||
    // shader. TODO: Refactor so this shares common init code as well.
 | 
					    // shader. TODO: Refactor so this shares common init code as well.
 | 
				
			||||||
| 
						 | 
					@ -646,62 +671,78 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<GlTexture> outputs;
 | 
					    std::vector<GlTexture> outputs;
 | 
				
			||||||
    if (is_category_mask) {
 | 
					    if (is_category_mask) {
 | 
				
			||||||
      // Step 3: For CATEGORY, apply argmax shader iteratively with each chunk
 | 
					      // Step 3, N = 1: For CATEGORY with 1 class, use special FG/BG argmax
 | 
				
			||||||
      // to get a 2-channel texture representing "combined maxval" and "argmax",
 | 
					      // shader instead of our usual N-class system.
 | 
				
			||||||
      // and then slice off the second channel for the category mask output,
 | 
					      if (num_outputs == 1) {
 | 
				
			||||||
      // using our usual channel_select program.
 | 
					        outputs.push_back(helper_.CreateDestinationTexture(
 | 
				
			||||||
      glUseProgram(argmax_shader_.program);
 | 
					            output_width, output_height, final_output_format));
 | 
				
			||||||
      glUniform1i(argmax_shader_.uniforms["current_chunk"], 1);
 | 
					        helper_.BindFramebuffer(outputs.back());
 | 
				
			||||||
      glUniform1i(argmax_shader_.uniforms["prev_max_texture"], 2);
 | 
					        glUseProgram(argmax_one_class_shader_.program);
 | 
				
			||||||
 | 
					        glUniform1i(argmax_one_class_shader_.uniforms["input_texture"], 1);
 | 
				
			||||||
      GlTexture max_texture = helper_.CreateDestinationTexture(
 | 
					 | 
				
			||||||
          output_width, output_height, chunk_output_format);
 | 
					 | 
				
			||||||
      GlTexture next_max_texture = helper_.CreateDestinationTexture(
 | 
					 | 
				
			||||||
          output_width, output_height, chunk_output_format);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      // GLSL uses IEEE 754 single-precision floating-point for encoding its
 | 
					 | 
				
			||||||
      // floats (at least for number representation, although not necessarily
 | 
					 | 
				
			||||||
      // for operations). So we can clear to a reasonable minimum float value
 | 
					 | 
				
			||||||
      // accordingly.
 | 
					 | 
				
			||||||
      const float kFloatMin32 = -3.402823466e+38;
 | 
					 | 
				
			||||||
      glClearColor(kFloatMin32, -1.0, 0.0, 1.0);
 | 
					 | 
				
			||||||
      helper_.BindFramebuffer(max_texture);
 | 
					 | 
				
			||||||
      glClear(GL_COLOR_BUFFER_BIT);
 | 
					 | 
				
			||||||
      // Set our clear color back to a "normal" default.
 | 
					 | 
				
			||||||
      glClearColor(0.0, 0.0, 0.0, 0.0);
 | 
					 | 
				
			||||||
      for (int i = 0; i < num_chunks; ++i) {
 | 
					 | 
				
			||||||
        int num_channels = 4;
 | 
					 | 
				
			||||||
        if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4;
 | 
					 | 
				
			||||||
        glUniform1i(argmax_shader_.uniforms["num_channels"], num_channels);
 | 
					 | 
				
			||||||
        glUniform1i(argmax_shader_.uniforms["argmax_offset"], i * 4);
 | 
					 | 
				
			||||||
        helper_.BindFramebuffer(next_max_texture);
 | 
					 | 
				
			||||||
        glActiveTexture(GL_TEXTURE2);
 | 
					 | 
				
			||||||
        glBindTexture(GL_TEXTURE_2D, max_texture.name());
 | 
					 | 
				
			||||||
        glActiveTexture(GL_TEXTURE1);
 | 
					        glActiveTexture(GL_TEXTURE1);
 | 
				
			||||||
        glBindTexture(GL_TEXTURE_2D, chunks[i].name());
 | 
					        // Only one chunk, and softmax cannot be applied to 1-class models
 | 
				
			||||||
        // TODO: We probably don't actually need all these clears.
 | 
					        // anyways.
 | 
				
			||||||
 | 
					        glBindTexture(GL_TEXTURE_2D, chunks[0].name());
 | 
				
			||||||
        glClear(GL_COLOR_BUFFER_BIT);
 | 
					        glClear(GL_COLOR_BUFFER_BIT);
 | 
				
			||||||
        glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
 | 
					        glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        // Step 3, N > 1: For CATEGORY with N classes, apply argmax shader
 | 
				
			||||||
 | 
					        // iteratively with each chunk to get a 2-channel texture representing
 | 
				
			||||||
 | 
					        // "combined maxval" and "argmax", and then slice off the second channel
 | 
				
			||||||
 | 
					        // for the category mask output, using our usual channel_select program.
 | 
				
			||||||
 | 
					        glUseProgram(argmax_shader_.program);
 | 
				
			||||||
 | 
					        glUniform1i(argmax_shader_.uniforms["current_chunk"], 1);
 | 
				
			||||||
 | 
					        glUniform1i(argmax_shader_.uniforms["prev_max_texture"], 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Put results into max_texture, so we can repeat the process easily.
 | 
					        GlTexture max_texture = helper_.CreateDestinationTexture(
 | 
				
			||||||
        std::swap(max_texture, next_max_texture);
 | 
					            output_width, output_height, chunk_output_format);
 | 
				
			||||||
 | 
					        GlTexture next_max_texture = helper_.CreateDestinationTexture(
 | 
				
			||||||
 | 
					            output_width, output_height, chunk_output_format);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // GLSL uses IEEE 754 single-precision floating-point for encoding its
 | 
				
			||||||
 | 
					        // floats (at least for number representation, although not necessarily
 | 
				
			||||||
 | 
					        // for operations). So we can clear to a reasonable minimum float value
 | 
				
			||||||
 | 
					        // accordingly.
 | 
				
			||||||
 | 
					        const float kFloatMin32 = -3.402823466e+38;
 | 
				
			||||||
 | 
					        glClearColor(kFloatMin32, -1.0, 0.0, 1.0);
 | 
				
			||||||
 | 
					        helper_.BindFramebuffer(max_texture);
 | 
				
			||||||
 | 
					        glClear(GL_COLOR_BUFFER_BIT);
 | 
				
			||||||
 | 
					        // Set our clear color back to a "normal" default.
 | 
				
			||||||
 | 
					        glClearColor(0.0, 0.0, 0.0, 0.0);
 | 
				
			||||||
 | 
					        for (int i = 0; i < num_chunks; ++i) {
 | 
				
			||||||
 | 
					          int num_channels = 4;
 | 
				
			||||||
 | 
					          if ((i + 1) * 4 > num_outputs) num_channels = num_outputs % 4;
 | 
				
			||||||
 | 
					          glUniform1i(argmax_shader_.uniforms["num_channels"], num_channels);
 | 
				
			||||||
 | 
					          glUniform1i(argmax_shader_.uniforms["argmax_offset"], i * 4);
 | 
				
			||||||
 | 
					          helper_.BindFramebuffer(next_max_texture);
 | 
				
			||||||
 | 
					          glActiveTexture(GL_TEXTURE2);
 | 
				
			||||||
 | 
					          glBindTexture(GL_TEXTURE_2D, max_texture.name());
 | 
				
			||||||
 | 
					          glActiveTexture(GL_TEXTURE1);
 | 
				
			||||||
 | 
					          glBindTexture(GL_TEXTURE_2D, chunks[i].name());
 | 
				
			||||||
 | 
					          // TODO: We probably don't actually need all these clears.
 | 
				
			||||||
 | 
					          glClear(GL_COLOR_BUFFER_BIT);
 | 
				
			||||||
 | 
					          glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // Put results into max_texture, so we can repeat the process easily.
 | 
				
			||||||
 | 
					          std::swap(max_texture, next_max_texture);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Do final channel-select on max_texture below, selecting for argmax
 | 
				
			||||||
 | 
					        outputs.push_back(helper_.CreateDestinationTexture(
 | 
				
			||||||
 | 
					            output_width, output_height, final_output_format));
 | 
				
			||||||
 | 
					        helper_.BindFramebuffer(outputs.back());
 | 
				
			||||||
 | 
					        glUseProgram(channel_select_shader_.program);
 | 
				
			||||||
 | 
					        glUniform1i(channel_select_shader_.uniforms["input_texture"], 1);
 | 
				
			||||||
 | 
					        // 0:max_val, 1:argmax
 | 
				
			||||||
 | 
					        glUniform1i(channel_select_shader_.uniforms["channel_select"], 1);
 | 
				
			||||||
 | 
					        glBindTexture(GL_TEXTURE_2D, max_texture.name());
 | 
				
			||||||
 | 
					        // We can't interpolate across argmax values, so we disable linear
 | 
				
			||||||
 | 
					        // interpolation there for this upsampling step.
 | 
				
			||||||
 | 
					        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
 | 
				
			||||||
 | 
					        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
 | 
				
			||||||
 | 
					        glClear(GL_COLOR_BUFFER_BIT);
 | 
				
			||||||
 | 
					        glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					 | 
				
			||||||
      // Do final channel-select on max_texture below, selecting for argmax
 | 
					 | 
				
			||||||
      outputs.push_back(helper_.CreateDestinationTexture(
 | 
					 | 
				
			||||||
          output_width, output_height, final_output_format));
 | 
					 | 
				
			||||||
      helper_.BindFramebuffer(outputs.back());
 | 
					 | 
				
			||||||
      glUseProgram(channel_select_shader_.program);
 | 
					 | 
				
			||||||
      glUniform1i(channel_select_shader_.uniforms["input_texture"], 1);
 | 
					 | 
				
			||||||
      // 0:max_val, 1:argmax
 | 
					 | 
				
			||||||
      glUniform1i(channel_select_shader_.uniforms["channel_select"], 1);
 | 
					 | 
				
			||||||
      glBindTexture(GL_TEXTURE_2D, max_texture.name());
 | 
					 | 
				
			||||||
      // We can't interpolate across argmax values, so we disable linear
 | 
					 | 
				
			||||||
      // interpolation there for this upsampling step.
 | 
					 | 
				
			||||||
      glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
 | 
					 | 
				
			||||||
      glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
 | 
					 | 
				
			||||||
      glClear(GL_COLOR_BUFFER_BIT);
 | 
					 | 
				
			||||||
      glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
 | 
					 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      // Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
 | 
					      // Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
 | 
				
			||||||
      // final textures.
 | 
					      // final textures.
 | 
				
			||||||
| 
						 | 
					@ -760,6 +801,7 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    glDeleteProgram(activation_shader_.program);
 | 
					    glDeleteProgram(activation_shader_.program);
 | 
				
			||||||
    glDeleteProgram(argmax_shader_.program);
 | 
					    glDeleteProgram(argmax_shader_.program);
 | 
				
			||||||
 | 
					    glDeleteProgram(argmax_one_class_shader_.program);
 | 
				
			||||||
    glDeleteProgram(channel_select_shader_.program);
 | 
					    glDeleteProgram(channel_select_shader_.program);
 | 
				
			||||||
    glDeleteProgram(softmax_max_shader_.program);
 | 
					    glDeleteProgram(softmax_max_shader_.program);
 | 
				
			||||||
    glDeleteProgram(softmax_transform_and_sum_shader_.program);
 | 
					    glDeleteProgram(softmax_transform_and_sum_shader_.program);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,6 +63,7 @@ class SegmentationPostprocessorGl {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  GlShader activation_shader_;
 | 
					  GlShader activation_shader_;
 | 
				
			||||||
  GlShader argmax_shader_;
 | 
					  GlShader argmax_shader_;
 | 
				
			||||||
 | 
					  GlShader argmax_one_class_shader_;
 | 
				
			||||||
  GlShader channel_select_shader_;
 | 
					  GlShader channel_select_shader_;
 | 
				
			||||||
  GlShader softmax_max_shader_;
 | 
					  GlShader softmax_max_shader_;
 | 
				
			||||||
  GlShader softmax_transform_and_sum_shader_;
 | 
					  GlShader softmax_transform_and_sum_shader_;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user