Add iOS support for MPMask
PiperOrigin-RevId: 534155657
This commit is contained in:
parent
102cffdf4c
commit
51730ec25c
|
@ -62,7 +62,10 @@ jasmine_node_test(
|
|||
mediapipe_ts_library(
|
||||
name = "mask",
|
||||
srcs = ["mask.ts"],
|
||||
deps = [":image"],
|
||||
deps = [
|
||||
":image",
|
||||
"//mediapipe/web/graph_runner:platform_utils",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||
import {isIOS} from '../../../../web/graph_runner/platform_utils';
|
||||
|
||||
/** Number of instances a user can keep alive before we raise a warning. */
|
||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
||||
|
@ -32,6 +33,8 @@ enum MPMaskType {
|
|||
/** The supported mask formats. For internal usage. */
|
||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* The wrapper class for MediaPipe segmentation masks.
|
||||
*
|
||||
|
@ -56,6 +59,9 @@ export class MPMask {
|
|||
*/
|
||||
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
||||
|
||||
/** The format used to write pixel values from textures. */
|
||||
private static texImage2DFormat?: GLenum;
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
private readonly containers: MPMaskContainer[],
|
||||
|
@ -127,6 +133,29 @@ export class MPMask {
|
|||
return this.convertToWebGLTexture();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the texture format used for writing float textures on this
|
||||
* platform.
|
||||
*/
|
||||
getTexImage2DFormat(): GLenum {
|
||||
const gl = this.getGL();
|
||||
if (!MPMask.texImage2DFormat) {
|
||||
// Note: This is the same check we use in
|
||||
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
|
||||
if (gl.getExtension('EXT_color_buffer_float') &&
|
||||
gl.getExtension('OES_texture_float_linear') &&
|
||||
gl.getExtension('EXT_float_blend')) {
|
||||
MPMask.texImage2DFormat = gl.R32F;
|
||||
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
|
||||
MPMask.texImage2DFormat = gl.R16F;
|
||||
} else {
|
||||
throw new Error(
|
||||
'GPU does not fully support 4-channel float32 or float16 formats');
|
||||
}
|
||||
}
|
||||
return MPMask.texImage2DFormat;
|
||||
}
|
||||
|
||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||
|
@ -176,8 +205,9 @@ export class MPMask {
|
|||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||
this.configureTextureParams();
|
||||
const format = this.getTexImage2DFormat();
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
||||
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||
gl.FLOAT, null);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
|
||||
|
@ -208,7 +238,7 @@ export class MPMask {
|
|||
if (!this.canvas) {
|
||||
throw new Error(
|
||||
'Conversion to different image formats require that a canvas ' +
|
||||
'is passed when iniitializing the image.');
|
||||
'is passed when initializing the image.');
|
||||
}
|
||||
if (!this.gl) {
|
||||
this.gl = assertNotNull(
|
||||
|
@ -216,11 +246,6 @@ export class MPMask {
|
|||
'You cannot use a canvas that is already bound to a different ' +
|
||||
'type of rendering context.');
|
||||
}
|
||||
const ext = this.gl.getExtension('EXT_color_buffer_float');
|
||||
if (!ext) {
|
||||
// TODO: Ensure this works on iOS
|
||||
throw new Error('Missing required EXT_color_buffer_float extension');
|
||||
}
|
||||
return this.gl;
|
||||
}
|
||||
|
||||
|
@ -238,18 +263,34 @@ export class MPMask {
|
|||
if (uint8Array) {
|
||||
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||
} else {
|
||||
float32Array = new Float32Array(this.width * this.height);
|
||||
|
||||
const gl = this.getGL();
|
||||
const shaderContext = this.getShaderContext();
|
||||
float32Array = new Float32Array(this.width * this.height);
|
||||
|
||||
// Create texture if needed
|
||||
const webGlTexture = this.convertToWebGLTexture();
|
||||
|
||||
// Create a framebuffer from the texture and read back pixels
|
||||
shaderContext.bindFramebuffer(gl, webGlTexture);
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||
shaderContext.unbindFramebuffer();
|
||||
|
||||
if (isIOS()) {
|
||||
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
|
||||
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
|
||||
// Uint16Array, however, and requires a manual bitwise conversion from
|
||||
// Uint16 to floating point numbers. This conversion is more expensive
|
||||
// that reading back a Float32Array from the RGBA image and dropping
|
||||
// the superfluous data, so we do this instead.
|
||||
const outputArray = new Float32Array(this.width * this.height * 4);
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
|
||||
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
|
||||
float32Array[i] = outputArray[j];
|
||||
}
|
||||
} else {
|
||||
gl.readPixels(
|
||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||
}
|
||||
}
|
||||
this.containers.push(float32Array);
|
||||
}
|
||||
|
@ -274,9 +315,9 @@ export class MPMask {
|
|||
webGLTexture = this.bindTexture();
|
||||
|
||||
const data = this.convertToFloat32Array();
|
||||
// TODO: Add support for R16F to support iOS
|
||||
const format = this.getTexImage2DFormat();
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
||||
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||
gl.FLOAT, data);
|
||||
this.unbindTexture();
|
||||
}
|
||||
|
|
|
@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
|
|||
// it uses "CriOS".
|
||||
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
||||
}
|
||||
|
||||
/** Detect if code is running on iOS. */
|
||||
export function isIOS() {
|
||||
// Source:
|
||||
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
|
||||
return [
|
||||
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
|
||||
'iPod'
|
||||
// tslint:disable-next-line:deprecation
|
||||
].includes(navigator.platform)
|
||||
// iPad on iOS 13 detection
|
||||
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user