From 12340a8e82cec65d6086326155e0c39d0f4caa58 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 15 Nov 2023 13:00:08 -0800 Subject: [PATCH] Use gl.LINEAR interpolation for confidence masks PiperOrigin-RevId: 582777383 --- .../web/vision/core/drawing_utils.test.ts | 51 +++++++++++-------- .../tasks/web/vision/core/drawing_utils.ts | 1 + .../core/drawing_utils_category_mask.ts | 23 +++++---- .../core/drawing_utils_confidence_mask.ts | 18 +++---- mediapipe/tasks/web/vision/core/mask.test.ts | 6 +-- mediapipe/tasks/web/vision/core/mask.ts | 34 +++++++++---- .../web/vision/core/vision_task_runner.ts | 7 +-- .../vision/image_segmenter/image_segmenter.ts | 10 +++- .../interactive_segmenter.ts | 8 ++- .../vision/pose_landmarker/pose_landmarker.ts | 3 +- 10 files changed, 98 insertions(+), 63 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.test.ts b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts index aaef42bbf..c32a5fc56 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils.test.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts @@ -30,25 +30,20 @@ if (skip) { (skip ? xdescribe : describe)('DrawingUtils', () => { let shaderContext = new MPImageShaderContext(); - let canvas2D: HTMLCanvasElement; - let context2D: CanvasRenderingContext2D; + let canvas2D: OffscreenCanvas; + let context2D: OffscreenCanvasRenderingContext2D; let drawingUtils2D: DrawingUtils; - let canvasWebGL: HTMLCanvasElement; + let canvasWebGL: OffscreenCanvas; let contextWebGL: WebGL2RenderingContext; let drawingUtilsWebGL: DrawingUtils; beforeEach(() => { - shaderContext = new MPImageShaderContext(); + canvas2D = canvas2D ?? new OffscreenCanvas(WIDTH, HEIGHT); + canvasWebGL = canvasWebGL ?? new OffscreenCanvas(WIDTH, HEIGHT); - canvasWebGL = document.createElement('canvas'); - canvasWebGL.width = WIDTH; - canvasWebGL.height = HEIGHT; + shaderContext = new MPImageShaderContext(); contextWebGL = canvasWebGL.getContext('webgl2')!; drawingUtilsWebGL = new DrawingUtils(contextWebGL); - - canvas2D = document.createElement('canvas'); - canvas2D.width = WIDTH; - canvas2D.height = HEIGHT; context2D = canvas2D.getContext('2d')!; drawingUtils2D = new DrawingUtils(context2D, contextWebGL); }); @@ -61,11 +56,11 @@ if (skip) { describe( 'drawConfidenceMask() blends background with foreground color', () => { - const foreground = new ImageData( + const defaultColor = [255, 255, 255, 255]; + const overlayImage = new ImageData( new Uint8ClampedArray( [0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255]), WIDTH, HEIGHT); - const background = [255, 255, 255, 255]; const expectedResult = new Uint8Array([ 255, 255, 255, 255, 178, 178, 178, 255, 102, 102, 102, 255, 0, 0, 0, 255 @@ -74,48 +69,52 @@ if (skip) { it('on 2D canvas', () => { const confidenceMask = new MPMask( [new Float32Array([0.0, 0.3, 0.6, 1.0])], + /* interpolateValues= */ true, /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH, HEIGHT); drawingUtils2D.drawConfidenceMask( - confidenceMask, background, foreground); + confidenceMask, defaultColor, overlayImage); const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data; expect(actualResult) .toEqual(new Uint8ClampedArray(expectedResult.buffer)); + confidenceMask.close(); }); it('on WebGL canvas', () => { const confidenceMask = new MPMask( [new Float32Array( [0.6, 1.0, 0.0, 0.3])], // Note: Vertically flipped + /* interpolateValues= */ true, /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH, HEIGHT); drawingUtilsWebGL.drawConfidenceMask( - confidenceMask, background, foreground); + confidenceMask, defaultColor, overlayImage); const actualResult = new Uint8Array(WIDTH * HEIGHT * 4); contextWebGL.readPixels( 0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE, actualResult); expect(actualResult).toEqual(expectedResult); + confidenceMask.close(); }); }); describe( 'drawConfidenceMask() blends background with foreground image', () => { - const foreground = new ImageData( - new Uint8ClampedArray( - [0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255]), - WIDTH, HEIGHT); - const background = new ImageData( + const defaultImage = new ImageData( new Uint8ClampedArray([ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 ]), WIDTH, HEIGHT); + const overlayImage = new ImageData( + new Uint8ClampedArray( + [0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255]), + WIDTH, HEIGHT); const expectedResult = new Uint8Array([ 255, 255, 255, 255, 178, 178, 178, 255, 102, 102, 102, 255, 0, 0, 0, 255 @@ -124,32 +123,36 @@ if (skip) { it('on 2D canvas', () => { const confidenceMask = new MPMask( [new Float32Array([0.0, 0.3, 0.6, 1.0])], + /* interpolateValues= */ true, /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH, HEIGHT); drawingUtils2D.drawConfidenceMask( - confidenceMask, background, foreground); + confidenceMask, defaultImage, overlayImage); const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data; expect(actualResult) .toEqual(new Uint8ClampedArray(expectedResult.buffer)); + confidenceMask.close(); }); it('on WebGL canvas', () => { const confidenceMask = new MPMask( [new Float32Array( [0.6, 1.0, 0.0, 0.3])], // Note: Vertically flipped + /* interpolateValues= */ true, /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH, HEIGHT); drawingUtilsWebGL.drawConfidenceMask( - confidenceMask, background, foreground); + confidenceMask, defaultImage, overlayImage); const actualResult = new Uint8Array(WIDTH * HEIGHT * 4); contextWebGL.readPixels( 0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE, actualResult); expect(actualResult).toEqual(expectedResult); + confidenceMask.close(); }); }); @@ -167,6 +170,7 @@ if (skip) { it('on 2D canvas', () => { const categoryMask = new MPMask( [new Uint8Array([0, 1, 2, 3])], + /* interpolateValues= */ false, /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH, HEIGHT); @@ -175,11 +179,13 @@ if (skip) { const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data; expect(actualResult) .toEqual(new Uint8ClampedArray(expectedResult.buffer)); + categoryMask.close(); }); it('on WebGL canvas', () => { const categoryMask = new MPMask( [new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped + /* interpolateValues= */ false, /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH, HEIGHT); @@ -190,6 +196,7 @@ if (skip) { 0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE, actualResult); expect(actualResult).toEqual(expectedResult); + categoryMask.close(); }); }); diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.ts b/mediapipe/tasks/web/vision/core/drawing_utils.ts index 154420f6b..520f9e2b3 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils.ts @@ -419,6 +419,7 @@ export class DrawingUtils { const convertedMask = new MPMask( [data], + mask.interpolateValues, /* ownsWebGlTexture= */ false, gl.canvas, this.convertToWebGLTextureShaderContext, diff --git a/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts index d7706075f..3b7cc0b47 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts @@ -92,11 +92,15 @@ export class CategoryMaskShaderContext extends MPImageShaderContext { colorMap: Map|number[][]) { const gl = this.gl!; + // Bind category mask + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, categoryMask); + // TODO: We should avoid uploading textures from CPU to GPU // if the textures haven't changed. This can lead to drastic performance // slowdowns (~50ms per frame). Users can reduce the penalty by passing a // canvas object instead of ImageData/HTMLImageElement. - gl.activeTexture(gl.TEXTURE0); + gl.activeTexture(gl.TEXTURE1); gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background); @@ -117,19 +121,15 @@ export class CategoryMaskShaderContext extends MPImageShaderContext { pixels[index * 4 + 2] = rgba[2]; pixels[index * 4 + 3] = rgba[3]; }); - gl.activeTexture(gl.TEXTURE1); + gl.activeTexture(gl.TEXTURE2); gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE, new Uint8Array(pixels)); } else { - gl.activeTexture(gl.TEXTURE1); + gl.activeTexture(gl.TEXTURE2); gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); } - - // Bind category mask - gl.activeTexture(gl.TEXTURE2); - gl.bindTexture(gl.TEXTURE_2D, categoryMask); } unbindTextures() { @@ -148,10 +148,11 @@ export class CategoryMaskShaderContext extends MPImageShaderContext { protected override setupTextures(): void { const gl = this.gl!; - gl.activeTexture(gl.TEXTURE0); + gl.activeTexture(gl.TEXTURE1); this.backgroundTexture = this.createTexture(gl, gl.LINEAR); // Use `gl.NEAREST` to prevent interpolating values in our category to // color map. + gl.activeTexture(gl.TEXTURE2); this.colorMappingTexture = this.createTexture(gl, gl.NEAREST); } @@ -172,9 +173,9 @@ export class CategoryMaskShaderContext extends MPImageShaderContext { protected override configureUniforms(): void { super.configureUniforms(); const gl = this.gl!; - gl.uniform1i(this.backgroundTextureUniform!, 0); - gl.uniform1i(this.colorMappingTextureUniform!, 1); - gl.uniform1i(this.maskTextureUniform!, 2); + gl.uniform1i(this.maskTextureUniform!, 0); + gl.uniform1i(this.backgroundTextureUniform!, 1); + gl.uniform1i(this.colorMappingTextureUniform!, 2); } override close(): void { diff --git a/mediapipe/tasks/web/vision/core/drawing_utils_confidence_mask.ts b/mediapipe/tasks/web/vision/core/drawing_utils_confidence_mask.ts index c8d30c9ee..953911f01 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils_confidence_mask.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils_confidence_mask.ts @@ -51,9 +51,9 @@ export class ConfidenceMaskShaderContext extends MPImageShaderContext { protected override setupTextures(): void { const gl = this.gl!; - gl.activeTexture(gl.TEXTURE0); - this.defaultTexture = this.createTexture(gl); gl.activeTexture(gl.TEXTURE1); + this.defaultTexture = this.createTexture(gl); + gl.activeTexture(gl.TEXTURE2); this.overlayTexture = this.createTexture(gl); } @@ -74,9 +74,9 @@ export class ConfidenceMaskShaderContext extends MPImageShaderContext { protected override configureUniforms(): void { super.configureUniforms(); const gl = this.gl!; - gl.uniform1i(this.defaultTextureUniform!, 0); - gl.uniform1i(this.overlayTextureUniform!, 1); - gl.uniform1i(this.maskTextureUniform!, 2); + gl.uniform1i(this.maskTextureUniform!, 0); + gl.uniform1i(this.defaultTextureUniform!, 1); + gl.uniform1i(this.overlayTextureUniform!, 2); } bindAndUploadTextures( @@ -88,17 +88,17 @@ export class ConfidenceMaskShaderContext extends MPImageShaderContext { // canvas object instead of ImageData/HTMLImageElement. const gl = this.gl!; gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, confidenceMask); + + gl.activeTexture(gl.TEXTURE1); gl.bindTexture(gl.TEXTURE_2D, this.defaultTexture!); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, defaultImage); - gl.activeTexture(gl.TEXTURE1); + gl.activeTexture(gl.TEXTURE2); gl.bindTexture(gl.TEXTURE_2D, this.overlayTexture!); gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, overlayImage); - - gl.activeTexture(gl.TEXTURE2); - gl.bindTexture(gl.TEXTURE_2D, confidenceMask); } unbindTextures() { diff --git a/mediapipe/tasks/web/vision/core/mask.test.ts b/mediapipe/tasks/web/vision/core/mask.test.ts index d2f5ddb09..29ed5ea02 100644 --- a/mediapipe/tasks/web/vision/core/mask.test.ts +++ b/mediapipe/tasks/web/vision/core/mask.test.ts @@ -136,7 +136,7 @@ class MPMaskTestContext { shaderContext: MPImageShaderContext, input: MaskType, width: number, height: number): MPMask { return new MPMask( - [input], + [input], /* interpolateValues= */ false, /* ownsWebGLTexture= */ false, context.canvas, shaderContext, width, height); } @@ -182,7 +182,7 @@ class MPMaskTestContext { const shaderContext = new MPImageShaderContext(); const mask = new MPMask( - [context.webGLTexture], + [context.webGLTexture], /* interpolateValues= */ false, /* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH, HEIGHT); @@ -196,7 +196,7 @@ class MPMaskTestContext { const shaderContext = new MPImageShaderContext(); const mask = new MPMask( - [context.webGLTexture], + [context.webGLTexture], /* interpolateValues= */ false, /* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH, HEIGHT); diff --git a/mediapipe/tasks/web/vision/core/mask.ts b/mediapipe/tasks/web/vision/core/mask.ts index b463589e4..d08145a2d 100644 --- a/mediapipe/tasks/web/vision/core/mask.ts +++ b/mediapipe/tasks/web/vision/core/mask.ts @@ -62,9 +62,25 @@ export class MPMask { /** The format used to write pixel values from textures. */ private static texImage2DFormat?: GLenum; - /** @hideconstructor */ + /** + * @param containers The data source for this mask as a `WebGLTexture`, + * `Unit8Array` or `Float32Array`. Multiple sources of the same data can + * be provided to reduce conversions. + * @param interpolateValues If enabled, uses `gl.LINEAR` instead of + * `gl.NEAREST` to interpolate between mask values. + * @param ownsWebGLTexture Whether the MPMask should take ownership of the + * `WebGLTexture` and free it when closed. + * @param canvas The canvas to use for rendering and conversion. Must be the + * same canvas for any WebGL resources. + * @param shaderContext A shader context that is shared between all masks from + * a single task. + * @param width The width of the mask. + * @param height The height of the mask. + * @hideconstructor + */ constructor( private readonly containers: MPMaskContainer[], + readonly interpolateValues: boolean, private ownsWebGLTexture: boolean, /** Returns the canvas element that the mask is bound to. */ readonly canvas: HTMLCanvasElement|OffscreenCanvas|undefined, @@ -215,7 +231,8 @@ export class MPMask { // Create a new texture and use it to back a framebuffer gl.activeTexture(gl.TEXTURE1); - destinationContainer = shaderContext.createTexture(gl, gl.NEAREST); + destinationContainer = shaderContext.createTexture( + gl, this.interpolateValues ? gl.LINEAR : gl.NEAREST); gl.bindTexture(gl.TEXTURE_2D, destinationContainer); const format = this.getTexImage2DFormat(); gl.texImage2D( @@ -242,8 +259,8 @@ export class MPMask { } return new MPMask( - destinationContainers, this.hasWebGLTexture(), this.canvas, - this.shaderContext, this.width, this.height); + destinationContainers, this.interpolateValues, this.hasWebGLTexture(), + this.canvas, this.shaderContext, this.width, this.height); } private getGL(): WebGL2RenderingContext { @@ -254,7 +271,7 @@ export class MPMask { } if (!this.gl) { this.gl = assertNotNull( - this.canvas.getContext('webgl2') as WebGL2RenderingContext | null, + this.canvas.getContext('webgl2'), 'You cannot use a canvas that is already bound to a different ' + 'type of rendering context.'); } @@ -350,11 +367,8 @@ export class MPMask { let webGLTexture = this.getContainer(MPMaskType.WEBGL_TEXTURE); if (!webGLTexture) { const shaderContext = this.getShaderContext(); - // `gl.NEAREST` ensures that we do not get interpolated values for - // masks. In some cases, the user might want interpolation (e.g. for - // confidence masks), so we might want to make this user-configurable. - // Note that `MPImage` uses `gl.LINEAR`. - webGLTexture = shaderContext.createTexture(gl, gl.NEAREST); + webGLTexture = shaderContext.createTexture( + gl, this.interpolateValues ? gl.LINEAR : gl.NEAREST); this.containers.push(webGLTexture); this.ownsWebGLTexture = true; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index b9aa5e352..292a37eec 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -274,8 +274,9 @@ export abstract class VisionTaskRunner extends TaskRunner { } /** Converts a WasmImage to an MPMask. */ - protected convertToMPMask(wasmImage: WasmImage, shouldCopyData: boolean): - MPMask { + protected convertToMPMask( + wasmImage: WasmImage, interpolateValues: boolean, + shouldCopyData: boolean): MPMask { const {data, width, height} = wasmImage; const pixels = width * height; @@ -291,7 +292,7 @@ export abstract class VisionTaskRunner extends TaskRunner { } const mask = new MPMask( - [container], + [container], interpolateValues, /* ownsWebGLTexture= */ false, this.graphRunner.wasmModule.canvas!, this.shaderContext, width, height); return shouldCopyData ? mask.clone() : mask; diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index d8751b9e3..cbd20450b 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -424,7 +424,10 @@ export class ImageSegmenter extends VisionTaskRunner { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { this.confidenceMasks = masks.map( wasmImage => this.convertToMPMask( - wasmImage, /* shouldCopyData= */ !this.userCallback)); + wasmImage, + /* interpolateValues= */ true, + /* shouldCopyData= */ !this.userCallback, + )); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( @@ -442,7 +445,10 @@ export class ImageSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { this.categoryMask = this.convertToMPMask( - mask, /* shouldCopyData= */ !this.userCallback); + mask, + /* interpolateValues= */ false, + /* shouldCopyData= */ !this.userCallback, + ); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 887f55839..5a37b9ff0 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -341,7 +341,10 @@ export class InteractiveSegmenter extends VisionTaskRunner { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { this.confidenceMasks = masks.map( wasmImage => this.convertToMPMask( - wasmImage, /* shouldCopyData= */ !this.userCallback)); + wasmImage, + /* interpolateValues= */ true, + /* shouldCopyData= */ !this.userCallback, + )); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( @@ -359,7 +362,8 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { this.categoryMask = this.convertToMPMask( - mask, /* shouldCopyData= */ !this.userCallback); + mask, /* interpolateValues= */ false, + /* shouldCopyData= */ !this.userCallback); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts index 8f6531827..262966d72 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts @@ -470,7 +470,8 @@ export class PoseLandmarker extends VisionTaskRunner { SEGMENTATION_MASK_STREAM, (masks, timestamp) => { this.segmentationMasks = masks.map( wasmImage => this.convertToMPMask( - wasmImage, /* shouldCopyData= */ !this.userCallback)); + wasmImage, /* interpolateValues= */ true, + /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener(