Add drawCategoryMask() to our public API
PiperOrigin-RevId: 578526413
This commit is contained in:
parent
3a55f1156a
commit
9474394768
|
@ -31,27 +31,57 @@ mediapipe_ts_library(
|
|||
|
||||
mediapipe_ts_library(
|
||||
name = "drawing_utils",
|
||||
srcs = ["drawing_utils.ts"],
|
||||
srcs = [
|
||||
"drawing_utils.ts",
|
||||
"drawing_utils_category_mask.ts",
|
||||
],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
":types",
|
||||
"//mediapipe/tasks/web/components/containers:bounding_box",
|
||||
"//mediapipe/tasks/web/components/containers:landmark",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image",
|
||||
srcs = [
|
||||
"image.ts",
|
||||
"image_shader_context.ts",
|
||||
name = "drawing_utils_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["drawing_utils.test.ts"],
|
||||
deps = [
|
||||
":drawing_utils",
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "drawing_utils_test",
|
||||
deps = [":drawing_utils_test_lib"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image",
|
||||
srcs = ["image.ts"],
|
||||
deps = ["image_shader_context"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_shader_context",
|
||||
srcs = ["image_shader_context.ts"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["image.test.ts"],
|
||||
deps = [":image"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
|
@ -64,6 +94,7 @@ mediapipe_ts_library(
|
|||
srcs = ["mask.ts"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
"//mediapipe/web/graph_runner:platform_utils",
|
||||
],
|
||||
)
|
||||
|
@ -74,6 +105,7 @@ mediapipe_ts_library(
|
|||
srcs = ["mask.test.ts"],
|
||||
deps = [
|
||||
":image",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
],
|
||||
)
|
||||
|
@ -89,6 +121,7 @@ mediapipe_ts_library(
|
|||
deps = [
|
||||
":image",
|
||||
":image_processing_options",
|
||||
":image_shader_context",
|
||||
":mask",
|
||||
":vision_task_options",
|
||||
"//mediapipe/framework/formats:rect_jspb_proto",
|
||||
|
|
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
import {DrawingUtils} from './drawing_utils';
|
||||
import {MPImageShaderContext} from './image_shader_context';
|
||||
import {MPMask} from './mask';
|
||||
|
||||
const WIDTH = 2;
|
||||
const HEIGHT = 2;
|
||||
|
||||
const skip = typeof document === 'undefined';
|
||||
if (skip) {
|
||||
console.log('These tests must be run in a browser.');
|
||||
}
|
||||
|
||||
(skip ? xdescribe : describe)('DrawingUtils', () => {
|
||||
let shaderContext = new MPImageShaderContext();
|
||||
let canvas2D: HTMLCanvasElement;
|
||||
let context2D: CanvasRenderingContext2D;
|
||||
let drawingUtils2D: DrawingUtils;
|
||||
let canvasWebGL: HTMLCanvasElement;
|
||||
let contextWebGL: WebGL2RenderingContext;
|
||||
let drawingUtilsWebGL: DrawingUtils;
|
||||
|
||||
beforeEach(() => {
|
||||
shaderContext = new MPImageShaderContext();
|
||||
|
||||
canvasWebGL = document.createElement('canvas');
|
||||
canvasWebGL.width = WIDTH;
|
||||
canvasWebGL.height = HEIGHT;
|
||||
contextWebGL = canvasWebGL.getContext('webgl2')!;
|
||||
drawingUtilsWebGL = new DrawingUtils(contextWebGL);
|
||||
|
||||
canvas2D = document.createElement('canvas');
|
||||
canvas2D.width = WIDTH;
|
||||
canvas2D.height = HEIGHT;
|
||||
context2D = canvas2D.getContext('2d')!;
|
||||
drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
shaderContext.close();
|
||||
drawingUtils2D.close();
|
||||
drawingUtilsWebGL.close();
|
||||
});
|
||||
|
||||
describe('drawCategoryMask() ', () => {
|
||||
const colors = [
|
||||
[0, 0, 0, 255],
|
||||
[0, 255, 0, 255],
|
||||
[0, 0, 255, 255],
|
||||
[255, 255, 255, 255],
|
||||
];
|
||||
const expectedResult = new Uint8Array(
|
||||
[0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
|
||||
);
|
||||
|
||||
it('on 2D canvas', () => {
|
||||
const categoryMask = new MPMask(
|
||||
[new Uint8Array([0, 1, 2, 3])],
|
||||
/* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
|
||||
HEIGHT);
|
||||
|
||||
drawingUtils2D.drawCategoryMask(categoryMask, colors);
|
||||
|
||||
const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
|
||||
expect(actualResult)
|
||||
.toEqual(new Uint8ClampedArray(expectedResult.buffer));
|
||||
});
|
||||
|
||||
it('on WebGL canvas', () => {
|
||||
const categoryMask = new MPMask(
|
||||
[new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped
|
||||
/* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
|
||||
HEIGHT);
|
||||
|
||||
drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
|
||||
|
||||
const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
|
||||
contextWebGL.readPixels(
|
||||
0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
|
||||
actualResult);
|
||||
expect(actualResult).toEqual(expectedResult);
|
||||
});
|
||||
});
|
||||
|
||||
// TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
|
||||
});
|
|
@ -16,7 +16,11 @@
|
|||
|
||||
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
|
||||
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||
import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask';
|
||||
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||
import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||
|
||||
/**
|
||||
* A user-defined callback to take input data and map it to a custom output
|
||||
|
@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
|
|||
*/
|
||||
export type Callback<I, O> = (input: I) => O;
|
||||
|
||||
// Used in public API
|
||||
export {ImageSource};
|
||||
|
||||
/** Data that a user can use to specialize drawing options. */
|
||||
export declare interface LandmarkData {
|
||||
index?: number;
|
||||
|
@ -31,6 +38,32 @@ export declare interface LandmarkData {
|
|||
to?: NormalizedLandmark;
|
||||
}
|
||||
|
||||
/** A color map with 22 classes. Used in our demos. */
|
||||
export const DEFAULT_CATEGORY_TO_COLOR_MAP = [
|
||||
[0, 0, 0, 0], // class 0 is BG = transparent
|
||||
[255, 0, 0, 255], // class 1 is red
|
||||
[0, 255, 0, 255], // class 2 is light green
|
||||
[0, 0, 255, 255], // class 3 is blue
|
||||
[255, 255, 0, 255], // class 4 is yellow
|
||||
[255, 0, 255, 255], // class 5 is light purple / magenta
|
||||
[0, 255, 255, 255], // class 6 is light blue / aqua
|
||||
[128, 128, 128, 255], // class 7 is gray
|
||||
[255, 100, 0, 255], // class 8 is dark orange
|
||||
[128, 0, 255, 255], // class 9 is dark purple
|
||||
[0, 150, 0, 255], // class 10 is green
|
||||
[255, 255, 255, 255], // class 11 is white
|
||||
[255, 105, 180, 255], // class 12 is pink
|
||||
[255, 150, 0, 255], // class 13 is orange
|
||||
[255, 250, 224, 255], // class 14 is light yellow
|
||||
[148, 0, 211, 255], // class 15 is dark violet
|
||||
[0, 100, 0, 255], // class 16 is dark green
|
||||
[0, 0, 128, 255], // class 17 is navy blue
|
||||
[165, 42, 42, 255], // class 18 is brown
|
||||
[64, 224, 208, 255], // class 19 is turquoise
|
||||
[255, 218, 185, 255], // class 20 is peach
|
||||
[192, 192, 192, 255], // class 21 is silver
|
||||
];
|
||||
|
||||
/**
|
||||
* Options for customizing the drawing routines
|
||||
*/
|
||||
|
@ -77,14 +110,47 @@ function resolve<O, I>(value: O|Callback<I, O>, data: I): O {
|
|||
return value instanceof Function ? value(data) : value;
|
||||
}
|
||||
|
||||
export {RGBAColor, CategoryToColorMap};
|
||||
|
||||
/** Helper class to visualize the result of a MediaPipe Vision task. */
|
||||
export class DrawingUtils {
|
||||
private categoryMaskShaderContext?: CategoryMaskShaderContext;
|
||||
private convertToWebGLTextureShaderContext?: MPImageShaderContext;
|
||||
private readonly context2d?: CanvasRenderingContext2D;
|
||||
private readonly contextWebGL?: WebGL2RenderingContext;
|
||||
|
||||
/**
|
||||
* Creates a new DrawingUtils class.
|
||||
*
|
||||
* @param ctx The canvas to render onto.
|
||||
* @param gpuContext The WebGL canvas rendering context to render into. If
|
||||
* your Task is using a GPU delegate, the context must be obtained from
|
||||
* its canvas (provided via `setOptions({ canvas: .. })`).
|
||||
*/
|
||||
constructor(private readonly ctx: CanvasRenderingContext2D) {}
|
||||
constructor(gpuContext: WebGL2RenderingContext);
|
||||
/**
|
||||
* Creates a new DrawingUtils class.
|
||||
*
|
||||
* @param cpuContext The 2D canvas rendering context to render into. If
|
||||
* you are rendering GPU data you must also provide `gpuContext` to allow
|
||||
* for data conversion.
|
||||
* @param gpuContext A WebGL canvas that is used for GPU rendering and for
|
||||
* converting GPU to CPU data. If your Task is using a GPU delegate, the
|
||||
* context must be obtained from its canvas (provided via
|
||||
* `setOptions({ canvas: .. })`).
|
||||
*/
|
||||
constructor(
|
||||
cpuContext: CanvasRenderingContext2D,
|
||||
gpuContext?: WebGL2RenderingContext);
|
||||
constructor(
|
||||
cpuOrGpuGontext: CanvasRenderingContext2D|WebGL2RenderingContext,
|
||||
gpuContext?: WebGL2RenderingContext) {
|
||||
if (cpuOrGpuGontext instanceof CanvasRenderingContext2D) {
|
||||
this.context2d = cpuOrGpuGontext;
|
||||
this.contextWebGL = gpuContext;
|
||||
} else {
|
||||
this.contextWebGL = cpuOrGpuGontext;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Restricts a number between two endpoints (order doesn't matter).
|
||||
|
@ -120,9 +186,35 @@ export class DrawingUtils {
|
|||
return DrawingUtils.clamp(out, y0, y1);
|
||||
}
|
||||
|
||||
private getCanvasRenderingContext(): CanvasRenderingContext2D {
|
||||
if (!this.context2d) {
|
||||
throw new Error(
|
||||
'CPU rendering requested but CanvasRenderingContext2D not provided.');
|
||||
}
|
||||
return this.context2d;
|
||||
}
|
||||
|
||||
private getWebGLRenderingContext(): WebGL2RenderingContext {
|
||||
if (!this.contextWebGL) {
|
||||
throw new Error(
|
||||
'GPU rendering requested but WebGL2RenderingContext not provided.');
|
||||
}
|
||||
return this.contextWebGL;
|
||||
}
|
||||
|
||||
private getCategoryMaskShaderContext(): CategoryMaskShaderContext {
|
||||
if (!this.categoryMaskShaderContext) {
|
||||
this.categoryMaskShaderContext = new CategoryMaskShaderContext();
|
||||
}
|
||||
return this.categoryMaskShaderContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Draws circles onto the provided landmarks.
|
||||
*
|
||||
* This method can only be used when `DrawingUtils` is initialized with a
|
||||
* `CanvasRenderingContext2D`.
|
||||
*
|
||||
* @export
|
||||
* @param landmarks The landmarks to draw.
|
||||
* @param style The style to visualize the landmarks.
|
||||
|
@ -132,7 +224,7 @@ export class DrawingUtils {
|
|||
if (!landmarks) {
|
||||
return;
|
||||
}
|
||||
const ctx = this.ctx;
|
||||
const ctx = this.getCanvasRenderingContext();
|
||||
const options = addDefaultOptions(style);
|
||||
ctx.save();
|
||||
const canvas = ctx.canvas;
|
||||
|
@ -159,6 +251,9 @@ export class DrawingUtils {
|
|||
/**
|
||||
* Draws lines between landmarks (given a connection graph).
|
||||
*
|
||||
* This method can only be used when `DrawingUtils` is initialized with a
|
||||
* `CanvasRenderingContext2D`.
|
||||
*
|
||||
* @export
|
||||
* @param landmarks The landmarks to draw.
|
||||
* @param connections The connections array that contains the start and the
|
||||
|
@ -171,7 +266,7 @@ export class DrawingUtils {
|
|||
if (!landmarks || !connections) {
|
||||
return;
|
||||
}
|
||||
const ctx = this.ctx;
|
||||
const ctx = this.getCanvasRenderingContext();
|
||||
const options = addDefaultOptions(style);
|
||||
ctx.save();
|
||||
const canvas = ctx.canvas;
|
||||
|
@ -195,12 +290,15 @@ export class DrawingUtils {
|
|||
/**
|
||||
* Draws a bounding box.
|
||||
*
|
||||
* This method can only be used when `DrawingUtils` is initialized with a
|
||||
* `CanvasRenderingContext2D`.
|
||||
*
|
||||
* @export
|
||||
* @param boundingBox The bounding box to draw.
|
||||
* @param style The style to visualize the boundin box.
|
||||
*/
|
||||
drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void {
|
||||
const ctx = this.ctx;
|
||||
const ctx = this.getCanvasRenderingContext();
|
||||
const options = addDefaultOptions(style);
|
||||
ctx.save();
|
||||
ctx.beginPath();
|
||||
|
@ -218,6 +316,118 @@ export class DrawingUtils {
|
|||
ctx.fill();
|
||||
ctx.restore();
|
||||
}
|
||||
|
||||
/** Draws a category mask on a CanvasRenderingContext2D. */
|
||||
private drawCategoryMask2D(
|
||||
mask: MPMask, background: RGBAColor|ImageSource,
|
||||
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
|
||||
// Use the WebGL renderer to draw result on our internal canvas.
|
||||
const gl = this.getWebGLRenderingContext();
|
||||
this.runWithWebGLTexture(mask, texture => {
|
||||
this.drawCategoryMaskWebGL(texture, background, categoryToColorMap);
|
||||
// Draw the result on the user canvas.
|
||||
const ctx = this.getCanvasRenderingContext();
|
||||
ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height);
|
||||
});
|
||||
}
|
||||
|
||||
/** Draws a category mask on a WebGL2RenderingContext2D. */
|
||||
private drawCategoryMaskWebGL(
|
||||
categoryTexture: WebGLTexture, background: RGBAColor|ImageSource,
|
||||
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
|
||||
const shaderContext = this.getCategoryMaskShaderContext();
|
||||
const gl = this.getWebGLRenderingContext();
|
||||
const backgroundImage = Array.isArray(background) ?
|
||||
new ImageData(new Uint8ClampedArray(background), 1, 1) :
|
||||
background;
|
||||
|
||||
shaderContext.run(gl, /* flipTexturesVertically= */ true, () => {
|
||||
shaderContext.bindAndUploadTextures(
|
||||
categoryTexture, backgroundImage, categoryToColorMap);
|
||||
gl.clearColor(0, 0, 0, 0);
|
||||
gl.clear(gl.COLOR_BUFFER_BIT);
|
||||
gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
|
||||
shaderContext.unbindTextures();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Draws a category mask using the provided category-to-color mapping.
|
||||
*
|
||||
* @export
|
||||
* @param mask A category mask that was returned from a segmentation task.
|
||||
* @param categoryToColorMap A map that maps category indices to RGBA
|
||||
* values. You must specify a map entry for each category.
|
||||
* @param background A color or image to use as the background. Defaults to
|
||||
* black.
|
||||
*/
|
||||
drawCategoryMask(
|
||||
mask: MPMask, categoryToColorMap: Map<number, RGBAColor>,
|
||||
background?: RGBAColor|ImageSource): void;
|
||||
/**
|
||||
* Draws a category mask using the provided color array.
|
||||
*
|
||||
* @export
|
||||
* @param mask A category mask that was returned from a segmentation task.
|
||||
* @param categoryToColorMap An array that maps indices to RGBA values. The
|
||||
* array's indices must correspond to the category indices of the model
|
||||
* and an entry must be provided for each category.
|
||||
* @param background A color or image to use as the background. Defaults to
|
||||
* black.
|
||||
*/
|
||||
drawCategoryMask(
|
||||
mask: MPMask, categoryToColorMap: RGBAColor[],
|
||||
background?: RGBAColor|ImageSource): void;
|
||||
drawCategoryMask(
|
||||
mask: MPMask, categoryToColorMap: CategoryToColorMap,
|
||||
background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
|
||||
if (this.context2d) {
|
||||
this.drawCategoryMask2D(mask, background, categoryToColorMap);
|
||||
} else {
|
||||
this.drawCategoryMaskWebGL(
|
||||
mask.getAsWebGLTexture(), background, categoryToColorMap);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the given mask to a WebGLTexture and runs the callback. Cleans
|
||||
* up any new resources after the callback finished executing.
|
||||
*/
|
||||
private runWithWebGLTexture(
|
||||
mask: MPMask, callback: (texture: WebGLTexture) => void): void {
|
||||
if (!mask.hasWebGLTexture()) {
|
||||
// Re-create the MPMask but use our the WebGL canvas so we can draw the
|
||||
// texture directly.
|
||||
const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() :
|
||||
mask.getAsUint8Array();
|
||||
this.convertToWebGLTextureShaderContext =
|
||||
this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext();
|
||||
const gl = this.getWebGLRenderingContext();
|
||||
|
||||
const convertedMask = new MPMask(
|
||||
[data],
|
||||
/* ownsWebGlTexture= */ false,
|
||||
gl.canvas,
|
||||
this.convertToWebGLTextureShaderContext,
|
||||
mask.width,
|
||||
mask.height,
|
||||
);
|
||||
callback(convertedMask.getAsWebGLTexture());
|
||||
convertedMask.close();
|
||||
} else {
|
||||
callback(mask.getAsWebGLTexture());
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Frees all WebGL resources held by this class.
|
||||
* @export
|
||||
*/
|
||||
close(): void {
|
||||
this.categoryMaskShaderContext?.close();
|
||||
this.categoryMaskShaderContext = undefined;
|
||||
this.convertToWebGLTextureShaderContext?.close();
|
||||
this.convertToWebGLTextureShaderContext = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
189
mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
Normal file
189
mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
Normal file
|
@ -0,0 +1,189 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||
|
||||
/**
|
||||
* A fragment shader that maps categories to colors based on a background
|
||||
* texture, a mask texture and a 256x1 "color mapping texture" that contains one
|
||||
* color for each pixel.
|
||||
*/
|
||||
const FRAGMENT_SHADER = `
|
||||
precision mediump float;
|
||||
uniform sampler2D backgroundTexture;
|
||||
uniform sampler2D maskTexture;
|
||||
uniform sampler2D colorMappingTexture;
|
||||
varying vec2 vTex;
|
||||
void main() {
|
||||
vec4 backgroundColor = texture2D(backgroundTexture, vTex);
|
||||
float category = texture2D(maskTexture, vTex).r;
|
||||
vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0));
|
||||
gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a);
|
||||
}
|
||||
`;
|
||||
|
||||
/**
|
||||
* A four channel color with values for red, green, blue and alpha
|
||||
* respectively.
|
||||
*/
|
||||
export type RGBAColor = [number, number, number, number]|number[];
|
||||
|
||||
/**
|
||||
* A category to color mapping that uses either a map or an array to assign
|
||||
* category indexes to RGBA colors.
|
||||
*/
|
||||
export type CategoryToColorMap = Map<number, RGBAColor>|RGBAColor[];
|
||||
|
||||
|
||||
/** Checks CategoryToColorMap maps for deep equality. */
|
||||
function isEqualColorMap(
|
||||
a: CategoryToColorMap, b: CategoryToColorMap): boolean {
|
||||
if (a !== b) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const aEntries = a.entries();
|
||||
const bEntries = b.entries();
|
||||
for (const [aKey, aValue] of aEntries) {
|
||||
const bNext = bEntries.next();
|
||||
if (bNext.done) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const [bKey, bValue] = bNext.value;
|
||||
if (aKey !== bKey) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] ||
|
||||
aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return !!bEntries.next().done;
|
||||
}
|
||||
|
||||
|
||||
/** A drawing util class for category masks. */
|
||||
export class CategoryMaskShaderContext extends MPImageShaderContext {
|
||||
backgroundTexture?: WebGLTexture;
|
||||
colorMappingTexture?: WebGLTexture;
|
||||
colorMappingTextureUniform?: WebGLUniformLocation;
|
||||
backgroundTextureUniform?: WebGLUniformLocation;
|
||||
maskTextureUniform?: WebGLUniformLocation;
|
||||
currentColorMap?: CategoryToColorMap;
|
||||
|
||||
bindAndUploadTextures(
|
||||
categoryMask: WebGLTexture, background: ImageSource,
|
||||
colorMap: Map<number, number[]>|number[][]) {
|
||||
const gl = this.gl!;
|
||||
|
||||
// TODO: We should avoid uploading textures from CPU to GPU
|
||||
// if the textures haven't changed. This can lead to drastic performance
|
||||
// slowdowns (~50ms per frame). Users can reduce the penalty by passing a
|
||||
// canvas object instead of ImageData/HTMLImageElement.
|
||||
gl.activeTexture(gl.TEXTURE0);
|
||||
gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!);
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background);
|
||||
|
||||
// Bind color mapping texture if changed.
|
||||
if (!this.currentColorMap ||
|
||||
!isEqualColorMap(this.currentColorMap, colorMap)) {
|
||||
this.currentColorMap = colorMap;
|
||||
|
||||
const pixels = new Array(256 * 4).fill(0);
|
||||
colorMap.forEach((rgba, index) => {
|
||||
if (rgba.length !== 4) {
|
||||
throw new Error(
|
||||
`Color at index ${index} is not a four-channel value.`);
|
||||
}
|
||||
pixels[index * 4] = rgba[0];
|
||||
pixels[index * 4 + 1] = rgba[1];
|
||||
pixels[index * 4 + 2] = rgba[2];
|
||||
pixels[index * 4 + 3] = rgba[3];
|
||||
});
|
||||
gl.activeTexture(gl.TEXTURE1);
|
||||
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE,
|
||||
new Uint8Array(pixels));
|
||||
} else {
|
||||
gl.activeTexture(gl.TEXTURE1);
|
||||
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
|
||||
}
|
||||
|
||||
// Bind category mask
|
||||
gl.activeTexture(gl.TEXTURE2);
|
||||
gl.bindTexture(gl.TEXTURE_2D, categoryMask);
|
||||
}
|
||||
|
||||
unbindTextures() {
|
||||
const gl = this.gl!;
|
||||
gl.activeTexture(gl.TEXTURE0);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
gl.activeTexture(gl.TEXTURE1);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
gl.activeTexture(gl.TEXTURE2);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
}
|
||||
|
||||
protected override getFragmentShader(): string {
|
||||
return FRAGMENT_SHADER;
|
||||
}
|
||||
|
||||
protected override setupTextures(): void {
|
||||
const gl = this.gl!;
|
||||
gl.activeTexture(gl.TEXTURE0);
|
||||
this.backgroundTexture = this.createTexture(gl, gl.LINEAR);
|
||||
// Use `gl.NEAREST` to prevent interpolating values in our category to
|
||||
// color map.
|
||||
this.colorMappingTexture = this.createTexture(gl, gl.NEAREST);
|
||||
}
|
||||
|
||||
protected override setupShaders(): void {
|
||||
super.setupShaders();
|
||||
const gl = this.gl!;
|
||||
this.backgroundTextureUniform = assertNotNull(
|
||||
gl.getUniformLocation(this.program!, 'backgroundTexture'),
|
||||
'Uniform location');
|
||||
this.colorMappingTextureUniform = assertNotNull(
|
||||
gl.getUniformLocation(this.program!, 'colorMappingTexture'),
|
||||
'Uniform location');
|
||||
this.maskTextureUniform = assertNotNull(
|
||||
gl.getUniformLocation(this.program!, 'maskTexture'),
|
||||
'Uniform location');
|
||||
}
|
||||
|
||||
protected override configureUniforms(): void {
|
||||
super.configureUniforms();
|
||||
const gl = this.gl!;
|
||||
gl.uniform1i(this.backgroundTextureUniform!, 0);
|
||||
gl.uniform1i(this.colorMappingTextureUniform!, 1);
|
||||
gl.uniform1i(this.maskTextureUniform!, 2);
|
||||
}
|
||||
|
||||
override close(): void {
|
||||
if (this.backgroundTexture) {
|
||||
this.gl!.deleteTexture(this.backgroundTexture);
|
||||
}
|
||||
if (this.colorMappingTexture) {
|
||||
this.gl!.deleteTexture(this.colorMappingTexture);
|
||||
}
|
||||
super.close();
|
||||
}
|
||||
}
|
|
@ -73,9 +73,9 @@ class MPImageShaderBuffers {
|
|||
* For internal use only.
|
||||
*/
|
||||
export class MPImageShaderContext {
|
||||
private gl?: WebGL2RenderingContext;
|
||||
protected gl?: WebGL2RenderingContext;
|
||||
private framebuffer?: WebGLFramebuffer;
|
||||
private program?: WebGLProgram;
|
||||
protected program?: WebGLProgram;
|
||||
private vertexShader?: WebGLShader;
|
||||
private fragmentShader?: WebGLShader;
|
||||
private aVertex?: GLint;
|
||||
|
@ -94,6 +94,14 @@ export class MPImageShaderContext {
|
|||
*/
|
||||
private shaderBuffersFlipVertically?: MPImageShaderBuffers;
|
||||
|
||||
protected getFragmentShader(): string {
|
||||
return FRAGMENT_SHADER;
|
||||
}
|
||||
|
||||
protected getVertexShader(): string {
|
||||
return VERTEX_SHADER;
|
||||
}
|
||||
|
||||
private compileShader(source: string, type: number): WebGLShader {
|
||||
const gl = this.gl!;
|
||||
const shader =
|
||||
|
@ -108,14 +116,15 @@ export class MPImageShaderContext {
|
|||
return shader;
|
||||
}
|
||||
|
||||
private setupShaders(): void {
|
||||
protected setupShaders(): void {
|
||||
const gl = this.gl!;
|
||||
this.program =
|
||||
assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
|
||||
|
||||
this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER);
|
||||
this.vertexShader =
|
||||
this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER);
|
||||
this.fragmentShader =
|
||||
this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER);
|
||||
this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER);
|
||||
|
||||
gl.linkProgram(this.program);
|
||||
const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS);
|
||||
|
@ -128,6 +137,10 @@ export class MPImageShaderContext {
|
|||
this.aTex = gl.getAttribLocation(this.program, 'aTex');
|
||||
}
|
||||
|
||||
protected setupTextures(): void {}
|
||||
|
||||
protected configureUniforms(): void {}
|
||||
|
||||
private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
|
||||
const gl = this.gl!;
|
||||
const vertexArrayObject =
|
||||
|
@ -193,17 +206,44 @@ export class MPImageShaderContext {
|
|||
|
||||
if (!this.program) {
|
||||
this.setupShaders();
|
||||
this.setupTextures();
|
||||
}
|
||||
|
||||
const shaderBuffers = this.getShaderBuffers(flipVertically);
|
||||
gl.useProgram(this.program!);
|
||||
shaderBuffers.bind();
|
||||
this.configureUniforms();
|
||||
const result = callback();
|
||||
shaderBuffers.unbind();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and configures a texture.
|
||||
*
|
||||
* @param gl The rendering context.
|
||||
* @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and
|
||||
* `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`.
|
||||
* @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and
|
||||
* `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`.
|
||||
*/
|
||||
createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum):
|
||||
WebGLTexture {
|
||||
this.maybeInitGL(gl);
|
||||
const texture =
|
||||
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
gl.texParameteri(
|
||||
gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(
|
||||
gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
return texture;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds a framebuffer to the canvas. If the framebuffer does not yet exist,
|
||||
* creates it first. Binds the provided texture to the framebuffer.
|
||||
|
|
|
@ -16,24 +16,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Pre-baked color table for a maximum of 12 classes.
|
||||
const CM_ALPHA = 128;
|
||||
const COLOR_MAP: Array<[number, number, number, number]> = [
|
||||
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
||||
[255, 0, 0, CM_ALPHA], // class 1 is red
|
||||
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
||||
[0, 0, 255, CM_ALPHA], // class 3 is blue
|
||||
[255, 255, 0, CM_ALPHA], // class 4 is yellow
|
||||
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
|
||||
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
|
||||
[128, 128, 128, CM_ALPHA], // class 7 is gray
|
||||
[255, 128, 0, CM_ALPHA], // class 8 is orange
|
||||
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
|
||||
[0, 128, 0, CM_ALPHA], // class 10 is dark green
|
||||
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
|
||||
];
|
||||
|
||||
|
||||
/** Helper function to draw a confidence mask */
|
||||
export function drawConfidenceMask(
|
||||
ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
|
||||
|
@ -47,23 +29,3 @@ export function drawConfidenceMask(
|
|||
}
|
||||
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to draw a category mask. For GPU, we only have F32Arrays
|
||||
* for now.
|
||||
*/
|
||||
export function drawCategoryMask(
|
||||
ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
|
||||
width: number, height: number): void {
|
||||
const rgbaArray = new Uint8ClampedArray(width * height * 4);
|
||||
const isFloatArray = image instanceof Float32Array;
|
||||
for (let i = 0; i < image.length; i++) {
|
||||
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||
const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||
rgbaArray[4 * i] = color[0];
|
||||
rgbaArray[4 * i + 1] = color[1];
|
||||
rgbaArray[4 * i + 2] = color[2];
|
||||
rgbaArray[4 * i + 3] = color[3];
|
||||
}
|
||||
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user