Add drawCategoryMask() to our public API
PiperOrigin-RevId: 578526413
This commit is contained in:
parent
3a55f1156a
commit
9474394768
|
@ -31,27 +31,57 @@ mediapipe_ts_library(
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "drawing_utils",
|
name = "drawing_utils",
|
||||||
srcs = ["drawing_utils.ts"],
|
srcs = [
|
||||||
|
"drawing_utils.ts",
|
||||||
|
"drawing_utils_category_mask.ts",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":image",
|
||||||
|
":image_shader_context",
|
||||||
|
":mask",
|
||||||
":types",
|
":types",
|
||||||
"//mediapipe/tasks/web/components/containers:bounding_box",
|
"//mediapipe/tasks/web/components/containers:bounding_box",
|
||||||
"//mediapipe/tasks/web/components/containers:landmark",
|
"//mediapipe/tasks/web/components/containers:landmark",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "image",
|
name = "drawing_utils_test_lib",
|
||||||
srcs = [
|
testonly = True,
|
||||||
"image.ts",
|
srcs = ["drawing_utils.test.ts"],
|
||||||
"image_shader_context.ts",
|
deps = [
|
||||||
|
":drawing_utils",
|
||||||
|
":image",
|
||||||
|
":image_shader_context",
|
||||||
|
":mask",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "drawing_utils_test",
|
||||||
|
deps = [":drawing_utils_test_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image",
|
||||||
|
srcs = ["image.ts"],
|
||||||
|
deps = ["image_shader_context"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_shader_context",
|
||||||
|
srcs = ["image_shader_context.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "image_test_lib",
|
name = "image_test_lib",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
srcs = ["image.test.ts"],
|
srcs = ["image.test.ts"],
|
||||||
deps = [":image"],
|
deps = [
|
||||||
|
":image",
|
||||||
|
":image_shader_context",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jasmine_node_test(
|
jasmine_node_test(
|
||||||
|
@ -64,6 +94,7 @@ mediapipe_ts_library(
|
||||||
srcs = ["mask.ts"],
|
srcs = ["mask.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":image",
|
":image",
|
||||||
|
":image_shader_context",
|
||||||
"//mediapipe/web/graph_runner:platform_utils",
|
"//mediapipe/web/graph_runner:platform_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -74,6 +105,7 @@ mediapipe_ts_library(
|
||||||
srcs = ["mask.test.ts"],
|
srcs = ["mask.test.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":image",
|
":image",
|
||||||
|
":image_shader_context",
|
||||||
":mask",
|
":mask",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -89,6 +121,7 @@ mediapipe_ts_library(
|
||||||
deps = [
|
deps = [
|
||||||
":image",
|
":image",
|
||||||
":image_processing_options",
|
":image_processing_options",
|
||||||
|
":image_shader_context",
|
||||||
":mask",
|
":mask",
|
||||||
":vision_task_options",
|
":vision_task_options",
|
||||||
"//mediapipe/framework/formats:rect_jspb_proto",
|
"//mediapipe/framework/formats:rect_jspb_proto",
|
||||||
|
|
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
103
mediapipe/tasks/web/vision/core/drawing_utils.test.ts
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
import {DrawingUtils} from './drawing_utils';
|
||||||
|
import {MPImageShaderContext} from './image_shader_context';
|
||||||
|
import {MPMask} from './mask';
|
||||||
|
|
||||||
|
const WIDTH = 2;
|
||||||
|
const HEIGHT = 2;
|
||||||
|
|
||||||
|
const skip = typeof document === 'undefined';
|
||||||
|
if (skip) {
|
||||||
|
console.log('These tests must be run in a browser.');
|
||||||
|
}
|
||||||
|
|
||||||
|
(skip ? xdescribe : describe)('DrawingUtils', () => {
|
||||||
|
let shaderContext = new MPImageShaderContext();
|
||||||
|
let canvas2D: HTMLCanvasElement;
|
||||||
|
let context2D: CanvasRenderingContext2D;
|
||||||
|
let drawingUtils2D: DrawingUtils;
|
||||||
|
let canvasWebGL: HTMLCanvasElement;
|
||||||
|
let contextWebGL: WebGL2RenderingContext;
|
||||||
|
let drawingUtilsWebGL: DrawingUtils;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
shaderContext = new MPImageShaderContext();
|
||||||
|
|
||||||
|
canvasWebGL = document.createElement('canvas');
|
||||||
|
canvasWebGL.width = WIDTH;
|
||||||
|
canvasWebGL.height = HEIGHT;
|
||||||
|
contextWebGL = canvasWebGL.getContext('webgl2')!;
|
||||||
|
drawingUtilsWebGL = new DrawingUtils(contextWebGL);
|
||||||
|
|
||||||
|
canvas2D = document.createElement('canvas');
|
||||||
|
canvas2D.width = WIDTH;
|
||||||
|
canvas2D.height = HEIGHT;
|
||||||
|
context2D = canvas2D.getContext('2d')!;
|
||||||
|
drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
shaderContext.close();
|
||||||
|
drawingUtils2D.close();
|
||||||
|
drawingUtilsWebGL.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('drawCategoryMask() ', () => {
|
||||||
|
const colors = [
|
||||||
|
[0, 0, 0, 255],
|
||||||
|
[0, 255, 0, 255],
|
||||||
|
[0, 0, 255, 255],
|
||||||
|
[255, 255, 255, 255],
|
||||||
|
];
|
||||||
|
const expectedResult = new Uint8Array(
|
||||||
|
[0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
|
||||||
|
);
|
||||||
|
|
||||||
|
it('on 2D canvas', () => {
|
||||||
|
const categoryMask = new MPMask(
|
||||||
|
[new Uint8Array([0, 1, 2, 3])],
|
||||||
|
/* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
|
||||||
|
HEIGHT);
|
||||||
|
|
||||||
|
drawingUtils2D.drawCategoryMask(categoryMask, colors);
|
||||||
|
|
||||||
|
const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
|
||||||
|
expect(actualResult)
|
||||||
|
.toEqual(new Uint8ClampedArray(expectedResult.buffer));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('on WebGL canvas', () => {
|
||||||
|
const categoryMask = new MPMask(
|
||||||
|
[new Uint8Array([2, 3, 0, 1])], // Note: Vertically flipped
|
||||||
|
/* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
|
||||||
|
HEIGHT);
|
||||||
|
|
||||||
|
drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
|
||||||
|
|
||||||
|
const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
|
||||||
|
contextWebGL.readPixels(
|
||||||
|
0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
|
||||||
|
actualResult);
|
||||||
|
expect(actualResult).toEqual(expectedResult);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
|
||||||
|
});
|
|
@ -16,7 +16,11 @@
|
||||||
|
|
||||||
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
|
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
|
||||||
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
|
import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask';
|
||||||
|
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||||
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
import {Connection} from '../../../../tasks/web/vision/core/types';
|
import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||||
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A user-defined callback to take input data and map it to a custom output
|
* A user-defined callback to take input data and map it to a custom output
|
||||||
|
@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||||
*/
|
*/
|
||||||
export type Callback<I, O> = (input: I) => O;
|
export type Callback<I, O> = (input: I) => O;
|
||||||
|
|
||||||
|
// Used in public API
|
||||||
|
export {ImageSource};
|
||||||
|
|
||||||
/** Data that a user can use to specialize drawing options. */
|
/** Data that a user can use to specialize drawing options. */
|
||||||
export declare interface LandmarkData {
|
export declare interface LandmarkData {
|
||||||
index?: number;
|
index?: number;
|
||||||
|
@ -31,6 +38,32 @@ export declare interface LandmarkData {
|
||||||
to?: NormalizedLandmark;
|
to?: NormalizedLandmark;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** A color map with 22 classes. Used in our demos. */
|
||||||
|
export const DEFAULT_CATEGORY_TO_COLOR_MAP = [
|
||||||
|
[0, 0, 0, 0], // class 0 is BG = transparent
|
||||||
|
[255, 0, 0, 255], // class 1 is red
|
||||||
|
[0, 255, 0, 255], // class 2 is light green
|
||||||
|
[0, 0, 255, 255], // class 3 is blue
|
||||||
|
[255, 255, 0, 255], // class 4 is yellow
|
||||||
|
[255, 0, 255, 255], // class 5 is light purple / magenta
|
||||||
|
[0, 255, 255, 255], // class 6 is light blue / aqua
|
||||||
|
[128, 128, 128, 255], // class 7 is gray
|
||||||
|
[255, 100, 0, 255], // class 8 is dark orange
|
||||||
|
[128, 0, 255, 255], // class 9 is dark purple
|
||||||
|
[0, 150, 0, 255], // class 10 is green
|
||||||
|
[255, 255, 255, 255], // class 11 is white
|
||||||
|
[255, 105, 180, 255], // class 12 is pink
|
||||||
|
[255, 150, 0, 255], // class 13 is orange
|
||||||
|
[255, 250, 224, 255], // class 14 is light yellow
|
||||||
|
[148, 0, 211, 255], // class 15 is dark violet
|
||||||
|
[0, 100, 0, 255], // class 16 is dark green
|
||||||
|
[0, 0, 128, 255], // class 17 is navy blue
|
||||||
|
[165, 42, 42, 255], // class 18 is brown
|
||||||
|
[64, 224, 208, 255], // class 19 is turquoise
|
||||||
|
[255, 218, 185, 255], // class 20 is peach
|
||||||
|
[192, 192, 192, 255], // class 21 is silver
|
||||||
|
];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Options for customizing the drawing routines
|
* Options for customizing the drawing routines
|
||||||
*/
|
*/
|
||||||
|
@ -77,14 +110,47 @@ function resolve<O, I>(value: O|Callback<I, O>, data: I): O {
|
||||||
return value instanceof Function ? value(data) : value;
|
return value instanceof Function ? value(data) : value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export {RGBAColor, CategoryToColorMap};
|
||||||
|
|
||||||
/** Helper class to visualize the result of a MediaPipe Vision task. */
|
/** Helper class to visualize the result of a MediaPipe Vision task. */
|
||||||
export class DrawingUtils {
|
export class DrawingUtils {
|
||||||
|
private categoryMaskShaderContext?: CategoryMaskShaderContext;
|
||||||
|
private convertToWebGLTextureShaderContext?: MPImageShaderContext;
|
||||||
|
private readonly context2d?: CanvasRenderingContext2D;
|
||||||
|
private readonly contextWebGL?: WebGL2RenderingContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new DrawingUtils class.
|
* Creates a new DrawingUtils class.
|
||||||
*
|
*
|
||||||
* @param ctx The canvas to render onto.
|
* @param gpuContext The WebGL canvas rendering context to render into. If
|
||||||
|
* your Task is using a GPU delegate, the context must be obtained from
|
||||||
|
* its canvas (provided via `setOptions({ canvas: .. })`).
|
||||||
*/
|
*/
|
||||||
constructor(private readonly ctx: CanvasRenderingContext2D) {}
|
constructor(gpuContext: WebGL2RenderingContext);
|
||||||
|
/**
|
||||||
|
* Creates a new DrawingUtils class.
|
||||||
|
*
|
||||||
|
* @param cpuContext The 2D canvas rendering context to render into. If
|
||||||
|
* you are rendering GPU data you must also provide `gpuContext` to allow
|
||||||
|
* for data conversion.
|
||||||
|
* @param gpuContext A WebGL canvas that is used for GPU rendering and for
|
||||||
|
* converting GPU to CPU data. If your Task is using a GPU delegate, the
|
||||||
|
* context must be obtained from its canvas (provided via
|
||||||
|
* `setOptions({ canvas: .. })`).
|
||||||
|
*/
|
||||||
|
constructor(
|
||||||
|
cpuContext: CanvasRenderingContext2D,
|
||||||
|
gpuContext?: WebGL2RenderingContext);
|
||||||
|
constructor(
|
||||||
|
cpuOrGpuGontext: CanvasRenderingContext2D|WebGL2RenderingContext,
|
||||||
|
gpuContext?: WebGL2RenderingContext) {
|
||||||
|
if (cpuOrGpuGontext instanceof CanvasRenderingContext2D) {
|
||||||
|
this.context2d = cpuOrGpuGontext;
|
||||||
|
this.contextWebGL = gpuContext;
|
||||||
|
} else {
|
||||||
|
this.contextWebGL = cpuOrGpuGontext;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Restricts a number between two endpoints (order doesn't matter).
|
* Restricts a number between two endpoints (order doesn't matter).
|
||||||
|
@ -120,9 +186,35 @@ export class DrawingUtils {
|
||||||
return DrawingUtils.clamp(out, y0, y1);
|
return DrawingUtils.clamp(out, y0, y1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getCanvasRenderingContext(): CanvasRenderingContext2D {
|
||||||
|
if (!this.context2d) {
|
||||||
|
throw new Error(
|
||||||
|
'CPU rendering requested but CanvasRenderingContext2D not provided.');
|
||||||
|
}
|
||||||
|
return this.context2d;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getWebGLRenderingContext(): WebGL2RenderingContext {
|
||||||
|
if (!this.contextWebGL) {
|
||||||
|
throw new Error(
|
||||||
|
'GPU rendering requested but WebGL2RenderingContext not provided.');
|
||||||
|
}
|
||||||
|
return this.contextWebGL;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getCategoryMaskShaderContext(): CategoryMaskShaderContext {
|
||||||
|
if (!this.categoryMaskShaderContext) {
|
||||||
|
this.categoryMaskShaderContext = new CategoryMaskShaderContext();
|
||||||
|
}
|
||||||
|
return this.categoryMaskShaderContext;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Draws circles onto the provided landmarks.
|
* Draws circles onto the provided landmarks.
|
||||||
*
|
*
|
||||||
|
* This method can only be used when `DrawingUtils` is initialized with a
|
||||||
|
* `CanvasRenderingContext2D`.
|
||||||
|
*
|
||||||
* @export
|
* @export
|
||||||
* @param landmarks The landmarks to draw.
|
* @param landmarks The landmarks to draw.
|
||||||
* @param style The style to visualize the landmarks.
|
* @param style The style to visualize the landmarks.
|
||||||
|
@ -132,7 +224,7 @@ export class DrawingUtils {
|
||||||
if (!landmarks) {
|
if (!landmarks) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const ctx = this.ctx;
|
const ctx = this.getCanvasRenderingContext();
|
||||||
const options = addDefaultOptions(style);
|
const options = addDefaultOptions(style);
|
||||||
ctx.save();
|
ctx.save();
|
||||||
const canvas = ctx.canvas;
|
const canvas = ctx.canvas;
|
||||||
|
@ -159,6 +251,9 @@ export class DrawingUtils {
|
||||||
/**
|
/**
|
||||||
* Draws lines between landmarks (given a connection graph).
|
* Draws lines between landmarks (given a connection graph).
|
||||||
*
|
*
|
||||||
|
* This method can only be used when `DrawingUtils` is initialized with a
|
||||||
|
* `CanvasRenderingContext2D`.
|
||||||
|
*
|
||||||
* @export
|
* @export
|
||||||
* @param landmarks The landmarks to draw.
|
* @param landmarks The landmarks to draw.
|
||||||
* @param connections The connections array that contains the start and the
|
* @param connections The connections array that contains the start and the
|
||||||
|
@ -171,7 +266,7 @@ export class DrawingUtils {
|
||||||
if (!landmarks || !connections) {
|
if (!landmarks || !connections) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const ctx = this.ctx;
|
const ctx = this.getCanvasRenderingContext();
|
||||||
const options = addDefaultOptions(style);
|
const options = addDefaultOptions(style);
|
||||||
ctx.save();
|
ctx.save();
|
||||||
const canvas = ctx.canvas;
|
const canvas = ctx.canvas;
|
||||||
|
@ -195,12 +290,15 @@ export class DrawingUtils {
|
||||||
/**
|
/**
|
||||||
* Draws a bounding box.
|
* Draws a bounding box.
|
||||||
*
|
*
|
||||||
|
* This method can only be used when `DrawingUtils` is initialized with a
|
||||||
|
* `CanvasRenderingContext2D`.
|
||||||
|
*
|
||||||
* @export
|
* @export
|
||||||
* @param boundingBox The bounding box to draw.
|
* @param boundingBox The bounding box to draw.
|
||||||
* @param style The style to visualize the boundin box.
|
* @param style The style to visualize the boundin box.
|
||||||
*/
|
*/
|
||||||
drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void {
|
drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void {
|
||||||
const ctx = this.ctx;
|
const ctx = this.getCanvasRenderingContext();
|
||||||
const options = addDefaultOptions(style);
|
const options = addDefaultOptions(style);
|
||||||
ctx.save();
|
ctx.save();
|
||||||
ctx.beginPath();
|
ctx.beginPath();
|
||||||
|
@ -218,6 +316,118 @@ export class DrawingUtils {
|
||||||
ctx.fill();
|
ctx.fill();
|
||||||
ctx.restore();
|
ctx.restore();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Draws a category mask on a CanvasRenderingContext2D. */
|
||||||
|
private drawCategoryMask2D(
|
||||||
|
mask: MPMask, background: RGBAColor|ImageSource,
|
||||||
|
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
|
||||||
|
// Use the WebGL renderer to draw result on our internal canvas.
|
||||||
|
const gl = this.getWebGLRenderingContext();
|
||||||
|
this.runWithWebGLTexture(mask, texture => {
|
||||||
|
this.drawCategoryMaskWebGL(texture, background, categoryToColorMap);
|
||||||
|
// Draw the result on the user canvas.
|
||||||
|
const ctx = this.getCanvasRenderingContext();
|
||||||
|
ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Draws a category mask on a WebGL2RenderingContext2D. */
|
||||||
|
private drawCategoryMaskWebGL(
|
||||||
|
categoryTexture: WebGLTexture, background: RGBAColor|ImageSource,
|
||||||
|
categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
|
||||||
|
const shaderContext = this.getCategoryMaskShaderContext();
|
||||||
|
const gl = this.getWebGLRenderingContext();
|
||||||
|
const backgroundImage = Array.isArray(background) ?
|
||||||
|
new ImageData(new Uint8ClampedArray(background), 1, 1) :
|
||||||
|
background;
|
||||||
|
|
||||||
|
shaderContext.run(gl, /* flipTexturesVertically= */ true, () => {
|
||||||
|
shaderContext.bindAndUploadTextures(
|
||||||
|
categoryTexture, backgroundImage, categoryToColorMap);
|
||||||
|
gl.clearColor(0, 0, 0, 0);
|
||||||
|
gl.clear(gl.COLOR_BUFFER_BIT);
|
||||||
|
gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
|
||||||
|
shaderContext.unbindTextures();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Draws a category mask using the provided category-to-color mapping.
|
||||||
|
*
|
||||||
|
* @export
|
||||||
|
* @param mask A category mask that was returned from a segmentation task.
|
||||||
|
* @param categoryToColorMap A map that maps category indices to RGBA
|
||||||
|
* values. You must specify a map entry for each category.
|
||||||
|
* @param background A color or image to use as the background. Defaults to
|
||||||
|
* black.
|
||||||
|
*/
|
||||||
|
drawCategoryMask(
|
||||||
|
mask: MPMask, categoryToColorMap: Map<number, RGBAColor>,
|
||||||
|
background?: RGBAColor|ImageSource): void;
|
||||||
|
/**
|
||||||
|
* Draws a category mask using the provided color array.
|
||||||
|
*
|
||||||
|
* @export
|
||||||
|
* @param mask A category mask that was returned from a segmentation task.
|
||||||
|
* @param categoryToColorMap An array that maps indices to RGBA values. The
|
||||||
|
* array's indices must correspond to the category indices of the model
|
||||||
|
* and an entry must be provided for each category.
|
||||||
|
* @param background A color or image to use as the background. Defaults to
|
||||||
|
* black.
|
||||||
|
*/
|
||||||
|
drawCategoryMask(
|
||||||
|
mask: MPMask, categoryToColorMap: RGBAColor[],
|
||||||
|
background?: RGBAColor|ImageSource): void;
|
||||||
|
drawCategoryMask(
|
||||||
|
mask: MPMask, categoryToColorMap: CategoryToColorMap,
|
||||||
|
background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
|
||||||
|
if (this.context2d) {
|
||||||
|
this.drawCategoryMask2D(mask, background, categoryToColorMap);
|
||||||
|
} else {
|
||||||
|
this.drawCategoryMaskWebGL(
|
||||||
|
mask.getAsWebGLTexture(), background, categoryToColorMap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the given mask to a WebGLTexture and runs the callback. Cleans
|
||||||
|
* up any new resources after the callback finished executing.
|
||||||
|
*/
|
||||||
|
private runWithWebGLTexture(
|
||||||
|
mask: MPMask, callback: (texture: WebGLTexture) => void): void {
|
||||||
|
if (!mask.hasWebGLTexture()) {
|
||||||
|
// Re-create the MPMask but use our the WebGL canvas so we can draw the
|
||||||
|
// texture directly.
|
||||||
|
const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() :
|
||||||
|
mask.getAsUint8Array();
|
||||||
|
this.convertToWebGLTextureShaderContext =
|
||||||
|
this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext();
|
||||||
|
const gl = this.getWebGLRenderingContext();
|
||||||
|
|
||||||
|
const convertedMask = new MPMask(
|
||||||
|
[data],
|
||||||
|
/* ownsWebGlTexture= */ false,
|
||||||
|
gl.canvas,
|
||||||
|
this.convertToWebGLTextureShaderContext,
|
||||||
|
mask.width,
|
||||||
|
mask.height,
|
||||||
|
);
|
||||||
|
callback(convertedMask.getAsWebGLTexture());
|
||||||
|
convertedMask.close();
|
||||||
|
} else {
|
||||||
|
callback(mask.getAsWebGLTexture());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Frees all WebGL resources held by this class.
|
||||||
|
* @export
|
||||||
|
*/
|
||||||
|
close(): void {
|
||||||
|
this.categoryMaskShaderContext?.close();
|
||||||
|
this.categoryMaskShaderContext = undefined;
|
||||||
|
this.convertToWebGLTextureShaderContext?.close();
|
||||||
|
this.convertToWebGLTextureShaderContext = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
189
mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
Normal file
189
mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
Normal file
|
@ -0,0 +1,189 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||||
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A fragment shader that maps categories to colors based on a background
|
||||||
|
* texture, a mask texture and a 256x1 "color mapping texture" that contains one
|
||||||
|
* color for each pixel.
|
||||||
|
*/
|
||||||
|
const FRAGMENT_SHADER = `
|
||||||
|
precision mediump float;
|
||||||
|
uniform sampler2D backgroundTexture;
|
||||||
|
uniform sampler2D maskTexture;
|
||||||
|
uniform sampler2D colorMappingTexture;
|
||||||
|
varying vec2 vTex;
|
||||||
|
void main() {
|
||||||
|
vec4 backgroundColor = texture2D(backgroundTexture, vTex);
|
||||||
|
float category = texture2D(maskTexture, vTex).r;
|
||||||
|
vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0));
|
||||||
|
gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a);
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A four channel color with values for red, green, blue and alpha
|
||||||
|
* respectively.
|
||||||
|
*/
|
||||||
|
export type RGBAColor = [number, number, number, number]|number[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A category to color mapping that uses either a map or an array to assign
|
||||||
|
* category indexes to RGBA colors.
|
||||||
|
*/
|
||||||
|
export type CategoryToColorMap = Map<number, RGBAColor>|RGBAColor[];
|
||||||
|
|
||||||
|
|
||||||
|
/** Checks CategoryToColorMap maps for deep equality. */
|
||||||
|
function isEqualColorMap(
|
||||||
|
a: CategoryToColorMap, b: CategoryToColorMap): boolean {
|
||||||
|
if (a !== b) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const aEntries = a.entries();
|
||||||
|
const bEntries = b.entries();
|
||||||
|
for (const [aKey, aValue] of aEntries) {
|
||||||
|
const bNext = bEntries.next();
|
||||||
|
if (bNext.done) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [bKey, bValue] = bNext.value;
|
||||||
|
if (aKey !== bKey) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] ||
|
||||||
|
aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !!bEntries.next().done;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** A drawing util class for category masks. */
|
||||||
|
export class CategoryMaskShaderContext extends MPImageShaderContext {
|
||||||
|
backgroundTexture?: WebGLTexture;
|
||||||
|
colorMappingTexture?: WebGLTexture;
|
||||||
|
colorMappingTextureUniform?: WebGLUniformLocation;
|
||||||
|
backgroundTextureUniform?: WebGLUniformLocation;
|
||||||
|
maskTextureUniform?: WebGLUniformLocation;
|
||||||
|
currentColorMap?: CategoryToColorMap;
|
||||||
|
|
||||||
|
bindAndUploadTextures(
|
||||||
|
categoryMask: WebGLTexture, background: ImageSource,
|
||||||
|
colorMap: Map<number, number[]>|number[][]) {
|
||||||
|
const gl = this.gl!;
|
||||||
|
|
||||||
|
// TODO: We should avoid uploading textures from CPU to GPU
|
||||||
|
// if the textures haven't changed. This can lead to drastic performance
|
||||||
|
// slowdowns (~50ms per frame). Users can reduce the penalty by passing a
|
||||||
|
// canvas object instead of ImageData/HTMLImageElement.
|
||||||
|
gl.activeTexture(gl.TEXTURE0);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!);
|
||||||
|
gl.texImage2D(
|
||||||
|
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background);
|
||||||
|
|
||||||
|
// Bind color mapping texture if changed.
|
||||||
|
if (!this.currentColorMap ||
|
||||||
|
!isEqualColorMap(this.currentColorMap, colorMap)) {
|
||||||
|
this.currentColorMap = colorMap;
|
||||||
|
|
||||||
|
const pixels = new Array(256 * 4).fill(0);
|
||||||
|
colorMap.forEach((rgba, index) => {
|
||||||
|
if (rgba.length !== 4) {
|
||||||
|
throw new Error(
|
||||||
|
`Color at index ${index} is not a four-channel value.`);
|
||||||
|
}
|
||||||
|
pixels[index * 4] = rgba[0];
|
||||||
|
pixels[index * 4 + 1] = rgba[1];
|
||||||
|
pixels[index * 4 + 2] = rgba[2];
|
||||||
|
pixels[index * 4 + 3] = rgba[3];
|
||||||
|
});
|
||||||
|
gl.activeTexture(gl.TEXTURE1);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
|
||||||
|
gl.texImage2D(
|
||||||
|
gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE,
|
||||||
|
new Uint8Array(pixels));
|
||||||
|
} else {
|
||||||
|
gl.activeTexture(gl.TEXTURE1);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind category mask
|
||||||
|
gl.activeTexture(gl.TEXTURE2);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, categoryMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
unbindTextures() {
|
||||||
|
const gl = this.gl!;
|
||||||
|
gl.activeTexture(gl.TEXTURE0);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
gl.activeTexture(gl.TEXTURE1);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
gl.activeTexture(gl.TEXTURE2);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override getFragmentShader(): string {
|
||||||
|
return FRAGMENT_SHADER;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override setupTextures(): void {
|
||||||
|
const gl = this.gl!;
|
||||||
|
gl.activeTexture(gl.TEXTURE0);
|
||||||
|
this.backgroundTexture = this.createTexture(gl, gl.LINEAR);
|
||||||
|
// Use `gl.NEAREST` to prevent interpolating values in our category to
|
||||||
|
// color map.
|
||||||
|
this.colorMappingTexture = this.createTexture(gl, gl.NEAREST);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override setupShaders(): void {
|
||||||
|
super.setupShaders();
|
||||||
|
const gl = this.gl!;
|
||||||
|
this.backgroundTextureUniform = assertNotNull(
|
||||||
|
gl.getUniformLocation(this.program!, 'backgroundTexture'),
|
||||||
|
'Uniform location');
|
||||||
|
this.colorMappingTextureUniform = assertNotNull(
|
||||||
|
gl.getUniformLocation(this.program!, 'colorMappingTexture'),
|
||||||
|
'Uniform location');
|
||||||
|
this.maskTextureUniform = assertNotNull(
|
||||||
|
gl.getUniformLocation(this.program!, 'maskTexture'),
|
||||||
|
'Uniform location');
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override configureUniforms(): void {
|
||||||
|
super.configureUniforms();
|
||||||
|
const gl = this.gl!;
|
||||||
|
gl.uniform1i(this.backgroundTextureUniform!, 0);
|
||||||
|
gl.uniform1i(this.colorMappingTextureUniform!, 1);
|
||||||
|
gl.uniform1i(this.maskTextureUniform!, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
override close(): void {
|
||||||
|
if (this.backgroundTexture) {
|
||||||
|
this.gl!.deleteTexture(this.backgroundTexture);
|
||||||
|
}
|
||||||
|
if (this.colorMappingTexture) {
|
||||||
|
this.gl!.deleteTexture(this.colorMappingTexture);
|
||||||
|
}
|
||||||
|
super.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -73,9 +73,9 @@ class MPImageShaderBuffers {
|
||||||
* For internal use only.
|
* For internal use only.
|
||||||
*/
|
*/
|
||||||
export class MPImageShaderContext {
|
export class MPImageShaderContext {
|
||||||
private gl?: WebGL2RenderingContext;
|
protected gl?: WebGL2RenderingContext;
|
||||||
private framebuffer?: WebGLFramebuffer;
|
private framebuffer?: WebGLFramebuffer;
|
||||||
private program?: WebGLProgram;
|
protected program?: WebGLProgram;
|
||||||
private vertexShader?: WebGLShader;
|
private vertexShader?: WebGLShader;
|
||||||
private fragmentShader?: WebGLShader;
|
private fragmentShader?: WebGLShader;
|
||||||
private aVertex?: GLint;
|
private aVertex?: GLint;
|
||||||
|
@ -94,6 +94,14 @@ export class MPImageShaderContext {
|
||||||
*/
|
*/
|
||||||
private shaderBuffersFlipVertically?: MPImageShaderBuffers;
|
private shaderBuffersFlipVertically?: MPImageShaderBuffers;
|
||||||
|
|
||||||
|
protected getFragmentShader(): string {
|
||||||
|
return FRAGMENT_SHADER;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected getVertexShader(): string {
|
||||||
|
return VERTEX_SHADER;
|
||||||
|
}
|
||||||
|
|
||||||
private compileShader(source: string, type: number): WebGLShader {
|
private compileShader(source: string, type: number): WebGLShader {
|
||||||
const gl = this.gl!;
|
const gl = this.gl!;
|
||||||
const shader =
|
const shader =
|
||||||
|
@ -108,14 +116,15 @@ export class MPImageShaderContext {
|
||||||
return shader;
|
return shader;
|
||||||
}
|
}
|
||||||
|
|
||||||
private setupShaders(): void {
|
protected setupShaders(): void {
|
||||||
const gl = this.gl!;
|
const gl = this.gl!;
|
||||||
this.program =
|
this.program =
|
||||||
assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
|
assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
|
||||||
|
|
||||||
this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER);
|
this.vertexShader =
|
||||||
|
this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER);
|
||||||
this.fragmentShader =
|
this.fragmentShader =
|
||||||
this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER);
|
this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER);
|
||||||
|
|
||||||
gl.linkProgram(this.program);
|
gl.linkProgram(this.program);
|
||||||
const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS);
|
const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS);
|
||||||
|
@ -128,6 +137,10 @@ export class MPImageShaderContext {
|
||||||
this.aTex = gl.getAttribLocation(this.program, 'aTex');
|
this.aTex = gl.getAttribLocation(this.program, 'aTex');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected setupTextures(): void {}
|
||||||
|
|
||||||
|
protected configureUniforms(): void {}
|
||||||
|
|
||||||
private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
|
private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
|
||||||
const gl = this.gl!;
|
const gl = this.gl!;
|
||||||
const vertexArrayObject =
|
const vertexArrayObject =
|
||||||
|
@ -193,17 +206,44 @@ export class MPImageShaderContext {
|
||||||
|
|
||||||
if (!this.program) {
|
if (!this.program) {
|
||||||
this.setupShaders();
|
this.setupShaders();
|
||||||
|
this.setupTextures();
|
||||||
}
|
}
|
||||||
|
|
||||||
const shaderBuffers = this.getShaderBuffers(flipVertically);
|
const shaderBuffers = this.getShaderBuffers(flipVertically);
|
||||||
gl.useProgram(this.program!);
|
gl.useProgram(this.program!);
|
||||||
shaderBuffers.bind();
|
shaderBuffers.bind();
|
||||||
|
this.configureUniforms();
|
||||||
const result = callback();
|
const result = callback();
|
||||||
shaderBuffers.unbind();
|
shaderBuffers.unbind();
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates and configures a texture.
|
||||||
|
*
|
||||||
|
* @param gl The rendering context.
|
||||||
|
* @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and
|
||||||
|
* `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`.
|
||||||
|
* @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and
|
||||||
|
* `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`.
|
||||||
|
*/
|
||||||
|
createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum):
|
||||||
|
WebGLTexture {
|
||||||
|
this.maybeInitGL(gl);
|
||||||
|
const texture =
|
||||||
|
assertNotNull(gl.createTexture(), 'Failed to create texture');
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||||
|
gl.texParameteri(
|
||||||
|
gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(
|
||||||
|
gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR);
|
||||||
|
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR);
|
||||||
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||||
|
return texture;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Binds a framebuffer to the canvas. If the framebuffer does not yet exist,
|
* Binds a framebuffer to the canvas. If the framebuffer does not yet exist,
|
||||||
* creates it first. Binds the provided texture to the framebuffer.
|
* creates it first. Binds the provided texture to the framebuffer.
|
||||||
|
|
|
@ -16,24 +16,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Pre-baked color table for a maximum of 12 classes.
|
|
||||||
const CM_ALPHA = 128;
|
|
||||||
const COLOR_MAP: Array<[number, number, number, number]> = [
|
|
||||||
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
|
||||||
[255, 0, 0, CM_ALPHA], // class 1 is red
|
|
||||||
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
|
||||||
[0, 0, 255, CM_ALPHA], // class 3 is blue
|
|
||||||
[255, 255, 0, CM_ALPHA], // class 4 is yellow
|
|
||||||
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
|
|
||||||
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
|
|
||||||
[128, 128, 128, CM_ALPHA], // class 7 is gray
|
|
||||||
[255, 128, 0, CM_ALPHA], // class 8 is orange
|
|
||||||
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
|
|
||||||
[0, 128, 0, CM_ALPHA], // class 10 is dark green
|
|
||||||
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
|
|
||||||
];
|
|
||||||
|
|
||||||
|
|
||||||
/** Helper function to draw a confidence mask */
|
/** Helper function to draw a confidence mask */
|
||||||
export function drawConfidenceMask(
|
export function drawConfidenceMask(
|
||||||
ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
|
ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
|
||||||
|
@ -47,23 +29,3 @@ export function drawConfidenceMask(
|
||||||
}
|
}
|
||||||
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
|
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper function to draw a category mask. For GPU, we only have F32Arrays
|
|
||||||
* for now.
|
|
||||||
*/
|
|
||||||
export function drawCategoryMask(
|
|
||||||
ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
|
|
||||||
width: number, height: number): void {
|
|
||||||
const rgbaArray = new Uint8ClampedArray(width * height * 4);
|
|
||||||
const isFloatArray = image instanceof Float32Array;
|
|
||||||
for (let i = 0; i < image.length; i++) {
|
|
||||||
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
|
||||||
const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
|
||||||
rgbaArray[4 * i] = color[0];
|
|
||||||
rgbaArray[4 * i + 1] = color[1];
|
|
||||||
rgbaArray[4 * i + 2] = color[2];
|
|
||||||
rgbaArray[4 * i + 3] = color[3];
|
|
||||||
}
|
|
||||||
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user