Add iOS support for MPMask

PiperOrigin-RevId: 534155657
This commit is contained in:
Sebastian Schmidt 2023-05-22 12:57:32 -07:00 committed by Copybara-Service
parent 102cffdf4c
commit 51730ec25c
3 changed files with 71 additions and 14 deletions

View File

@ -62,7 +62,10 @@ jasmine_node_test(
mediapipe_ts_library( mediapipe_ts_library(
name = "mask", name = "mask",
srcs = ["mask.ts"], srcs = ["mask.ts"],
deps = [":image"], deps = [
":image",
"//mediapipe/web/graph_runner:platform_utils",
],
) )
mediapipe_ts_library( mediapipe_ts_library(

View File

@ -15,6 +15,7 @@
*/ */
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {isIOS} from '../../../../web/graph_runner/platform_utils';
/** Number of instances a user can keep alive before we raise a warning. */ /** Number of instances a user can keep alive before we raise a warning. */
const INSTANCE_COUNT_WARNING_THRESHOLD = 250; const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
@ -32,6 +33,8 @@ enum MPMaskType {
/** The supported mask formats. For internal usage. */ /** The supported mask formats. For internal usage. */
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
/** /**
* The wrapper class for MediaPipe segmentation masks. * The wrapper class for MediaPipe segmentation masks.
* *
@ -56,6 +59,9 @@ export class MPMask {
*/ */
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD; private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
/** The format used to write pixel values from textures. */
private static texImage2DFormat?: GLenum;
/** @hideconstructor */ /** @hideconstructor */
constructor( constructor(
private readonly containers: MPMaskContainer[], private readonly containers: MPMaskContainer[],
@ -127,6 +133,29 @@ export class MPMask {
return this.convertToWebGLTexture(); return this.convertToWebGLTexture();
} }
/**
* Returns the texture format used for writing float textures on this
* platform.
*/
getTexImage2DFormat(): GLenum {
const gl = this.getGL();
if (!MPMask.texImage2DFormat) {
// Note: This is the same check we use in
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
if (gl.getExtension('EXT_color_buffer_float') &&
gl.getExtension('OES_texture_float_linear') &&
gl.getExtension('EXT_float_blend')) {
MPMask.texImage2DFormat = gl.R32F;
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
MPMask.texImage2DFormat = gl.R16F;
} else {
throw new Error(
'GPU does not fully support 4-channel float32 or float16 formats');
}
}
return MPMask.texImage2DFormat;
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined; private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -176,8 +205,9 @@ export class MPMask {
assertNotNull(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer); gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams(); this.configureTextureParams();
const format = this.getTexImage2DFormat();
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, null); gl.FLOAT, null);
gl.bindTexture(gl.TEXTURE_2D, null); gl.bindTexture(gl.TEXTURE_2D, null);
@ -208,7 +238,7 @@ export class MPMask {
if (!this.canvas) { if (!this.canvas) {
throw new Error( throw new Error(
'Conversion to different image formats require that a canvas ' + 'Conversion to different image formats require that a canvas ' +
'is passed when iniitializing the image.'); 'is passed when initializing the image.');
} }
if (!this.gl) { if (!this.gl) {
this.gl = assertNotNull( this.gl = assertNotNull(
@ -216,11 +246,6 @@ export class MPMask {
'You cannot use a canvas that is already bound to a different ' + 'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.'); 'type of rendering context.');
} }
const ext = this.gl.getExtension('EXT_color_buffer_float');
if (!ext) {
// TODO: Ensure this works on iOS
throw new Error('Missing required EXT_color_buffer_float extension');
}
return this.gl; return this.gl;
} }
@ -238,18 +263,34 @@ export class MPMask {
if (uint8Array) { if (uint8Array) {
float32Array = new Float32Array(uint8Array).map(v => v / 255); float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else { } else {
float32Array = new Float32Array(this.width * this.height);
const gl = this.getGL(); const gl = this.getGL();
const shaderContext = this.getShaderContext(); const shaderContext = this.getShaderContext();
float32Array = new Float32Array(this.width * this.height);
// Create texture if needed // Create texture if needed
const webGlTexture = this.convertToWebGLTexture(); const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels // Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture); shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array); if (isIOS()) {
shaderContext.unbindFramebuffer(); // WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
// Uint16Array, however, and requires a manual bitwise conversion from
// Uint16 to floating point numbers. This conversion is more expensive
// that reading back a Float32Array from the RGBA image and dropping
// the superfluous data, so we do this instead.
const outputArray = new Float32Array(this.width * this.height * 4);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
float32Array[i] = outputArray[j];
}
} else {
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
}
} }
this.containers.push(float32Array); this.containers.push(float32Array);
} }
@ -274,9 +315,9 @@ export class MPMask {
webGLTexture = this.bindTexture(); webGLTexture = this.bindTexture();
const data = this.convertToFloat32Array(); const data = this.convertToFloat32Array();
// TODO: Add support for R16F to support iOS const format = this.getTexImage2DFormat();
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED, gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, data); gl.FLOAT, data);
this.unbindTexture(); this.unbindTexture();
} }

View File

@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
// it uses "CriOS". // it uses "CriOS".
return userAgent.includes('Safari') && !userAgent.includes('Chrome'); return userAgent.includes('Safari') && !userAgent.includes('Chrome');
} }
/** Detect if code is running on iOS. */
export function isIOS() {
// Source:
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
return [
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
'iPod'
// tslint:disable-next-line:deprecation
].includes(navigator.platform)
// iPad on iOS 13 detection
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
}