Add iOS support for MPMask

PiperOrigin-RevId: 534155657
This commit is contained in:
Sebastian Schmidt 2023-05-22 12:57:32 -07:00 committed by Copybara-Service
parent 102cffdf4c
commit 51730ec25c
3 changed files with 71 additions and 14 deletions

View File

@ -62,7 +62,10 @@ jasmine_node_test(
mediapipe_ts_library(
name = "mask",
srcs = ["mask.ts"],
deps = [":image"],
deps = [
":image",
"//mediapipe/web/graph_runner:platform_utils",
],
)
mediapipe_ts_library(

View File

@ -15,6 +15,7 @@
*/
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {isIOS} from '../../../../web/graph_runner/platform_utils';
/** Number of instances a user can keep alive before we raise a warning. */
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
@ -32,6 +33,8 @@ enum MPMaskType {
/** The supported mask formats. For internal usage. */
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
/**
* The wrapper class for MediaPipe segmentation masks.
*
@ -56,6 +59,9 @@ export class MPMask {
*/
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
/** The format used to write pixel values from textures. */
private static texImage2DFormat?: GLenum;
/** @hideconstructor */
constructor(
private readonly containers: MPMaskContainer[],
@ -127,6 +133,29 @@ export class MPMask {
return this.convertToWebGLTexture();
}
/**
* Returns the texture format used for writing float textures on this
* platform.
*/
getTexImage2DFormat(): GLenum {
const gl = this.getGL();
if (!MPMask.texImage2DFormat) {
// Note: This is the same check we use in
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
if (gl.getExtension('EXT_color_buffer_float') &&
gl.getExtension('OES_texture_float_linear') &&
gl.getExtension('EXT_float_blend')) {
MPMask.texImage2DFormat = gl.R32F;
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
MPMask.texImage2DFormat = gl.R16F;
} else {
throw new Error(
'GPU does not fully support 4-channel float32 or float16 formats');
}
}
return MPMask.texImage2DFormat;
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -176,8 +205,9 @@ export class MPMask {
assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
this.configureTextureParams();
const format = this.getTexImage2DFormat();
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, null);
gl.bindTexture(gl.TEXTURE_2D, null);
@ -208,7 +238,7 @@ export class MPMask {
if (!this.canvas) {
throw new Error(
'Conversion to different image formats require that a canvas ' +
'is passed when iniitializing the image.');
'is passed when initializing the image.');
}
if (!this.gl) {
this.gl = assertNotNull(
@ -216,11 +246,6 @@ export class MPMask {
'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.');
}
const ext = this.gl.getExtension('EXT_color_buffer_float');
if (!ext) {
// TODO: Ensure this works on iOS
throw new Error('Missing required EXT_color_buffer_float extension');
}
return this.gl;
}
@ -238,18 +263,34 @@ export class MPMask {
if (uint8Array) {
float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else {
float32Array = new Float32Array(this.width * this.height);
const gl = this.getGL();
const shaderContext = this.getShaderContext();
float32Array = new Float32Array(this.width * this.height);
// Create texture if needed
const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture);
if (isIOS()) {
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
// Uint16Array, however, and requires a manual bitwise conversion from
// Uint16 to floating point numbers. This conversion is more expensive
// that reading back a Float32Array from the RGBA image and dropping
// the superfluous data, so we do this instead.
const outputArray = new Float32Array(this.width * this.height * 4);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
float32Array[i] = outputArray[j];
}
} else {
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
shaderContext.unbindFramebuffer();
}
}
this.containers.push(float32Array);
}
@ -274,9 +315,9 @@ export class MPMask {
webGLTexture = this.bindTexture();
const data = this.convertToFloat32Array();
// TODO: Add support for R16F to support iOS
const format = this.getTexImage2DFormat();
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
gl.FLOAT, data);
this.unbindTexture();
}

View File

@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
// it uses "CriOS".
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
}
/** Detect if code is running on iOS. */
export function isIOS() {
// Source:
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
return [
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
'iPod'
// tslint:disable-next-line:deprecation
].includes(navigator.platform)
// iPad on iOS 13 detection
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
}