diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index dfbbb9f91..31bad937d 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -31,27 +31,57 @@ mediapipe_ts_library( mediapipe_ts_library( name = "drawing_utils", - srcs = ["drawing_utils.ts"], + srcs = [ + "drawing_utils.ts", + "drawing_utils_category_mask.ts", + ], deps = [ + ":image", + ":image_shader_context", + ":mask", ":types", "//mediapipe/tasks/web/components/containers:bounding_box", "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) mediapipe_ts_library( - name = "image", - srcs = [ - "image.ts", - "image_shader_context.ts", + name = "drawing_utils_test_lib", + testonly = True, + srcs = ["drawing_utils.test.ts"], + deps = [ + ":drawing_utils", + ":image", + ":image_shader_context", + ":mask", ], ) +jasmine_node_test( + name = "drawing_utils_test", + deps = [":drawing_utils_test_lib"], +) + +mediapipe_ts_library( + name = "image", + srcs = ["image.ts"], + deps = ["image_shader_context"], +) + +mediapipe_ts_library( + name = "image_shader_context", + srcs = ["image_shader_context.ts"], +) + mediapipe_ts_library( name = "image_test_lib", testonly = True, srcs = ["image.test.ts"], - deps = [":image"], + deps = [ + ":image", + ":image_shader_context", + ], ) jasmine_node_test( @@ -64,6 +94,7 @@ mediapipe_ts_library( srcs = ["mask.ts"], deps = [ ":image", + ":image_shader_context", "//mediapipe/web/graph_runner:platform_utils", ], ) @@ -74,6 +105,7 @@ mediapipe_ts_library( srcs = ["mask.test.ts"], deps = [ ":image", + ":image_shader_context", ":mask", ], ) @@ -89,6 +121,7 @@ mediapipe_ts_library( deps = [ ":image", ":image_processing_options", + ":image_shader_context", ":mask", ":vision_task_options", "//mediapipe/framework/formats:rect_jspb_proto", diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.test.ts b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts new file mode 100644 index 000000000..b5ba8e9a4 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/drawing_utils.test.ts @@ -0,0 +1,103 @@ +/** + * Copyright 2023 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {DrawingUtils} from './drawing_utils'; +import {MPImageShaderContext} from './image_shader_context'; +import {MPMask} from './mask'; + +const WIDTH = 2; +const HEIGHT = 2; + +const skip = typeof document === 'undefined'; +if (skip) { + console.log('These tests must be run in a browser.'); +} + +(skip ? xdescribe : describe)('DrawingUtils', () => { + let shaderContext = new MPImageShaderContext(); + let canvas2D: HTMLCanvasElement; + let context2D: CanvasRenderingContext2D; + let drawingUtils2D: DrawingUtils; + let canvasWebGL: HTMLCanvasElement; + let contextWebGL: WebGL2RenderingContext; + let drawingUtilsWebGL: DrawingUtils; + + beforeEach(() => { + shaderContext = new MPImageShaderContext(); + + canvasWebGL = document.createElement('canvas'); + canvasWebGL.width = WIDTH; + canvasWebGL.height = HEIGHT; + contextWebGL = canvasWebGL.getContext('webgl2')!; + drawingUtilsWebGL = new DrawingUtils(contextWebGL); + + canvas2D = document.createElement('canvas'); + canvas2D.width = WIDTH; + canvas2D.height = HEIGHT; + context2D = canvas2D.getContext('2d')!; + drawingUtils2D = new DrawingUtils(context2D, contextWebGL); + }); + + afterEach(() => { + shaderContext.close(); + drawingUtils2D.close(); + drawingUtilsWebGL.close(); + }); + + describe('drawCategoryMask() ', () => { + const colors = [ + [0, 0, 0, 255], + [0, 255, 0, 255], + [0, 0, 255, 255], + [255, 255, 255, 255], + ]; + const expectedResult = new Uint8Array( + [0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255], + ); + + it('on 2D canvas', () => { + const categoryMask = new MPMask( + [new Uint8Array([0, 1, 2, 3])], + /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH, + HEIGHT); + + drawingUtils2D.drawCategoryMask(categoryMask, colors); + + const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data; + expect(actualResult) + .toEqual(new Uint8ClampedArray(expectedResult.buffer)); + }); + + it('on WebGL canvas', () => { + const categoryMask = new MPMask( + [new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped + /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH, + HEIGHT); + + drawingUtilsWebGL.drawCategoryMask(categoryMask, colors); + + const actualResult = new Uint8Array(WIDTH * WIDTH * 4); + contextWebGL.readPixels( + 0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE, + actualResult); + expect(actualResult).toEqual(expectedResult); + }); + }); + + // TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox +}); diff --git a/mediapipe/tasks/web/vision/core/drawing_utils.ts b/mediapipe/tasks/web/vision/core/drawing_utils.ts index c1e84fa11..95e376fb2 100644 --- a/mediapipe/tasks/web/vision/core/drawing_utils.ts +++ b/mediapipe/tasks/web/vision/core/drawing_utils.ts @@ -16,7 +16,11 @@ import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask'; +import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {Connection} from '../../../../tasks/web/vision/core/types'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; /** * A user-defined callback to take input data and map it to a custom output @@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types'; */ export type Callback = (input: I) => O; +// Used in public API +export {ImageSource}; + /** Data that a user can use to specialize drawing options. */ export declare interface LandmarkData { index?: number; @@ -31,6 +38,32 @@ export declare interface LandmarkData { to?: NormalizedLandmark; } +/** A color map with 22 classes. Used in our demos. */ +export const DEFAULT_CATEGORY_TO_COLOR_MAP = [ + [0, 0, 0, 0], // class 0 is BG = transparent + [255, 0, 0, 255], // class 1 is red + [0, 255, 0, 255], // class 2 is light green + [0, 0, 255, 255], // class 3 is blue + [255, 255, 0, 255], // class 4 is yellow + [255, 0, 255, 255], // class 5 is light purple / magenta + [0, 255, 255, 255], // class 6 is light blue / aqua + [128, 128, 128, 255], // class 7 is gray + [255, 100, 0, 255], // class 8 is dark orange + [128, 0, 255, 255], // class 9 is dark purple + [0, 150, 0, 255], // class 10 is green + [255, 255, 255, 255], // class 11 is white + [255, 105, 180, 255], // class 12 is pink + [255, 150, 0, 255], // class 13 is orange + [255, 250, 224, 255], // class 14 is light yellow + [148, 0, 211, 255], // class 15 is dark violet + [0, 100, 0, 255], // class 16 is dark green + [0, 0, 128, 255], // class 17 is navy blue + [165, 42, 42, 255], // class 18 is brown + [64, 224, 208, 255], // class 19 is turquoise + [255, 218, 185, 255], // class 20 is peach + [192, 192, 192, 255], // class 21 is silver +]; + /** * Options for customizing the drawing routines */ @@ -77,14 +110,47 @@ function resolve(value: O|Callback, data: I): O { return value instanceof Function ? value(data) : value; } +export {RGBAColor, CategoryToColorMap}; + /** Helper class to visualize the result of a MediaPipe Vision task. */ export class DrawingUtils { + private categoryMaskShaderContext?: CategoryMaskShaderContext; + private convertToWebGLTextureShaderContext?: MPImageShaderContext; + private readonly context2d?: CanvasRenderingContext2D; + private readonly contextWebGL?: WebGL2RenderingContext; + /** * Creates a new DrawingUtils class. * - * @param ctx The canvas to render onto. + * @param gpuContext The WebGL canvas rendering context to render into. If + * your Task is using a GPU delegate, the context must be obtained from + * its canvas (provided via `setOptions({ canvas: .. })`). */ - constructor(private readonly ctx: CanvasRenderingContext2D) {} + constructor(gpuContext: WebGL2RenderingContext); + /** + * Creates a new DrawingUtils class. + * + * @param cpuContext The 2D canvas rendering context to render into. If + * you are rendering GPU data you must also provide `gpuContext` to allow + * for data conversion. + * @param gpuContext A WebGL canvas that is used for GPU rendering and for + * converting GPU to CPU data. If your Task is using a GPU delegate, the + * context must be obtained from its canvas (provided via + * `setOptions({ canvas: .. })`). + */ + constructor( + cpuContext: CanvasRenderingContext2D, + gpuContext?: WebGL2RenderingContext); + constructor( + cpuOrGpuGontext: CanvasRenderingContext2D|WebGL2RenderingContext, + gpuContext?: WebGL2RenderingContext) { + if (cpuOrGpuGontext instanceof CanvasRenderingContext2D) { + this.context2d = cpuOrGpuGontext; + this.contextWebGL = gpuContext; + } else { + this.contextWebGL = cpuOrGpuGontext; + } + } /** * Restricts a number between two endpoints (order doesn't matter). @@ -120,9 +186,35 @@ export class DrawingUtils { return DrawingUtils.clamp(out, y0, y1); } + private getCanvasRenderingContext(): CanvasRenderingContext2D { + if (!this.context2d) { + throw new Error( + 'CPU rendering requested but CanvasRenderingContext2D not provided.'); + } + return this.context2d; + } + + private getWebGLRenderingContext(): WebGL2RenderingContext { + if (!this.contextWebGL) { + throw new Error( + 'GPU rendering requested but WebGL2RenderingContext not provided.'); + } + return this.contextWebGL; + } + + private getCategoryMaskShaderContext(): CategoryMaskShaderContext { + if (!this.categoryMaskShaderContext) { + this.categoryMaskShaderContext = new CategoryMaskShaderContext(); + } + return this.categoryMaskShaderContext; + } + /** * Draws circles onto the provided landmarks. * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param landmarks The landmarks to draw. * @param style The style to visualize the landmarks. @@ -132,7 +224,7 @@ export class DrawingUtils { if (!landmarks) { return; } - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); const canvas = ctx.canvas; @@ -159,6 +251,9 @@ export class DrawingUtils { /** * Draws lines between landmarks (given a connection graph). * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param landmarks The landmarks to draw. * @param connections The connections array that contains the start and the @@ -171,7 +266,7 @@ export class DrawingUtils { if (!landmarks || !connections) { return; } - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); const canvas = ctx.canvas; @@ -195,12 +290,15 @@ export class DrawingUtils { /** * Draws a bounding box. * + * This method can only be used when `DrawingUtils` is initialized with a + * `CanvasRenderingContext2D`. + * * @export * @param boundingBox The bounding box to draw. * @param style The style to visualize the boundin box. */ drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void { - const ctx = this.ctx; + const ctx = this.getCanvasRenderingContext(); const options = addDefaultOptions(style); ctx.save(); ctx.beginPath(); @@ -218,6 +316,118 @@ export class DrawingUtils { ctx.fill(); ctx.restore(); } + + /** Draws a category mask on a CanvasRenderingContext2D. */ + private drawCategoryMask2D( + mask: MPMask, background: RGBAColor|ImageSource, + categoryToColorMap: Map|RGBAColor[]): void { + // Use the WebGL renderer to draw result on our internal canvas. + const gl = this.getWebGLRenderingContext(); + this.runWithWebGLTexture(mask, texture => { + this.drawCategoryMaskWebGL(texture, background, categoryToColorMap); + // Draw the result on the user canvas. + const ctx = this.getCanvasRenderingContext(); + ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height); + }); + } + + /** Draws a category mask on a WebGL2RenderingContext2D. */ + private drawCategoryMaskWebGL( + categoryTexture: WebGLTexture, background: RGBAColor|ImageSource, + categoryToColorMap: Map|RGBAColor[]): void { + const shaderContext = this.getCategoryMaskShaderContext(); + const gl = this.getWebGLRenderingContext(); + const backgroundImage = Array.isArray(background) ? + new ImageData(new Uint8ClampedArray(background), 1, 1) : + background; + + shaderContext.run(gl, /* flipTexturesVertically= */ true, () => { + shaderContext.bindAndUploadTextures( + categoryTexture, backgroundImage, categoryToColorMap); + gl.clearColor(0, 0, 0, 0); + gl.clear(gl.COLOR_BUFFER_BIT); + gl.drawArrays(gl.TRIANGLE_FAN, 0, 4); + shaderContext.unbindTextures(); + }); + } + + /** + * Draws a category mask using the provided category-to-color mapping. + * + * @export + * @param mask A category mask that was returned from a segmentation task. + * @param categoryToColorMap A map that maps category indices to RGBA + * values. You must specify a map entry for each category. + * @param background A color or image to use as the background. Defaults to + * black. + */ + drawCategoryMask( + mask: MPMask, categoryToColorMap: Map, + background?: RGBAColor|ImageSource): void; + /** + * Draws a category mask using the provided color array. + * + * @export + * @param mask A category mask that was returned from a segmentation task. + * @param categoryToColorMap An array that maps indices to RGBA values. The + * array's indices must correspond to the category indices of the model + * and an entry must be provided for each category. + * @param background A color or image to use as the background. Defaults to + * black. + */ + drawCategoryMask( + mask: MPMask, categoryToColorMap: RGBAColor[], + background?: RGBAColor|ImageSource): void; + drawCategoryMask( + mask: MPMask, categoryToColorMap: CategoryToColorMap, + background: RGBAColor|ImageSource = [0, 0, 0, 255]): void { + if (this.context2d) { + this.drawCategoryMask2D(mask, background, categoryToColorMap); + } else { + this.drawCategoryMaskWebGL( + mask.getAsWebGLTexture(), background, categoryToColorMap); + } + } + + /** + * Converts the given mask to a WebGLTexture and runs the callback. Cleans + * up any new resources after the callback finished executing. + */ + private runWithWebGLTexture( + mask: MPMask, callback: (texture: WebGLTexture) => void): void { + if (!mask.hasWebGLTexture()) { + // Re-create the MPMask but use our the WebGL canvas so we can draw the + // texture directly. + const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() : + mask.getAsUint8Array(); + this.convertToWebGLTextureShaderContext = + this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext(); + const gl = this.getWebGLRenderingContext(); + + const convertedMask = new MPMask( + [data], + /* ownsWebGlTexture= */ false, + gl.canvas, + this.convertToWebGLTextureShaderContext, + mask.width, + mask.height, + ); + callback(convertedMask.getAsWebGLTexture()); + convertedMask.close(); + } else { + callback(mask.getAsWebGLTexture()); + } + } + /** + * Frees all WebGL resources held by this class. + * @export + */ + close(): void { + this.categoryMaskShaderContext?.close(); + this.categoryMaskShaderContext = undefined; + this.convertToWebGLTextureShaderContext?.close(); + this.convertToWebGLTextureShaderContext = undefined; + } } diff --git a/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts new file mode 100644 index 000000000..d7706075f --- /dev/null +++ b/mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts @@ -0,0 +1,189 @@ +/** + * Copyright 2023 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +/** + * A fragment shader that maps categories to colors based on a background + * texture, a mask texture and a 256x1 "color mapping texture" that contains one + * color for each pixel. + */ +const FRAGMENT_SHADER = ` + precision mediump float; + uniform sampler2D backgroundTexture; + uniform sampler2D maskTexture; + uniform sampler2D colorMappingTexture; + varying vec2 vTex; + void main() { + vec4 backgroundColor = texture2D(backgroundTexture, vTex); + float category = texture2D(maskTexture, vTex).r; + vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0)); + gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a); + } + `; + +/** + * A four channel color with values for red, green, blue and alpha + * respectively. + */ +export type RGBAColor = [number, number, number, number]|number[]; + +/** + * A category to color mapping that uses either a map or an array to assign + * category indexes to RGBA colors. + */ +export type CategoryToColorMap = Map|RGBAColor[]; + + +/** Checks CategoryToColorMap maps for deep equality. */ +function isEqualColorMap( + a: CategoryToColorMap, b: CategoryToColorMap): boolean { + if (a !== b) { + return false; + } + + const aEntries = a.entries(); + const bEntries = b.entries(); + for (const [aKey, aValue] of aEntries) { + const bNext = bEntries.next(); + if (bNext.done) { + return false; + } + + const [bKey, bValue] = bNext.value; + if (aKey !== bKey) { + return false; + } + + if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] || + aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) { + return false; + } + } + return !!bEntries.next().done; +} + + +/** A drawing util class for category masks. */ +export class CategoryMaskShaderContext extends MPImageShaderContext { + backgroundTexture?: WebGLTexture; + colorMappingTexture?: WebGLTexture; + colorMappingTextureUniform?: WebGLUniformLocation; + backgroundTextureUniform?: WebGLUniformLocation; + maskTextureUniform?: WebGLUniformLocation; + currentColorMap?: CategoryToColorMap; + + bindAndUploadTextures( + categoryMask: WebGLTexture, background: ImageSource, + colorMap: Map|number[][]) { + const gl = this.gl!; + + // TODO: We should avoid uploading textures from CPU to GPU + // if the textures haven't changed. This can lead to drastic performance + // slowdowns (~50ms per frame). Users can reduce the penalty by passing a + // canvas object instead of ImageData/HTMLImageElement. + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background); + + // Bind color mapping texture if changed. + if (!this.currentColorMap || + !isEqualColorMap(this.currentColorMap, colorMap)) { + this.currentColorMap = colorMap; + + const pixels = new Array(256 * 4).fill(0); + colorMap.forEach((rgba, index) => { + if (rgba.length !== 4) { + throw new Error( + `Color at index ${index} is not a four-channel value.`); + } + pixels[index * 4] = rgba[0]; + pixels[index * 4 + 1] = rgba[1]; + pixels[index * 4 + 2] = rgba[2]; + pixels[index * 4 + 3] = rgba[3]; + }); + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE, + new Uint8Array(pixels)); + } else { + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!); + } + + // Bind category mask + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, categoryMask); + } + + unbindTextures() { + const gl = this.gl!; + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, null); + gl.activeTexture(gl.TEXTURE1); + gl.bindTexture(gl.TEXTURE_2D, null); + gl.activeTexture(gl.TEXTURE2); + gl.bindTexture(gl.TEXTURE_2D, null); + } + + protected override getFragmentShader(): string { + return FRAGMENT_SHADER; + } + + protected override setupTextures(): void { + const gl = this.gl!; + gl.activeTexture(gl.TEXTURE0); + this.backgroundTexture = this.createTexture(gl, gl.LINEAR); + // Use `gl.NEAREST` to prevent interpolating values in our category to + // color map. + this.colorMappingTexture = this.createTexture(gl, gl.NEAREST); + } + + protected override setupShaders(): void { + super.setupShaders(); + const gl = this.gl!; + this.backgroundTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'backgroundTexture'), + 'Uniform location'); + this.colorMappingTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'colorMappingTexture'), + 'Uniform location'); + this.maskTextureUniform = assertNotNull( + gl.getUniformLocation(this.program!, 'maskTexture'), + 'Uniform location'); + } + + protected override configureUniforms(): void { + super.configureUniforms(); + const gl = this.gl!; + gl.uniform1i(this.backgroundTextureUniform!, 0); + gl.uniform1i(this.colorMappingTextureUniform!, 1); + gl.uniform1i(this.maskTextureUniform!, 2); + } + + override close(): void { + if (this.backgroundTexture) { + this.gl!.deleteTexture(this.backgroundTexture); + } + if (this.colorMappingTexture) { + this.gl!.deleteTexture(this.colorMappingTexture); + } + super.close(); + } +} diff --git a/mediapipe/tasks/web/vision/core/image_shader_context.ts b/mediapipe/tasks/web/vision/core/image_shader_context.ts index eb17d001a..3dec9da95 100644 --- a/mediapipe/tasks/web/vision/core/image_shader_context.ts +++ b/mediapipe/tasks/web/vision/core/image_shader_context.ts @@ -27,9 +27,9 @@ const FRAGMENT_SHADER = ` precision mediump float; varying vec2 vTex; uniform sampler2D inputTexture; - void main() { - gl_FragColor = texture2D(inputTexture, vTex); - } + void main() { + gl_FragColor = texture2D(inputTexture, vTex); + } `; /** Helper to assert that `value` is not null. */ @@ -73,9 +73,9 @@ class MPImageShaderBuffers { * For internal use only. */ export class MPImageShaderContext { - private gl?: WebGL2RenderingContext; + protected gl?: WebGL2RenderingContext; private framebuffer?: WebGLFramebuffer; - private program?: WebGLProgram; + protected program?: WebGLProgram; private vertexShader?: WebGLShader; private fragmentShader?: WebGLShader; private aVertex?: GLint; @@ -94,6 +94,14 @@ export class MPImageShaderContext { */ private shaderBuffersFlipVertically?: MPImageShaderBuffers; + protected getFragmentShader(): string { + return FRAGMENT_SHADER; + } + + protected getVertexShader(): string { + return VERTEX_SHADER; + } + private compileShader(source: string, type: number): WebGLShader { const gl = this.gl!; const shader = @@ -108,14 +116,15 @@ export class MPImageShaderContext { return shader; } - private setupShaders(): void { + protected setupShaders(): void { const gl = this.gl!; this.program = assertNotNull(gl.createProgram()!, 'Failed to create WebGL program'); - this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER); + this.vertexShader = + this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER); this.fragmentShader = - this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER); + this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER); gl.linkProgram(this.program); const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS); @@ -128,6 +137,10 @@ export class MPImageShaderContext { this.aTex = gl.getAttribLocation(this.program, 'aTex'); } + protected setupTextures(): void {} + + protected configureUniforms(): void {} + private createBuffers(flipVertically: boolean): MPImageShaderBuffers { const gl = this.gl!; const vertexArrayObject = @@ -193,17 +206,44 @@ export class MPImageShaderContext { if (!this.program) { this.setupShaders(); + this.setupTextures(); } const shaderBuffers = this.getShaderBuffers(flipVertically); gl.useProgram(this.program!); shaderBuffers.bind(); + this.configureUniforms(); const result = callback(); shaderBuffers.unbind(); return result; } + /** + * Creates and configures a texture. + * + * @param gl The rendering context. + * @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and + * `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`. + * @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and + * `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`. + */ + createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum): + WebGLTexture { + this.maybeInitGL(gl); + const texture = + assertNotNull(gl.createTexture(), 'Failed to create texture'); + gl.bindTexture(gl.TEXTURE_2D, texture); + gl.texParameteri( + gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE); + gl.texParameteri( + gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR); + gl.bindTexture(gl.TEXTURE_2D, null); + return texture; + } + /** * Binds a framebuffer to the canvas. If the framebuffer does not yet exist, * creates it first. Binds the provided texture to the framebuffer. diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index ebb3be16a..3ee981bab 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -16,24 +16,6 @@ * limitations under the License. */ -// Pre-baked color table for a maximum of 12 classes. -const CM_ALPHA = 128; -const COLOR_MAP: Array<[number, number, number, number]> = [ - [0, 0, 0, CM_ALPHA], // class 0 is BG = transparent - [255, 0, 0, CM_ALPHA], // class 1 is red - [0, 255, 0, CM_ALPHA], // class 2 is light green - [0, 0, 255, CM_ALPHA], // class 3 is blue - [255, 255, 0, CM_ALPHA], // class 4 is yellow - [255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta - [0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua - [128, 128, 128, CM_ALPHA], // class 7 is gray - [255, 128, 0, CM_ALPHA], // class 8 is orange - [128, 0, 255, CM_ALPHA], // class 9 is dark purple - [0, 128, 0, CM_ALPHA], // class 10 is dark green - [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? -]; - - /** Helper function to draw a confidence mask */ export function drawConfidenceMask( ctx: CanvasRenderingContext2D, image: Float32Array, width: number, @@ -47,23 +29,3 @@ export function drawConfidenceMask( } ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0); } - -/** - * Helper function to draw a category mask. For GPU, we only have F32Arrays - * for now. - */ -export function drawCategoryMask( - ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array, - width: number, height: number): void { - const rgbaArray = new Uint8ClampedArray(width * height * 4); - const isFloatArray = image instanceof Float32Array; - for (let i = 0; i < image.length; i++) { - const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - const color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - rgbaArray[4 * i] = color[0]; - rgbaArray[4 * i + 1] = color[1]; - rgbaArray[4 * i + 2] = color[2]; - rgbaArray[4 * i + 3] = color[3]; - } - ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0); -}