Special treatment for 1-class segmentation category mask output on GPU.
PiperOrigin-RevId: 523271622
This commit is contained in:
		
							parent
							
								
									87ec846ed6
								
							
						
					
					
						commit
						3bb411e99d
					
				|  | @ -185,6 +185,28 @@ void main() { | |||
|     } | ||||
| })"; | ||||
| 
 | ||||
| // Special argmax shader for N=1 classes. We don't need to worry about softmax
 | ||||
| // activation (it is assumed softmax requires N > 1 classes), but this should
 | ||||
| // occur after SIGMOID activation if specified. Instead of a true argmax, we
 | ||||
| // simply use 0.5 as the cutoff, assigning 1 (foreground) or 0 (background)
 | ||||
| // based on whether the confidence value reaches this cutoff or not,
 | ||||
| // respectively.
 | ||||
| static constexpr char kArgmaxOneClassShader[] = R"( | ||||
| DEFAULT_PRECISION(mediump, float) | ||||
| in vec2 sample_coordinate; | ||||
| uniform sampler2D input_texture; | ||||
| 
 | ||||
| void main() { | ||||
|   float input_val = texture2D(input_texture, sample_coordinate).x; | ||||
|   // Category is just value rounded to nearest integer; then we map to either
 | ||||
|   // 0 or 1/255 accordingly. If the input has been activated properly, then the
 | ||||
|   // values should always be in the range [0, 1]. But just in case it hasn't, to
 | ||||
|   // avoid category overflow issues when the activation function is not properly
 | ||||
|   // chosen, we add an extra clamp here, as performance hit is minimal.
 | ||||
|   float category = clamp(floor(input_val + 0.5), 0.0, 1.0); | ||||
|   gl_FragColor = vec4(category / 255.0, 0.0, 0.0, 1.0); | ||||
| })"; | ||||
| 
 | ||||
| // Softmax is in 3 steps:
 | ||||
| // - First we find max over all masks
 | ||||
| // - Then we transform all masks to be exp(val - maxval), and also add to
 | ||||
|  | @ -377,11 +399,14 @@ absl::Status SegmentationPostprocessorGl::GlInit() { | |||
|         "softmax normalization", kNormalizationShader, | ||||
|         {"sum_texture", "current_chunk"}, &softmax_normalization_shader_)); | ||||
| 
 | ||||
|     // Category mask shaders (Argmax)
 | ||||
|     // Category mask shaders (Argmax and special 1-class fg/bg argmax)
 | ||||
|     MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( | ||||
|         "argmax", kArgmaxShader, | ||||
|         {"prev_max_texture", "current_chunk", "num_channels", "argmax_offset"}, | ||||
|         &argmax_shader_)); | ||||
|     MP_RETURN_IF_ERROR(CreateBasicFragmentShaderProgram( | ||||
|         "one-class argmax", kArgmaxOneClassShader, {"input_texture"}, | ||||
|         &argmax_one_class_shader_)); | ||||
| 
 | ||||
|     // Split shader. This is created separately since it uses a custom vertex
 | ||||
|     // shader. TODO: Refactor so this shares common init code as well.
 | ||||
|  | @ -646,10 +671,25 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, | |||
| 
 | ||||
|     std::vector<GlTexture> outputs; | ||||
|     if (is_category_mask) { | ||||
|       // Step 3: For CATEGORY, apply argmax shader iteratively with each chunk
 | ||||
|       // to get a 2-channel texture representing "combined maxval" and "argmax",
 | ||||
|       // and then slice off the second channel for the category mask output,
 | ||||
|       // using our usual channel_select program.
 | ||||
|       // Step 3, N = 1: For CATEGORY with 1 class, use special FG/BG argmax
 | ||||
|       // shader instead of our usual N-class system.
 | ||||
|       if (num_outputs == 1) { | ||||
|         outputs.push_back(helper_.CreateDestinationTexture( | ||||
|             output_width, output_height, final_output_format)); | ||||
|         helper_.BindFramebuffer(outputs.back()); | ||||
|         glUseProgram(argmax_one_class_shader_.program); | ||||
|         glUniform1i(argmax_one_class_shader_.uniforms["input_texture"], 1); | ||||
|         glActiveTexture(GL_TEXTURE1); | ||||
|         // Only one chunk, and softmax cannot be applied to 1-class models
 | ||||
|         // anyways.
 | ||||
|         glBindTexture(GL_TEXTURE_2D, chunks[0].name()); | ||||
|         glClear(GL_COLOR_BUFFER_BIT); | ||||
|         glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); | ||||
|       } else { | ||||
|         // Step 3, N > 1: For CATEGORY with N classes, apply argmax shader
 | ||||
|         // iteratively with each chunk to get a 2-channel texture representing
 | ||||
|         // "combined maxval" and "argmax", and then slice off the second channel
 | ||||
|         // for the category mask output, using our usual channel_select program.
 | ||||
|         glUseProgram(argmax_shader_.program); | ||||
|         glUniform1i(argmax_shader_.uniforms["current_chunk"], 1); | ||||
|         glUniform1i(argmax_shader_.uniforms["prev_max_texture"], 2); | ||||
|  | @ -702,6 +742,7 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(const Shape& input_shape, | |||
|         glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); | ||||
|         glClear(GL_COLOR_BUFFER_BIT); | ||||
|         glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); | ||||
|       } | ||||
|     } else { | ||||
|       // Step 3: For CONFIDENCE, apply channel-select repeatedly to extract
 | ||||
|       // final textures.
 | ||||
|  | @ -760,6 +801,7 @@ SegmentationPostprocessorGl::~SegmentationPostprocessorGl() { | |||
| 
 | ||||
|     glDeleteProgram(activation_shader_.program); | ||||
|     glDeleteProgram(argmax_shader_.program); | ||||
|     glDeleteProgram(argmax_one_class_shader_.program); | ||||
|     glDeleteProgram(channel_select_shader_.program); | ||||
|     glDeleteProgram(softmax_max_shader_.program); | ||||
|     glDeleteProgram(softmax_transform_and_sum_shader_.program); | ||||
|  |  | |||
|  | @ -63,6 +63,7 @@ class SegmentationPostprocessorGl { | |||
| 
 | ||||
|   GlShader activation_shader_; | ||||
|   GlShader argmax_shader_; | ||||
|   GlShader argmax_one_class_shader_; | ||||
|   GlShader channel_select_shader_; | ||||
|   GlShader softmax_max_shader_; | ||||
|   GlShader softmax_transform_and_sum_shader_; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user