diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 325603353..fa28e04a5 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -62,7 +62,10 @@ jasmine_node_test( mediapipe_ts_library( name = "mask", srcs = ["mask.ts"], - deps = [":image"], + deps = [ + ":image", + "//mediapipe/web/graph_runner:platform_utils", + ], ) mediapipe_ts_library( diff --git a/mediapipe/tasks/web/vision/core/mask.ts b/mediapipe/tasks/web/vision/core/mask.ts index 3f37e804f..9622b638f 100644 --- a/mediapipe/tasks/web/vision/core/mask.ts +++ b/mediapipe/tasks/web/vision/core/mask.ts @@ -15,6 +15,7 @@ */ import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {isIOS} from '../../../../web/graph_runner/platform_utils'; /** Number of instances a user can keep alive before we raise a warning. */ const INSTANCE_COUNT_WARNING_THRESHOLD = 250; @@ -32,6 +33,8 @@ enum MPMaskType { /** The supported mask formats. For internal usage. */ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; + + /** * The wrapper class for MediaPipe segmentation masks. * @@ -56,6 +59,9 @@ export class MPMask { */ private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD; + /** The format used to write pixel values from textures. */ + private static texImage2DFormat?: GLenum; + /** @hideconstructor */ constructor( private readonly containers: MPMaskContainer[], @@ -127,6 +133,29 @@ export class MPMask { return this.convertToWebGLTexture(); } + /** + * Returns the texture format used for writing float textures on this + * platform. + */ + getTexImage2DFormat(): GLenum { + const gl = this.getGL(); + if (!MPMask.texImage2DFormat) { + // Note: This is the same check we use in + // `SegmentationPostprocessorGl::GetSegmentationResultGpu()`. + if (gl.getExtension('EXT_color_buffer_float') && + gl.getExtension('OES_texture_float_linear') && + gl.getExtension('EXT_float_blend')) { + MPMask.texImage2DFormat = gl.R32F; + } else if (gl.getExtension('EXT_color_buffer_half_float')) { + MPMask.texImage2DFormat = gl.R16F; + } else { + throw new Error( + 'GPU does not fully support 4-channel float32 or float16 formats'); + } + } + return MPMask.texImage2DFormat; + } + private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined; @@ -176,8 +205,9 @@ export class MPMask { assertNotNull(gl.createTexture(), 'Failed to create texture'); gl.bindTexture(gl.TEXTURE_2D, destinationContainer); this.configureTextureParams(); + const format = this.getTexImage2DFormat(); gl.texImage2D( - gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, + gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED, gl.FLOAT, null); gl.bindTexture(gl.TEXTURE_2D, null); @@ -208,7 +238,7 @@ export class MPMask { if (!this.canvas) { throw new Error( 'Conversion to different image formats require that a canvas ' + - 'is passed when iniitializing the image.'); + 'is passed when initializing the image.'); } if (!this.gl) { this.gl = assertNotNull( @@ -216,11 +246,6 @@ export class MPMask { 'You cannot use a canvas that is already bound to a different ' + 'type of rendering context.'); } - const ext = this.gl.getExtension('EXT_color_buffer_float'); - if (!ext) { - // TODO: Ensure this works on iOS - throw new Error('Missing required EXT_color_buffer_float extension'); - } return this.gl; } @@ -238,18 +263,34 @@ export class MPMask { if (uint8Array) { float32Array = new Float32Array(uint8Array).map(v => v / 255); } else { + float32Array = new Float32Array(this.width * this.height); + const gl = this.getGL(); const shaderContext = this.getShaderContext(); - float32Array = new Float32Array(this.width * this.height); // Create texture if needed const webGlTexture = this.convertToWebGLTexture(); // Create a framebuffer from the texture and read back pixels shaderContext.bindFramebuffer(gl, webGlTexture); - gl.readPixels( - 0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); - shaderContext.unbindFramebuffer(); + + if (isIOS()) { + // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads + // (as tested on iOS 16.4). HALF_FLOAT requires reading data into a + // Uint16Array, however, and requires a manual bitwise conversion from + // Uint16 to floating point numbers. This conversion is more expensive + // that reading back a Float32Array from the RGBA image and dropping + // the superfluous data, so we do this instead. + const outputArray = new Float32Array(this.width * this.height * 4); + gl.readPixels( + 0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray); + for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) { + float32Array[i] = outputArray[j]; + } + } else { + gl.readPixels( + 0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); + } } this.containers.push(float32Array); } @@ -274,9 +315,9 @@ export class MPMask { webGLTexture = this.bindTexture(); const data = this.convertToFloat32Array(); - // TODO: Add support for R16F to support iOS + const format = this.getTexImage2DFormat(); gl.texImage2D( - gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, + gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED, gl.FLOAT, data); this.unbindTexture(); } diff --git a/mediapipe/web/graph_runner/platform_utils.ts b/mediapipe/web/graph_runner/platform_utils.ts index 71239abab..7e1decf34 100644 --- a/mediapipe/web/graph_runner/platform_utils.ts +++ b/mediapipe/web/graph_runner/platform_utils.ts @@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) { // it uses "CriOS". return userAgent.includes('Safari') && !userAgent.includes('Chrome'); } + +/** Detect if code is running on iOS. */ +export function isIOS() { + // Source: + // https://stackoverflow.com/questions/9038625/detect-if-device-is-ios + return [ + 'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone', + 'iPod' + // tslint:disable-next-line:deprecation + ].includes(navigator.platform) + // iPad on iOS 13 detection + || (navigator.userAgent.includes('Mac') && 'ontouchend' in document); +}