Add iOS support for MPMask
PiperOrigin-RevId: 534155657
This commit is contained in:
		
							parent
							
								
									102cffdf4c
								
							
						
					
					
						commit
						51730ec25c
					
				| 
						 | 
					@ -62,7 +62,10 @@ jasmine_node_test(
 | 
				
			||||||
mediapipe_ts_library(
 | 
					mediapipe_ts_library(
 | 
				
			||||||
    name = "mask",
 | 
					    name = "mask",
 | 
				
			||||||
    srcs = ["mask.ts"],
 | 
					    srcs = ["mask.ts"],
 | 
				
			||||||
    deps = [":image"],
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":image",
 | 
				
			||||||
 | 
					        "//mediapipe/web/graph_runner:platform_utils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mediapipe_ts_library(
 | 
					mediapipe_ts_library(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
 | 
					import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
 | 
				
			||||||
 | 
					import {isIOS} from '../../../../web/graph_runner/platform_utils';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Number of instances a user can keep alive before we raise a warning. */
 | 
					/** Number of instances a user can keep alive before we raise a warning. */
 | 
				
			||||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
 | 
					const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
 | 
				
			||||||
| 
						 | 
					@ -32,6 +33,8 @@ enum MPMaskType {
 | 
				
			||||||
/** The supported mask formats. For internal usage. */
 | 
					/** The supported mask formats. For internal usage. */
 | 
				
			||||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
 | 
					export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * The wrapper class for MediaPipe segmentation masks.
 | 
					 * The wrapper class for MediaPipe segmentation masks.
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
| 
						 | 
					@ -56,6 +59,9 @@ export class MPMask {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
 | 
					  private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The format used to write pixel values from textures. */
 | 
				
			||||||
 | 
					  private static texImage2DFormat?: GLenum;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** @hideconstructor */
 | 
					  /** @hideconstructor */
 | 
				
			||||||
  constructor(
 | 
					  constructor(
 | 
				
			||||||
      private readonly containers: MPMaskContainer[],
 | 
					      private readonly containers: MPMaskContainer[],
 | 
				
			||||||
| 
						 | 
					@ -127,6 +133,29 @@ export class MPMask {
 | 
				
			||||||
    return this.convertToWebGLTexture();
 | 
					    return this.convertToWebGLTexture();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Returns the texture format used for writing float textures on this
 | 
				
			||||||
 | 
					   * platform.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  getTexImage2DFormat(): GLenum {
 | 
				
			||||||
 | 
					    const gl = this.getGL();
 | 
				
			||||||
 | 
					    if (!MPMask.texImage2DFormat) {
 | 
				
			||||||
 | 
					      // Note: This is the same check we use in
 | 
				
			||||||
 | 
					      // `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
 | 
				
			||||||
 | 
					      if (gl.getExtension('EXT_color_buffer_float') &&
 | 
				
			||||||
 | 
					          gl.getExtension('OES_texture_float_linear') &&
 | 
				
			||||||
 | 
					          gl.getExtension('EXT_float_blend')) {
 | 
				
			||||||
 | 
					        MPMask.texImage2DFormat = gl.R32F;
 | 
				
			||||||
 | 
					      } else if (gl.getExtension('EXT_color_buffer_half_float')) {
 | 
				
			||||||
 | 
					        MPMask.texImage2DFormat = gl.R16F;
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        throw new Error(
 | 
				
			||||||
 | 
					            'GPU does not fully support 4-channel float32 or float16 formats');
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return MPMask.texImage2DFormat;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
 | 
					  private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
 | 
				
			||||||
  private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
 | 
					  private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
 | 
				
			||||||
  private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
 | 
					  private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
 | 
				
			||||||
| 
						 | 
					@ -176,8 +205,9 @@ export class MPMask {
 | 
				
			||||||
            assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
					            assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
				
			||||||
        gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
 | 
					        gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
 | 
				
			||||||
        this.configureTextureParams();
 | 
					        this.configureTextureParams();
 | 
				
			||||||
 | 
					        const format = this.getTexImage2DFormat();
 | 
				
			||||||
        gl.texImage2D(
 | 
					        gl.texImage2D(
 | 
				
			||||||
            gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
 | 
					            gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
 | 
				
			||||||
            gl.FLOAT, null);
 | 
					            gl.FLOAT, null);
 | 
				
			||||||
        gl.bindTexture(gl.TEXTURE_2D, null);
 | 
					        gl.bindTexture(gl.TEXTURE_2D, null);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -208,7 +238,7 @@ export class MPMask {
 | 
				
			||||||
    if (!this.canvas) {
 | 
					    if (!this.canvas) {
 | 
				
			||||||
      throw new Error(
 | 
					      throw new Error(
 | 
				
			||||||
          'Conversion to different image formats require that a canvas ' +
 | 
					          'Conversion to different image formats require that a canvas ' +
 | 
				
			||||||
          'is passed when iniitializing the image.');
 | 
					          'is passed when initializing the image.');
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (!this.gl) {
 | 
					    if (!this.gl) {
 | 
				
			||||||
      this.gl = assertNotNull(
 | 
					      this.gl = assertNotNull(
 | 
				
			||||||
| 
						 | 
					@ -216,11 +246,6 @@ export class MPMask {
 | 
				
			||||||
          'You cannot use a canvas that is already bound to a different ' +
 | 
					          'You cannot use a canvas that is already bound to a different ' +
 | 
				
			||||||
              'type of rendering context.');
 | 
					              'type of rendering context.');
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    const ext = this.gl.getExtension('EXT_color_buffer_float');
 | 
					 | 
				
			||||||
    if (!ext) {
 | 
					 | 
				
			||||||
      // TODO: Ensure this works on iOS
 | 
					 | 
				
			||||||
      throw new Error('Missing required EXT_color_buffer_float extension');
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return this.gl;
 | 
					    return this.gl;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -238,18 +263,34 @@ export class MPMask {
 | 
				
			||||||
      if (uint8Array) {
 | 
					      if (uint8Array) {
 | 
				
			||||||
        float32Array = new Float32Array(uint8Array).map(v => v / 255);
 | 
					        float32Array = new Float32Array(uint8Array).map(v => v / 255);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
 | 
					        float32Array = new Float32Array(this.width * this.height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        const gl = this.getGL();
 | 
					        const gl = this.getGL();
 | 
				
			||||||
        const shaderContext = this.getShaderContext();
 | 
					        const shaderContext = this.getShaderContext();
 | 
				
			||||||
        float32Array = new Float32Array(this.width * this.height);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Create texture if needed
 | 
					        // Create texture if needed
 | 
				
			||||||
        const webGlTexture = this.convertToWebGLTexture();
 | 
					        const webGlTexture = this.convertToWebGLTexture();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Create a framebuffer from the texture and read back pixels
 | 
					        // Create a framebuffer from the texture and read back pixels
 | 
				
			||||||
        shaderContext.bindFramebuffer(gl, webGlTexture);
 | 
					        shaderContext.bindFramebuffer(gl, webGlTexture);
 | 
				
			||||||
        gl.readPixels(
 | 
					
 | 
				
			||||||
            0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
 | 
					        if (isIOS()) {
 | 
				
			||||||
        shaderContext.unbindFramebuffer();
 | 
					          // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
 | 
				
			||||||
 | 
					          // (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
 | 
				
			||||||
 | 
					          // Uint16Array, however, and requires a manual bitwise conversion from
 | 
				
			||||||
 | 
					          // Uint16 to floating point numbers. This conversion is more expensive
 | 
				
			||||||
 | 
					          // that reading back a Float32Array from the RGBA image and dropping
 | 
				
			||||||
 | 
					          // the superfluous data, so we do this instead.
 | 
				
			||||||
 | 
					          const outputArray = new Float32Array(this.width * this.height * 4);
 | 
				
			||||||
 | 
					          gl.readPixels(
 | 
				
			||||||
 | 
					              0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
 | 
				
			||||||
 | 
					          for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
 | 
				
			||||||
 | 
					            float32Array[i] = outputArray[j];
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          gl.readPixels(
 | 
				
			||||||
 | 
					              0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      this.containers.push(float32Array);
 | 
					      this.containers.push(float32Array);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -274,9 +315,9 @@ export class MPMask {
 | 
				
			||||||
      webGLTexture = this.bindTexture();
 | 
					      webGLTexture = this.bindTexture();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      const data = this.convertToFloat32Array();
 | 
					      const data = this.convertToFloat32Array();
 | 
				
			||||||
      // TODO: Add support for R16F to support iOS
 | 
					      const format = this.getTexImage2DFormat();
 | 
				
			||||||
      gl.texImage2D(
 | 
					      gl.texImage2D(
 | 
				
			||||||
          gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
 | 
					          gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
 | 
				
			||||||
          gl.FLOAT, data);
 | 
					          gl.FLOAT, data);
 | 
				
			||||||
      this.unbindTexture();
 | 
					      this.unbindTexture();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
 | 
				
			||||||
  // it uses "CriOS".
 | 
					  // it uses "CriOS".
 | 
				
			||||||
  return userAgent.includes('Safari') && !userAgent.includes('Chrome');
 | 
					  return userAgent.includes('Safari') && !userAgent.includes('Chrome');
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Detect if code is running on iOS. */
 | 
				
			||||||
 | 
					export function isIOS() {
 | 
				
			||||||
 | 
					  // Source:
 | 
				
			||||||
 | 
					  // https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
 | 
				
			||||||
 | 
					  return [
 | 
				
			||||||
 | 
					    'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
 | 
				
			||||||
 | 
					    'iPod'
 | 
				
			||||||
 | 
					    // tslint:disable-next-line:deprecation
 | 
				
			||||||
 | 
					  ].includes(navigator.platform)
 | 
				
			||||||
 | 
					      // iPad on iOS 13 detection
 | 
				
			||||||
 | 
					      || (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user