Add iOS support for MPMask
PiperOrigin-RevId: 534155657
This commit is contained in:
parent
102cffdf4c
commit
51730ec25c
|
@ -62,7 +62,10 @@ jasmine_node_test(
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "mask",
|
name = "mask",
|
||||||
srcs = ["mask.ts"],
|
srcs = ["mask.ts"],
|
||||||
deps = [":image"],
|
deps = [
|
||||||
|
":image",
|
||||||
|
"//mediapipe/web/graph_runner:platform_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||||
|
import {isIOS} from '../../../../web/graph_runner/platform_utils';
|
||||||
|
|
||||||
/** Number of instances a user can keep alive before we raise a warning. */
|
/** Number of instances a user can keep alive before we raise a warning. */
|
||||||
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
|
||||||
|
@ -32,6 +33,8 @@ enum MPMaskType {
|
||||||
/** The supported mask formats. For internal usage. */
|
/** The supported mask formats. For internal usage. */
|
||||||
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The wrapper class for MediaPipe segmentation masks.
|
* The wrapper class for MediaPipe segmentation masks.
|
||||||
*
|
*
|
||||||
|
@ -56,6 +59,9 @@ export class MPMask {
|
||||||
*/
|
*/
|
||||||
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
private static instancesBeforeWarning = INSTANCE_COUNT_WARNING_THRESHOLD;
|
||||||
|
|
||||||
|
/** The format used to write pixel values from textures. */
|
||||||
|
private static texImage2DFormat?: GLenum;
|
||||||
|
|
||||||
/** @hideconstructor */
|
/** @hideconstructor */
|
||||||
constructor(
|
constructor(
|
||||||
private readonly containers: MPMaskContainer[],
|
private readonly containers: MPMaskContainer[],
|
||||||
|
@ -127,6 +133,29 @@ export class MPMask {
|
||||||
return this.convertToWebGLTexture();
|
return this.convertToWebGLTexture();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the texture format used for writing float textures on this
|
||||||
|
* platform.
|
||||||
|
*/
|
||||||
|
getTexImage2DFormat(): GLenum {
|
||||||
|
const gl = this.getGL();
|
||||||
|
if (!MPMask.texImage2DFormat) {
|
||||||
|
// Note: This is the same check we use in
|
||||||
|
// `SegmentationPostprocessorGl::GetSegmentationResultGpu()`.
|
||||||
|
if (gl.getExtension('EXT_color_buffer_float') &&
|
||||||
|
gl.getExtension('OES_texture_float_linear') &&
|
||||||
|
gl.getExtension('EXT_float_blend')) {
|
||||||
|
MPMask.texImage2DFormat = gl.R32F;
|
||||||
|
} else if (gl.getExtension('EXT_color_buffer_half_float')) {
|
||||||
|
MPMask.texImage2DFormat = gl.R16F;
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
'GPU does not fully support 4-channel float32 or float16 formats');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MPMask.texImage2DFormat;
|
||||||
|
}
|
||||||
|
|
||||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||||
|
@ -176,8 +205,9 @@ export class MPMask {
|
||||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
|
||||||
this.configureTextureParams();
|
this.configureTextureParams();
|
||||||
|
const format = this.getTexImage2DFormat();
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||||
gl.FLOAT, null);
|
gl.FLOAT, null);
|
||||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
|
||||||
|
@ -208,7 +238,7 @@ export class MPMask {
|
||||||
if (!this.canvas) {
|
if (!this.canvas) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Conversion to different image formats require that a canvas ' +
|
'Conversion to different image formats require that a canvas ' +
|
||||||
'is passed when iniitializing the image.');
|
'is passed when initializing the image.');
|
||||||
}
|
}
|
||||||
if (!this.gl) {
|
if (!this.gl) {
|
||||||
this.gl = assertNotNull(
|
this.gl = assertNotNull(
|
||||||
|
@ -216,11 +246,6 @@ export class MPMask {
|
||||||
'You cannot use a canvas that is already bound to a different ' +
|
'You cannot use a canvas that is already bound to a different ' +
|
||||||
'type of rendering context.');
|
'type of rendering context.');
|
||||||
}
|
}
|
||||||
const ext = this.gl.getExtension('EXT_color_buffer_float');
|
|
||||||
if (!ext) {
|
|
||||||
// TODO: Ensure this works on iOS
|
|
||||||
throw new Error('Missing required EXT_color_buffer_float extension');
|
|
||||||
}
|
|
||||||
return this.gl;
|
return this.gl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,18 +263,34 @@ export class MPMask {
|
||||||
if (uint8Array) {
|
if (uint8Array) {
|
||||||
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||||
} else {
|
} else {
|
||||||
|
float32Array = new Float32Array(this.width * this.height);
|
||||||
|
|
||||||
const gl = this.getGL();
|
const gl = this.getGL();
|
||||||
const shaderContext = this.getShaderContext();
|
const shaderContext = this.getShaderContext();
|
||||||
float32Array = new Float32Array(this.width * this.height);
|
|
||||||
|
|
||||||
// Create texture if needed
|
// Create texture if needed
|
||||||
const webGlTexture = this.convertToWebGLTexture();
|
const webGlTexture = this.convertToWebGLTexture();
|
||||||
|
|
||||||
// Create a framebuffer from the texture and read back pixels
|
// Create a framebuffer from the texture and read back pixels
|
||||||
shaderContext.bindFramebuffer(gl, webGlTexture);
|
shaderContext.bindFramebuffer(gl, webGlTexture);
|
||||||
gl.readPixels(
|
|
||||||
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
if (isIOS()) {
|
||||||
shaderContext.unbindFramebuffer();
|
// WebKit on iOS only supports gl.HALF_FLOAT for single channel reads
|
||||||
|
// (as tested on iOS 16.4). HALF_FLOAT requires reading data into a
|
||||||
|
// Uint16Array, however, and requires a manual bitwise conversion from
|
||||||
|
// Uint16 to floating point numbers. This conversion is more expensive
|
||||||
|
// that reading back a Float32Array from the RGBA image and dropping
|
||||||
|
// the superfluous data, so we do this instead.
|
||||||
|
const outputArray = new Float32Array(this.width * this.height * 4);
|
||||||
|
gl.readPixels(
|
||||||
|
0, 0, this.width, this.height, gl.RGBA, gl.FLOAT, outputArray);
|
||||||
|
for (let i = 0, j = 0; i < float32Array.length; ++i, j += 4) {
|
||||||
|
float32Array[i] = outputArray[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gl.readPixels(
|
||||||
|
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
this.containers.push(float32Array);
|
this.containers.push(float32Array);
|
||||||
}
|
}
|
||||||
|
@ -274,9 +315,9 @@ export class MPMask {
|
||||||
webGLTexture = this.bindTexture();
|
webGLTexture = this.bindTexture();
|
||||||
|
|
||||||
const data = this.convertToFloat32Array();
|
const data = this.convertToFloat32Array();
|
||||||
// TODO: Add support for R16F to support iOS
|
const format = this.getTexImage2DFormat();
|
||||||
gl.texImage2D(
|
gl.texImage2D(
|
||||||
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
|
gl.TEXTURE_2D, 0, format, this.width, this.height, 0, gl.RED,
|
||||||
gl.FLOAT, data);
|
gl.FLOAT, data);
|
||||||
this.unbindTexture();
|
this.unbindTexture();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,3 +21,16 @@ export function isWebKit(browser = navigator) {
|
||||||
// it uses "CriOS".
|
// it uses "CriOS".
|
||||||
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
return userAgent.includes('Safari') && !userAgent.includes('Chrome');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Detect if code is running on iOS. */
|
||||||
|
export function isIOS() {
|
||||||
|
// Source:
|
||||||
|
// https://stackoverflow.com/questions/9038625/detect-if-device-is-ios
|
||||||
|
return [
|
||||||
|
'iPad Simulator', 'iPhone Simulator', 'iPod Simulator', 'iPad', 'iPhone',
|
||||||
|
'iPod'
|
||||||
|
// tslint:disable-next-line:deprecation
|
||||||
|
].includes(navigator.platform)
|
||||||
|
// iPad on iOS 13 detection
|
||||||
|
|| (navigator.userAgent.includes('Mac') && 'ontouchend' in document);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user