Add drawCategoryMask() to our public API

PiperOrigin-RevId: 578526413
This commit is contained in:
Sebastian Schmidt 2023-11-01 08:31:11 -07:00 committed by Copybara-Service
parent 3a55f1156a
commit 9474394768
6 changed files with 594 additions and 57 deletions

View File

@ -31,27 +31,57 @@ mediapipe_ts_library(
mediapipe_ts_library( mediapipe_ts_library(
name = "drawing_utils", name = "drawing_utils",
srcs = ["drawing_utils.ts"], srcs = [
"drawing_utils.ts",
"drawing_utils_category_mask.ts",
],
deps = [ deps = [
":image",
":image_shader_context",
":mask",
":types", ":types",
"//mediapipe/tasks/web/components/containers:bounding_box", "//mediapipe/tasks/web/components/containers:bounding_box",
"//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )
mediapipe_ts_library( mediapipe_ts_library(
name = "image", name = "drawing_utils_test_lib",
srcs = [ testonly = True,
"image.ts", srcs = ["drawing_utils.test.ts"],
"image_shader_context.ts", deps = [
":drawing_utils",
":image",
":image_shader_context",
":mask",
], ],
) )
jasmine_node_test(
name = "drawing_utils_test",
deps = [":drawing_utils_test_lib"],
)
mediapipe_ts_library(
name = "image",
srcs = ["image.ts"],
deps = ["image_shader_context"],
)
mediapipe_ts_library(
name = "image_shader_context",
srcs = ["image_shader_context.ts"],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "image_test_lib", name = "image_test_lib",
testonly = True, testonly = True,
srcs = ["image.test.ts"], srcs = ["image.test.ts"],
deps = [":image"], deps = [
":image",
":image_shader_context",
],
) )
jasmine_node_test( jasmine_node_test(
@ -64,6 +94,7 @@ mediapipe_ts_library(
srcs = ["mask.ts"], srcs = ["mask.ts"],
deps = [ deps = [
":image", ":image",
":image_shader_context",
"//mediapipe/web/graph_runner:platform_utils", "//mediapipe/web/graph_runner:platform_utils",
], ],
) )
@ -74,6 +105,7 @@ mediapipe_ts_library(
srcs = ["mask.test.ts"], srcs = ["mask.test.ts"],
deps = [ deps = [
":image", ":image",
":image_shader_context",
":mask", ":mask",
], ],
) )
@ -89,6 +121,7 @@ mediapipe_ts_library(
deps = [ deps = [
":image", ":image",
":image_processing_options", ":image_processing_options",
":image_shader_context",
":mask", ":mask",
":vision_task_options", ":vision_task_options",
"//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto",

View File

@ -0,0 +1,103 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {DrawingUtils} from './drawing_utils';
import {MPImageShaderContext} from './image_shader_context';
import {MPMask} from './mask';
const WIDTH = 2;
const HEIGHT = 2;
const skip = typeof document === 'undefined';
if (skip) {
console.log('These tests must be run in a browser.');
}
(skip ? xdescribe : describe)('DrawingUtils', () => {
let shaderContext = new MPImageShaderContext();
let canvas2D: HTMLCanvasElement;
let context2D: CanvasRenderingContext2D;
let drawingUtils2D: DrawingUtils;
let canvasWebGL: HTMLCanvasElement;
let contextWebGL: WebGL2RenderingContext;
let drawingUtilsWebGL: DrawingUtils;
beforeEach(() => {
shaderContext = new MPImageShaderContext();
canvasWebGL = document.createElement('canvas');
canvasWebGL.width = WIDTH;
canvasWebGL.height = HEIGHT;
contextWebGL = canvasWebGL.getContext('webgl2')!;
drawingUtilsWebGL = new DrawingUtils(contextWebGL);
canvas2D = document.createElement('canvas');
canvas2D.width = WIDTH;
canvas2D.height = HEIGHT;
context2D = canvas2D.getContext('2d')!;
drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
});
afterEach(() => {
shaderContext.close();
drawingUtils2D.close();
drawingUtilsWebGL.close();
});
describe('drawCategoryMask() ', () => {
const colors = [
[0, 0, 0, 255],
[0, 255, 0, 255],
[0, 0, 255, 255],
[255, 255, 255, 255],
];
const expectedResult = new Uint8Array(
[0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
);
it('on 2D canvas', () => {
const categoryMask = new MPMask(
[new Uint8Array([0, 1, 2, 3])],
/* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
HEIGHT);
drawingUtils2D.drawCategoryMask(categoryMask, colors);
const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
expect(actualResult)
.toEqual(new Uint8ClampedArray(expectedResult.buffer));
});
it('on WebGL canvas', () => {
const categoryMask = new MPMask(
[new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped
/* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
HEIGHT);
drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
contextWebGL.readPixels(
0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
actualResult);
expect(actualResult).toEqual(expectedResult);
});
});
// TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
});

View File

@ -16,7 +16,11 @@
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask';
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {Connection} from '../../../../tasks/web/vision/core/types'; import {Connection} from '../../../../tasks/web/vision/core/types';
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
/** /**
* A user-defined callback to take input data and map it to a custom output * A user-defined callback to take input data and map it to a custom output
@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
*/ */
export type Callback<I, O> = (input: I) => O; export type Callback<I, O> = (input: I) => O;
// Used in public API
export {ImageSource};
/** Data that a user can use to specialize drawing options. */ /** Data that a user can use to specialize drawing options. */
export declare interface LandmarkData { export declare interface LandmarkData {
index?: number; index?: number;
@ -31,6 +38,32 @@ export declare interface LandmarkData {
to?: NormalizedLandmark; to?: NormalizedLandmark;
} }
/** A color map with 22 classes. Used in our demos. */
export const DEFAULT_CATEGORY_TO_COLOR_MAP = [
[0, 0, 0, 0], // class 0 is BG = transparent
[255, 0, 0, 255], // class 1 is red
[0, 255, 0, 255], // class 2 is light green
[0, 0, 255, 255], // class 3 is blue
[255, 255, 0, 255], // class 4 is yellow
[255, 0, 255, 255], // class 5 is light purple / magenta
[0, 255, 255, 255], // class 6 is light blue / aqua
[128, 128, 128, 255], // class 7 is gray
[255, 100, 0, 255], // class 8 is dark orange
[128, 0, 255, 255], // class 9 is dark purple
[0, 150, 0, 255], // class 10 is green
[255, 255, 255, 255], // class 11 is white
[255, 105, 180, 255], // class 12 is pink
[255, 150, 0, 255], // class 13 is orange
[255, 250, 224, 255], // class 14 is light yellow
[148, 0, 211, 255], // class 15 is dark violet
[0, 100, 0, 255], // class 16 is dark green
[0, 0, 128, 255], // class 17 is navy blue
[165, 42, 42, 255], // class 18 is brown
[64, 224, 208, 255], // class 19 is turquoise
[255, 218, 185, 255], // class 20 is peach
[192, 192, 192, 255], // class 21 is silver
];
/** /**
* Options for customizing the drawing routines * Options for customizing the drawing routines
*/ */
@ -77,14 +110,47 @@ function resolve<O, I>(value: O|Callback<I, O>, data: I): O {
return value instanceof Function ? value(data) : value; return value instanceof Function ? value(data) : value;
} }
export {RGBAColor, CategoryToColorMap};
/** Helper class to visualize the result of a MediaPipe Vision task. */ /** Helper class to visualize the result of a MediaPipe Vision task. */
export class DrawingUtils { export class DrawingUtils {
private categoryMaskShaderContext?: CategoryMaskShaderContext;
private convertToWebGLTextureShaderContext?: MPImageShaderContext;
private readonly context2d?: CanvasRenderingContext2D;
private readonly contextWebGL?: WebGL2RenderingContext;
/** /**
* Creates a new DrawingUtils class. * Creates a new DrawingUtils class.
* *
* @param ctx The canvas to render onto. * @param gpuContext The WebGL canvas rendering context to render into. If
* your Task is using a GPU delegate, the context must be obtained from
* its canvas (provided via `setOptions({ canvas: .. })`).
*/ */
constructor(private readonly ctx: CanvasRenderingContext2D) {} constructor(gpuContext: WebGL2RenderingContext);
/**
* Creates a new DrawingUtils class.
*
* @param cpuContext The 2D canvas rendering context to render into. If
* you are rendering GPU data you must also provide `gpuContext` to allow
* for data conversion.
* @param gpuContext A WebGL canvas that is used for GPU rendering and for
* converting GPU to CPU data. If your Task is using a GPU delegate, the
* context must be obtained from its canvas (provided via
* `setOptions({ canvas: .. })`).
*/
constructor(
cpuContext: CanvasRenderingContext2D,
gpuContext?: WebGL2RenderingContext);
constructor(
cpuOrGpuGontext: CanvasRenderingContext2D|WebGL2RenderingContext,
gpuContext?: WebGL2RenderingContext) {
if (cpuOrGpuGontext instanceof CanvasRenderingContext2D) {
this.context2d = cpuOrGpuGontext;
this.contextWebGL = gpuContext;
} else {
this.contextWebGL = cpuOrGpuGontext;
}
}
/** /**
* Restricts a number between two endpoints (order doesn't matter). * Restricts a number between two endpoints (order doesn't matter).
@ -120,9 +186,35 @@ export class DrawingUtils {
return DrawingUtils.clamp(out, y0, y1); return DrawingUtils.clamp(out, y0, y1);
} }
private getCanvasRenderingContext(): CanvasRenderingContext2D {
if (!this.context2d) {
throw new Error(
'CPU rendering requested but CanvasRenderingContext2D not provided.');
}
return this.context2d;
}
private getWebGLRenderingContext(): WebGL2RenderingContext {
if (!this.contextWebGL) {
throw new Error(
'GPU rendering requested but WebGL2RenderingContext not provided.');
}
return this.contextWebGL;
}
private getCategoryMaskShaderContext(): CategoryMaskShaderContext {
if (!this.categoryMaskShaderContext) {
this.categoryMaskShaderContext = new CategoryMaskShaderContext();
}
return this.categoryMaskShaderContext;
}
/** /**
* Draws circles onto the provided landmarks. * Draws circles onto the provided landmarks.
* *
* This method can only be used when `DrawingUtils` is initialized with a
* `CanvasRenderingContext2D`.
*
* @export * @export
* @param landmarks The landmarks to draw. * @param landmarks The landmarks to draw.
* @param style The style to visualize the landmarks. * @param style The style to visualize the landmarks.
@ -132,7 +224,7 @@ export class DrawingUtils {
if (!landmarks) { if (!landmarks) {
return; return;
} }
const ctx = this.ctx; const ctx = this.getCanvasRenderingContext();
const options = addDefaultOptions(style); const options = addDefaultOptions(style);
ctx.save(); ctx.save();
const canvas = ctx.canvas; const canvas = ctx.canvas;
@ -159,6 +251,9 @@ export class DrawingUtils {
/** /**
* Draws lines between landmarks (given a connection graph). * Draws lines between landmarks (given a connection graph).
* *
* This method can only be used when `DrawingUtils` is initialized with a
* `CanvasRenderingContext2D`.
*
* @export * @export
* @param landmarks The landmarks to draw. * @param landmarks The landmarks to draw.
* @param connections The connections array that contains the start and the * @param connections The connections array that contains the start and the
@ -171,7 +266,7 @@ export class DrawingUtils {
if (!landmarks || !connections) { if (!landmarks || !connections) {
return; return;
} }
const ctx = this.ctx; const ctx = this.getCanvasRenderingContext();
const options = addDefaultOptions(style); const options = addDefaultOptions(style);
ctx.save(); ctx.save();
const canvas = ctx.canvas; const canvas = ctx.canvas;
@ -195,12 +290,15 @@ export class DrawingUtils {
/** /**
* Draws a bounding box. * Draws a bounding box.
* *
* This method can only be used when `DrawingUtils` is initialized with a
* `CanvasRenderingContext2D`.
*
* @export * @export
* @param boundingBox The bounding box to draw. * @param boundingBox The bounding box to draw.
* @param style The style to visualize the boundin box. * @param style The style to visualize the boundin box.
*/ */
drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void { drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void {
const ctx = this.ctx; const ctx = this.getCanvasRenderingContext();
const options = addDefaultOptions(style); const options = addDefaultOptions(style);
ctx.save(); ctx.save();
ctx.beginPath(); ctx.beginPath();
@ -218,6 +316,118 @@ export class DrawingUtils {
ctx.fill(); ctx.fill();
ctx.restore(); ctx.restore();
} }
/** Draws a category mask on a CanvasRenderingContext2D. */
private drawCategoryMask2D(
mask: MPMask, background: RGBAColor|ImageSource,
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
// Use the WebGL renderer to draw result on our internal canvas.
const gl = this.getWebGLRenderingContext();
this.runWithWebGLTexture(mask, texture => {
this.drawCategoryMaskWebGL(texture, background, categoryToColorMap);
// Draw the result on the user canvas.
const ctx = this.getCanvasRenderingContext();
ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height);
});
}
/** Draws a category mask on a WebGL2RenderingContext2D. */
private drawCategoryMaskWebGL(
categoryTexture: WebGLTexture, background: RGBAColor|ImageSource,
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
const shaderContext = this.getCategoryMaskShaderContext();
const gl = this.getWebGLRenderingContext();
const backgroundImage = Array.isArray(background) ?
new ImageData(new Uint8ClampedArray(background), 1, 1) :
background;
shaderContext.run(gl, /* flipTexturesVertically= */ true, () => {
shaderContext.bindAndUploadTextures(
categoryTexture, backgroundImage, categoryToColorMap);
gl.clearColor(0, 0, 0, 0);
gl.clear(gl.COLOR_BUFFER_BIT);
gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
shaderContext.unbindTextures();
});
}
/**
* Draws a category mask using the provided category-to-color mapping.
*
* @export
* @param mask A category mask that was returned from a segmentation task.
* @param categoryToColorMap A map that maps category indices to RGBA
* values. You must specify a map entry for each category.
* @param background A color or image to use as the background. Defaults to
* black.
*/
drawCategoryMask(
mask: MPMask, categoryToColorMap: Map<number, RGBAColor>,
background?: RGBAColor|ImageSource): void;
/**
* Draws a category mask using the provided color array.
*
* @export
* @param mask A category mask that was returned from a segmentation task.
* @param categoryToColorMap An array that maps indices to RGBA values. The
* array's indices must correspond to the category indices of the model
* and an entry must be provided for each category.
* @param background A color or image to use as the background. Defaults to
* black.
*/
drawCategoryMask(
mask: MPMask, categoryToColorMap: RGBAColor[],
background?: RGBAColor|ImageSource): void;
drawCategoryMask(
mask: MPMask, categoryToColorMap: CategoryToColorMap,
background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
if (this.context2d) {
this.drawCategoryMask2D(mask, background, categoryToColorMap);
} else {
this.drawCategoryMaskWebGL(
mask.getAsWebGLTexture(), background, categoryToColorMap);
}
}
/**
* Converts the given mask to a WebGLTexture and runs the callback. Cleans
* up any new resources after the callback finished executing.
*/
private runWithWebGLTexture(
mask: MPMask, callback: (texture: WebGLTexture) => void): void {
if (!mask.hasWebGLTexture()) {
// Re-create the MPMask but use our the WebGL canvas so we can draw the
// texture directly.
const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() :
mask.getAsUint8Array();
this.convertToWebGLTextureShaderContext =
this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext();
const gl = this.getWebGLRenderingContext();
const convertedMask = new MPMask(
[data],
/* ownsWebGlTexture= */ false,
gl.canvas,
this.convertToWebGLTextureShaderContext,
mask.width,
mask.height,
);
callback(convertedMask.getAsWebGLTexture());
convertedMask.close();
} else {
callback(mask.getAsWebGLTexture());
}
}
/**
* Frees all WebGL resources held by this class.
* @export
*/
close(): void {
this.categoryMaskShaderContext?.close();
this.categoryMaskShaderContext = undefined;
this.convertToWebGLTextureShaderContext?.close();
this.convertToWebGLTextureShaderContext = undefined;
}
} }

View File

@ -0,0 +1,189 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
/**
* A fragment shader that maps categories to colors based on a background
* texture, a mask texture and a 256x1 "color mapping texture" that contains one
* color for each pixel.
*/
const FRAGMENT_SHADER = `
precision mediump float;
uniform sampler2D backgroundTexture;
uniform sampler2D maskTexture;
uniform sampler2D colorMappingTexture;
varying vec2 vTex;
void main() {
vec4 backgroundColor = texture2D(backgroundTexture, vTex);
float category = texture2D(maskTexture, vTex).r;
vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0));
gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a);
}
`;
/**
* A four channel color with values for red, green, blue and alpha
* respectively.
*/
export type RGBAColor = [number, number, number, number]|number[];
/**
* A category to color mapping that uses either a map or an array to assign
* category indexes to RGBA colors.
*/
export type CategoryToColorMap = Map<number, RGBAColor>|RGBAColor[];
/** Checks CategoryToColorMap maps for deep equality. */
function isEqualColorMap(
a: CategoryToColorMap, b: CategoryToColorMap): boolean {
if (a !== b) {
return false;
}
const aEntries = a.entries();
const bEntries = b.entries();
for (const [aKey, aValue] of aEntries) {
const bNext = bEntries.next();
if (bNext.done) {
return false;
}
const [bKey, bValue] = bNext.value;
if (aKey !== bKey) {
return false;
}
if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] ||
aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) {
return false;
}
}
return !!bEntries.next().done;
}
/** A drawing util class for category masks. */
export class CategoryMaskShaderContext extends MPImageShaderContext {
backgroundTexture?: WebGLTexture;
colorMappingTexture?: WebGLTexture;
colorMappingTextureUniform?: WebGLUniformLocation;
backgroundTextureUniform?: WebGLUniformLocation;
maskTextureUniform?: WebGLUniformLocation;
currentColorMap?: CategoryToColorMap;
bindAndUploadTextures(
categoryMask: WebGLTexture, background: ImageSource,
colorMap: Map<number, number[]>|number[][]) {
const gl = this.gl!;
// TODO: We should avoid uploading textures from CPU to GPU
// if the textures haven't changed. This can lead to drastic performance
// slowdowns (~50ms per frame). Users can reduce the penalty by passing a
// canvas object instead of ImageData/HTMLImageElement.
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background);
// Bind color mapping texture if changed.
if (!this.currentColorMap ||
!isEqualColorMap(this.currentColorMap, colorMap)) {
this.currentColorMap = colorMap;
const pixels = new Array(256 * 4).fill(0);
colorMap.forEach((rgba, index) => {
if (rgba.length !== 4) {
throw new Error(
`Color at index ${index} is not a four-channel value.`);
}
pixels[index * 4] = rgba[0];
pixels[index * 4 + 1] = rgba[1];
pixels[index * 4 + 2] = rgba[2];
pixels[index * 4 + 3] = rgba[3];
});
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE,
new Uint8Array(pixels));
} else {
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
}
// Bind category mask
gl.activeTexture(gl.TEXTURE2);
gl.bindTexture(gl.TEXTURE_2D, categoryMask);
}
unbindTextures() {
const gl = this.gl!;
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, null);
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, null);
gl.activeTexture(gl.TEXTURE2);
gl.bindTexture(gl.TEXTURE_2D, null);
}
protected override getFragmentShader(): string {
return FRAGMENT_SHADER;
}
protected override setupTextures(): void {
const gl = this.gl!;
gl.activeTexture(gl.TEXTURE0);
this.backgroundTexture = this.createTexture(gl, gl.LINEAR);
// Use `gl.NEAREST` to prevent interpolating values in our category to
// color map.
this.colorMappingTexture = this.createTexture(gl, gl.NEAREST);
}
protected override setupShaders(): void {
super.setupShaders();
const gl = this.gl!;
this.backgroundTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'backgroundTexture'),
'Uniform location');
this.colorMappingTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'colorMappingTexture'),
'Uniform location');
this.maskTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'maskTexture'),
'Uniform location');
}
protected override configureUniforms(): void {
super.configureUniforms();
const gl = this.gl!;
gl.uniform1i(this.backgroundTextureUniform!, 0);
gl.uniform1i(this.colorMappingTextureUniform!, 1);
gl.uniform1i(this.maskTextureUniform!, 2);
}
override close(): void {
if (this.backgroundTexture) {
this.gl!.deleteTexture(this.backgroundTexture);
}
if (this.colorMappingTexture) {
this.gl!.deleteTexture(this.colorMappingTexture);
}
super.close();
}
}

View File

@ -73,9 +73,9 @@ class MPImageShaderBuffers {
* For internal use only. * For internal use only.
*/ */
export class MPImageShaderContext { export class MPImageShaderContext {
private gl?: WebGL2RenderingContext; protected gl?: WebGL2RenderingContext;
private framebuffer?: WebGLFramebuffer; private framebuffer?: WebGLFramebuffer;
private program?: WebGLProgram; protected program?: WebGLProgram;
private vertexShader?: WebGLShader; private vertexShader?: WebGLShader;
private fragmentShader?: WebGLShader; private fragmentShader?: WebGLShader;
private aVertex?: GLint; private aVertex?: GLint;
@ -94,6 +94,14 @@ export class MPImageShaderContext {
*/ */
private shaderBuffersFlipVertically?: MPImageShaderBuffers; private shaderBuffersFlipVertically?: MPImageShaderBuffers;
protected getFragmentShader(): string {
return FRAGMENT_SHADER;
}
protected getVertexShader(): string {
return VERTEX_SHADER;
}
private compileShader(source: string, type: number): WebGLShader { private compileShader(source: string, type: number): WebGLShader {
const gl = this.gl!; const gl = this.gl!;
const shader = const shader =
@ -108,14 +116,15 @@ export class MPImageShaderContext {
return shader; return shader;
} }
private setupShaders(): void { protected setupShaders(): void {
const gl = this.gl!; const gl = this.gl!;
this.program = this.program =
assertNotNull(gl.createProgram()!, 'Failed to create WebGL program'); assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER); this.vertexShader =
this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER);
this.fragmentShader = this.fragmentShader =
this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER); this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER);
gl.linkProgram(this.program); gl.linkProgram(this.program);
const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS); const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS);
@ -128,6 +137,10 @@ export class MPImageShaderContext {
this.aTex = gl.getAttribLocation(this.program, 'aTex'); this.aTex = gl.getAttribLocation(this.program, 'aTex');
} }
protected setupTextures(): void {}
protected configureUniforms(): void {}
private createBuffers(flipVertically: boolean): MPImageShaderBuffers { private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
const gl = this.gl!; const gl = this.gl!;
const vertexArrayObject = const vertexArrayObject =
@ -193,17 +206,44 @@ export class MPImageShaderContext {
if (!this.program) { if (!this.program) {
this.setupShaders(); this.setupShaders();
this.setupTextures();
} }
const shaderBuffers = this.getShaderBuffers(flipVertically); const shaderBuffers = this.getShaderBuffers(flipVertically);
gl.useProgram(this.program!); gl.useProgram(this.program!);
shaderBuffers.bind(); shaderBuffers.bind();
this.configureUniforms();
const result = callback(); const result = callback();
shaderBuffers.unbind(); shaderBuffers.unbind();
return result; return result;
} }
/**
* Creates and configures a texture.
*
* @param gl The rendering context.
* @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and
* `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`.
* @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and
* `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`.
*/
createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum):
WebGLTexture {
this.maybeInitGL(gl);
const texture =
assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texParameteri(
gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE);
gl.texParameteri(
gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR);
gl.bindTexture(gl.TEXTURE_2D, null);
return texture;
}
/** /**
* Binds a framebuffer to the canvas. If the framebuffer does not yet exist, * Binds a framebuffer to the canvas. If the framebuffer does not yet exist,
* creates it first. Binds the provided texture to the framebuffer. * creates it first. Binds the provided texture to the framebuffer.

View File

@ -16,24 +16,6 @@
* limitations under the License. * limitations under the License.
*/ */
// Pre-baked color table for a maximum of 12 classes.
const CM_ALPHA = 128;
const COLOR_MAP: Array<[number, number, number, number]> = [
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
[255, 0, 0, CM_ALPHA], // class 1 is red
[0, 255, 0, CM_ALPHA], // class 2 is light green
[0, 0, 255, CM_ALPHA], // class 3 is blue
[255, 255, 0, CM_ALPHA], // class 4 is yellow
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
[128, 128, 128, CM_ALPHA], // class 7 is gray
[255, 128, 0, CM_ALPHA], // class 8 is orange
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
[0, 128, 0, CM_ALPHA], // class 10 is dark green
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
];
/** Helper function to draw a confidence mask */ /** Helper function to draw a confidence mask */
export function drawConfidenceMask( export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array, width: number, ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
@ -47,23 +29,3 @@ export function drawConfidenceMask(
} }
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0); ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
} }
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
width: number, height: number): void {
const rgbaArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
rgbaArray[4 * i] = color[0];
rgbaArray[4 * i + 1] = color[1];
rgbaArray[4 * i + 2] = color[2];
rgbaArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
}