Add drawCategoryMask() to our public API
PiperOrigin-RevId: 578526413
This commit is contained in:
		
							parent
							
								
									3a55f1156a
								
							
						
					
					
						commit
						9474394768
					
				| 
						 | 
				
			
			@ -31,27 +31,57 @@ mediapipe_ts_library(
 | 
			
		|||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "drawing_utils",
 | 
			
		||||
    srcs = ["drawing_utils.ts"],
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "drawing_utils.ts",
 | 
			
		||||
        "drawing_utils_category_mask.ts",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
        ":mask",
 | 
			
		||||
        ":types",
 | 
			
		||||
        "//mediapipe/tasks/web/components/containers:bounding_box",
 | 
			
		||||
        "//mediapipe/tasks/web/components/containers:landmark",
 | 
			
		||||
        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "image.ts",
 | 
			
		||||
        "image_shader_context.ts",
 | 
			
		||||
    name = "drawing_utils_test_lib",
 | 
			
		||||
    testonly = True,
 | 
			
		||||
    srcs = ["drawing_utils.test.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":drawing_utils",
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
        ":mask",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
jasmine_node_test(
 | 
			
		||||
    name = "drawing_utils_test",
 | 
			
		||||
    deps = [":drawing_utils_test_lib"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image",
 | 
			
		||||
    srcs = ["image.ts"],
 | 
			
		||||
    deps = ["image_shader_context"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image_shader_context",
 | 
			
		||||
    srcs = ["image_shader_context.ts"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image_test_lib",
 | 
			
		||||
    testonly = True,
 | 
			
		||||
    srcs = ["image.test.ts"],
 | 
			
		||||
    deps = [":image"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
jasmine_node_test(
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +94,7 @@ mediapipe_ts_library(
 | 
			
		|||
    srcs = ["mask.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
        "//mediapipe/web/graph_runner:platform_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -74,6 +105,7 @@ mediapipe_ts_library(
 | 
			
		|||
    srcs = ["mask.test.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
        ":mask",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -89,6 +121,7 @@ mediapipe_ts_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        ":image",
 | 
			
		||||
        ":image_processing_options",
 | 
			
		||||
        ":image_shader_context",
 | 
			
		||||
        ":mask",
 | 
			
		||||
        ":vision_task_options",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_jspb_proto",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										103
									
								
								mediapipe/tasks/web/vision/core/drawing_utils.test.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								mediapipe/tasks/web/vision/core/drawing_utils.test.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,103 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import 'jasmine';
 | 
			
		||||
 | 
			
		||||
import {DrawingUtils} from './drawing_utils';
 | 
			
		||||
import {MPImageShaderContext} from './image_shader_context';
 | 
			
		||||
import {MPMask} from './mask';
 | 
			
		||||
 | 
			
		||||
const WIDTH = 2;
 | 
			
		||||
const HEIGHT = 2;
 | 
			
		||||
 | 
			
		||||
const skip = typeof document === 'undefined';
 | 
			
		||||
if (skip) {
 | 
			
		||||
  console.log('These tests must be run in a browser.');
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
(skip ? xdescribe : describe)('DrawingUtils', () => {
 | 
			
		||||
  let shaderContext = new MPImageShaderContext();
 | 
			
		||||
  let canvas2D: HTMLCanvasElement;
 | 
			
		||||
  let context2D: CanvasRenderingContext2D;
 | 
			
		||||
  let drawingUtils2D: DrawingUtils;
 | 
			
		||||
  let canvasWebGL: HTMLCanvasElement;
 | 
			
		||||
  let contextWebGL: WebGL2RenderingContext;
 | 
			
		||||
  let drawingUtilsWebGL: DrawingUtils;
 | 
			
		||||
 | 
			
		||||
  beforeEach(() => {
 | 
			
		||||
    shaderContext = new MPImageShaderContext();
 | 
			
		||||
 | 
			
		||||
    canvasWebGL = document.createElement('canvas');
 | 
			
		||||
    canvasWebGL.width = WIDTH;
 | 
			
		||||
    canvasWebGL.height = HEIGHT;
 | 
			
		||||
    contextWebGL = canvasWebGL.getContext('webgl2')!;
 | 
			
		||||
    drawingUtilsWebGL = new DrawingUtils(contextWebGL);
 | 
			
		||||
 | 
			
		||||
    canvas2D = document.createElement('canvas');
 | 
			
		||||
    canvas2D.width = WIDTH;
 | 
			
		||||
    canvas2D.height = HEIGHT;
 | 
			
		||||
    context2D = canvas2D.getContext('2d')!;
 | 
			
		||||
    drawingUtils2D = new DrawingUtils(context2D, contextWebGL);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  afterEach(() => {
 | 
			
		||||
    shaderContext.close();
 | 
			
		||||
    drawingUtils2D.close();
 | 
			
		||||
    drawingUtilsWebGL.close();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  describe('drawCategoryMask() ', () => {
 | 
			
		||||
    const colors = [
 | 
			
		||||
      [0, 0, 0, 255],
 | 
			
		||||
      [0, 255, 0, 255],
 | 
			
		||||
      [0, 0, 255, 255],
 | 
			
		||||
      [255, 255, 255, 255],
 | 
			
		||||
    ];
 | 
			
		||||
    const expectedResult = new Uint8Array(
 | 
			
		||||
        [0, 0, 0, 255, 0, 255, 0, 255, 0, 0, 255, 255, 255, 255, 255, 255],
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    it('on 2D canvas', () => {
 | 
			
		||||
      const categoryMask = new MPMask(
 | 
			
		||||
          [new Uint8Array([0, 1, 2, 3])],
 | 
			
		||||
          /* ownsWebGLTexture= */ false, canvas2D, shaderContext, WIDTH,
 | 
			
		||||
          HEIGHT);
 | 
			
		||||
 | 
			
		||||
      drawingUtils2D.drawCategoryMask(categoryMask, colors);
 | 
			
		||||
 | 
			
		||||
      const actualResult = context2D.getImageData(0, 0, WIDTH, HEIGHT).data;
 | 
			
		||||
      expect(actualResult)
 | 
			
		||||
          .toEqual(new Uint8ClampedArray(expectedResult.buffer));
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    it('on WebGL canvas', () => {
 | 
			
		||||
      const categoryMask = new MPMask(
 | 
			
		||||
          [new Uint8Array([2, 3, 0, 1])],  // Note: Vertically flipped
 | 
			
		||||
          /* ownsWebGLTexture= */ false, canvasWebGL, shaderContext, WIDTH,
 | 
			
		||||
          HEIGHT);
 | 
			
		||||
 | 
			
		||||
      drawingUtilsWebGL.drawCategoryMask(categoryMask, colors);
 | 
			
		||||
 | 
			
		||||
      const actualResult = new Uint8Array(WIDTH * WIDTH * 4);
 | 
			
		||||
      contextWebGL.readPixels(
 | 
			
		||||
          0, 0, WIDTH, HEIGHT, contextWebGL.RGBA, contextWebGL.UNSIGNED_BYTE,
 | 
			
		||||
          actualResult);
 | 
			
		||||
      expect(actualResult).toEqual(expectedResult);
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  // TODO: Add tests for drawConnectors/drawLandmarks/drawBoundingBox
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			@ -16,7 +16,11 @@
 | 
			
		|||
 | 
			
		||||
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
 | 
			
		||||
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
 | 
			
		||||
import {CategoryMaskShaderContext, CategoryToColorMap, RGBAColor} from '../../../../tasks/web/vision/core/drawing_utils_category_mask';
 | 
			
		||||
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
 | 
			
		||||
import {MPMask} from '../../../../tasks/web/vision/core/mask';
 | 
			
		||||
import {Connection} from '../../../../tasks/web/vision/core/types';
 | 
			
		||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A user-defined callback to take input data and map it to a custom output
 | 
			
		||||
| 
						 | 
				
			
			@ -24,6 +28,9 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
 | 
			
		|||
 */
 | 
			
		||||
export type Callback<I, O> = (input: I) => O;
 | 
			
		||||
 | 
			
		||||
// Used in public API
 | 
			
		||||
export {ImageSource};
 | 
			
		||||
 | 
			
		||||
/** Data that a user can use to specialize drawing options. */
 | 
			
		||||
export declare interface LandmarkData {
 | 
			
		||||
  index?: number;
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +38,32 @@ export declare interface LandmarkData {
 | 
			
		|||
  to?: NormalizedLandmark;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** A color map with 22 classes. Used in our demos. */
 | 
			
		||||
export const DEFAULT_CATEGORY_TO_COLOR_MAP = [
 | 
			
		||||
  [0, 0, 0, 0],          // class 0 is BG = transparent
 | 
			
		||||
  [255, 0, 0, 255],      // class 1 is red
 | 
			
		||||
  [0, 255, 0, 255],      // class 2 is light green
 | 
			
		||||
  [0, 0, 255, 255],      // class 3 is blue
 | 
			
		||||
  [255, 255, 0, 255],    // class 4 is yellow
 | 
			
		||||
  [255, 0, 255, 255],    // class 5 is light purple / magenta
 | 
			
		||||
  [0, 255, 255, 255],    // class 6 is light blue / aqua
 | 
			
		||||
  [128, 128, 128, 255],  // class 7 is gray
 | 
			
		||||
  [255, 100, 0, 255],    // class 8 is dark orange
 | 
			
		||||
  [128, 0, 255, 255],    // class 9 is dark purple
 | 
			
		||||
  [0, 150, 0, 255],      // class 10 is green
 | 
			
		||||
  [255, 255, 255, 255],  // class 11 is white
 | 
			
		||||
  [255, 105, 180, 255],  // class 12 is pink
 | 
			
		||||
  [255, 150, 0, 255],    // class 13 is orange
 | 
			
		||||
  [255, 250, 224, 255],  // class 14 is light yellow
 | 
			
		||||
  [148, 0, 211, 255],    // class 15 is dark violet
 | 
			
		||||
  [0, 100, 0, 255],      // class 16 is dark green
 | 
			
		||||
  [0, 0, 128, 255],      // class 17 is navy blue
 | 
			
		||||
  [165, 42, 42, 255],    // class 18 is brown
 | 
			
		||||
  [64, 224, 208, 255],   // class 19 is turquoise
 | 
			
		||||
  [255, 218, 185, 255],  // class 20 is peach
 | 
			
		||||
  [192, 192, 192, 255],  // class 21 is silver
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Options for customizing the drawing routines
 | 
			
		||||
 */
 | 
			
		||||
| 
						 | 
				
			
			@ -77,14 +110,47 @@ function resolve<O, I>(value: O|Callback<I, O>, data: I): O {
 | 
			
		|||
  return value instanceof Function ? value(data) : value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export {RGBAColor, CategoryToColorMap};
 | 
			
		||||
 | 
			
		||||
/** Helper class to visualize the result of a MediaPipe Vision task. */
 | 
			
		||||
export class DrawingUtils {
 | 
			
		||||
  private categoryMaskShaderContext?: CategoryMaskShaderContext;
 | 
			
		||||
  private convertToWebGLTextureShaderContext?: MPImageShaderContext;
 | 
			
		||||
  private readonly context2d?: CanvasRenderingContext2D;
 | 
			
		||||
  private readonly contextWebGL?: WebGL2RenderingContext;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a new DrawingUtils class.
 | 
			
		||||
   *
 | 
			
		||||
   * @param ctx The canvas to render onto.
 | 
			
		||||
   * @param gpuContext The WebGL canvas rendering context to render into. If
 | 
			
		||||
   *     your Task is using a GPU delegate, the context must be obtained from
 | 
			
		||||
   * its canvas (provided via `setOptions({ canvas: .. })`).
 | 
			
		||||
   */
 | 
			
		||||
  constructor(private readonly ctx: CanvasRenderingContext2D) {}
 | 
			
		||||
  constructor(gpuContext: WebGL2RenderingContext);
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a new DrawingUtils class.
 | 
			
		||||
   *
 | 
			
		||||
   * @param cpuContext The 2D canvas rendering context to render into. If
 | 
			
		||||
   *     you are rendering GPU data you must also provide `gpuContext` to allow
 | 
			
		||||
   *     for data conversion.
 | 
			
		||||
   * @param gpuContext A WebGL canvas that is used for GPU rendering and for
 | 
			
		||||
   *     converting GPU to CPU data. If your Task is using a GPU delegate, the
 | 
			
		||||
   *     context must be obtained from  its canvas (provided via
 | 
			
		||||
   *     `setOptions({ canvas: .. })`).
 | 
			
		||||
   */
 | 
			
		||||
  constructor(
 | 
			
		||||
      cpuContext: CanvasRenderingContext2D,
 | 
			
		||||
      gpuContext?: WebGL2RenderingContext);
 | 
			
		||||
  constructor(
 | 
			
		||||
      cpuOrGpuGontext: CanvasRenderingContext2D|WebGL2RenderingContext,
 | 
			
		||||
      gpuContext?: WebGL2RenderingContext) {
 | 
			
		||||
    if (cpuOrGpuGontext instanceof CanvasRenderingContext2D) {
 | 
			
		||||
      this.context2d = cpuOrGpuGontext;
 | 
			
		||||
      this.contextWebGL = gpuContext;
 | 
			
		||||
    } else {
 | 
			
		||||
      this.contextWebGL = cpuOrGpuGontext;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Restricts a number between two endpoints (order doesn't matter).
 | 
			
		||||
| 
						 | 
				
			
			@ -120,9 +186,35 @@ export class DrawingUtils {
 | 
			
		|||
    return DrawingUtils.clamp(out, y0, y1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private getCanvasRenderingContext(): CanvasRenderingContext2D {
 | 
			
		||||
    if (!this.context2d) {
 | 
			
		||||
      throw new Error(
 | 
			
		||||
          'CPU rendering requested but CanvasRenderingContext2D not provided.');
 | 
			
		||||
    }
 | 
			
		||||
    return this.context2d;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private getWebGLRenderingContext(): WebGL2RenderingContext {
 | 
			
		||||
    if (!this.contextWebGL) {
 | 
			
		||||
      throw new Error(
 | 
			
		||||
          'GPU rendering requested but WebGL2RenderingContext not provided.');
 | 
			
		||||
    }
 | 
			
		||||
    return this.contextWebGL;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private getCategoryMaskShaderContext(): CategoryMaskShaderContext {
 | 
			
		||||
    if (!this.categoryMaskShaderContext) {
 | 
			
		||||
      this.categoryMaskShaderContext = new CategoryMaskShaderContext();
 | 
			
		||||
    }
 | 
			
		||||
    return this.categoryMaskShaderContext;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Draws circles onto the provided landmarks.
 | 
			
		||||
   *
 | 
			
		||||
   * This method can only be used when `DrawingUtils` is initialized with a
 | 
			
		||||
   * `CanvasRenderingContext2D`.
 | 
			
		||||
   *
 | 
			
		||||
   * @export
 | 
			
		||||
   * @param landmarks The landmarks to draw.
 | 
			
		||||
   * @param style The style to visualize the landmarks.
 | 
			
		||||
| 
						 | 
				
			
			@ -132,7 +224,7 @@ export class DrawingUtils {
 | 
			
		|||
    if (!landmarks) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    const ctx = this.ctx;
 | 
			
		||||
    const ctx = this.getCanvasRenderingContext();
 | 
			
		||||
    const options = addDefaultOptions(style);
 | 
			
		||||
    ctx.save();
 | 
			
		||||
    const canvas = ctx.canvas;
 | 
			
		||||
| 
						 | 
				
			
			@ -159,6 +251,9 @@ export class DrawingUtils {
 | 
			
		|||
  /**
 | 
			
		||||
   * Draws lines between landmarks (given a connection graph).
 | 
			
		||||
   *
 | 
			
		||||
   * This method can only be used when `DrawingUtils` is initialized with a
 | 
			
		||||
   * `CanvasRenderingContext2D`.
 | 
			
		||||
   *
 | 
			
		||||
   * @export
 | 
			
		||||
   * @param landmarks The landmarks to draw.
 | 
			
		||||
   * @param connections The connections array that contains the start and the
 | 
			
		||||
| 
						 | 
				
			
			@ -171,7 +266,7 @@ export class DrawingUtils {
 | 
			
		|||
    if (!landmarks || !connections) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    const ctx = this.ctx;
 | 
			
		||||
    const ctx = this.getCanvasRenderingContext();
 | 
			
		||||
    const options = addDefaultOptions(style);
 | 
			
		||||
    ctx.save();
 | 
			
		||||
    const canvas = ctx.canvas;
 | 
			
		||||
| 
						 | 
				
			
			@ -195,12 +290,15 @@ export class DrawingUtils {
 | 
			
		|||
  /**
 | 
			
		||||
   * Draws a bounding box.
 | 
			
		||||
   *
 | 
			
		||||
   * This method can only be used when `DrawingUtils` is initialized with a
 | 
			
		||||
   * `CanvasRenderingContext2D`.
 | 
			
		||||
   *
 | 
			
		||||
   * @export
 | 
			
		||||
   * @param boundingBox The bounding box to draw.
 | 
			
		||||
   * @param style The style to visualize the boundin box.
 | 
			
		||||
   */
 | 
			
		||||
  drawBoundingBox(boundingBox: BoundingBox, style?: DrawingOptions): void {
 | 
			
		||||
    const ctx = this.ctx;
 | 
			
		||||
    const ctx = this.getCanvasRenderingContext();
 | 
			
		||||
    const options = addDefaultOptions(style);
 | 
			
		||||
    ctx.save();
 | 
			
		||||
    ctx.beginPath();
 | 
			
		||||
| 
						 | 
				
			
			@ -218,6 +316,118 @@ export class DrawingUtils {
 | 
			
		|||
    ctx.fill();
 | 
			
		||||
    ctx.restore();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Draws a category mask on a CanvasRenderingContext2D. */
 | 
			
		||||
  private drawCategoryMask2D(
 | 
			
		||||
      mask: MPMask, background: RGBAColor|ImageSource,
 | 
			
		||||
      categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
 | 
			
		||||
    // Use the WebGL renderer to draw result on our internal canvas.
 | 
			
		||||
    const gl = this.getWebGLRenderingContext();
 | 
			
		||||
    this.runWithWebGLTexture(mask, texture => {
 | 
			
		||||
      this.drawCategoryMaskWebGL(texture, background, categoryToColorMap);
 | 
			
		||||
      // Draw the result on the user canvas.
 | 
			
		||||
      const ctx = this.getCanvasRenderingContext();
 | 
			
		||||
      ctx.drawImage(gl.canvas, 0, 0, ctx.canvas.width, ctx.canvas.height);
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Draws a category mask on a WebGL2RenderingContext2D. */
 | 
			
		||||
  private drawCategoryMaskWebGL(
 | 
			
		||||
      categoryTexture: WebGLTexture, background: RGBAColor|ImageSource,
 | 
			
		||||
      categoryToColorMap: Map<number, RGBAColor>|RGBAColor[]): void {
 | 
			
		||||
    const shaderContext = this.getCategoryMaskShaderContext();
 | 
			
		||||
    const gl = this.getWebGLRenderingContext();
 | 
			
		||||
    const backgroundImage = Array.isArray(background) ?
 | 
			
		||||
        new ImageData(new Uint8ClampedArray(background), 1, 1) :
 | 
			
		||||
        background;
 | 
			
		||||
 | 
			
		||||
    shaderContext.run(gl, /* flipTexturesVertically= */ true, () => {
 | 
			
		||||
      shaderContext.bindAndUploadTextures(
 | 
			
		||||
          categoryTexture, backgroundImage, categoryToColorMap);
 | 
			
		||||
      gl.clearColor(0, 0, 0, 0);
 | 
			
		||||
      gl.clear(gl.COLOR_BUFFER_BIT);
 | 
			
		||||
      gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
 | 
			
		||||
      shaderContext.unbindTextures();
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Draws a category mask using the provided category-to-color mapping.
 | 
			
		||||
   *
 | 
			
		||||
   * @export
 | 
			
		||||
   * @param mask A category mask that was returned from a segmentation task.
 | 
			
		||||
   * @param categoryToColorMap A map that maps category indices to RGBA
 | 
			
		||||
   *     values. You must specify a map entry for each category.
 | 
			
		||||
   * @param background A color or image to use as the background. Defaults to
 | 
			
		||||
   *     black.
 | 
			
		||||
   */
 | 
			
		||||
  drawCategoryMask(
 | 
			
		||||
      mask: MPMask, categoryToColorMap: Map<number, RGBAColor>,
 | 
			
		||||
      background?: RGBAColor|ImageSource): void;
 | 
			
		||||
  /**
 | 
			
		||||
   * Draws a category mask using the provided color array.
 | 
			
		||||
   *
 | 
			
		||||
   * @export
 | 
			
		||||
   * @param mask A category mask that was returned from a segmentation task.
 | 
			
		||||
   * @param categoryToColorMap An array that maps indices to RGBA values. The
 | 
			
		||||
   *     array's indices must correspond to the category indices of the model
 | 
			
		||||
   *     and an entry must be provided for each category.
 | 
			
		||||
   * @param background A color or image to use as the background. Defaults to
 | 
			
		||||
   *     black.
 | 
			
		||||
   */
 | 
			
		||||
  drawCategoryMask(
 | 
			
		||||
      mask: MPMask, categoryToColorMap: RGBAColor[],
 | 
			
		||||
      background?: RGBAColor|ImageSource): void;
 | 
			
		||||
  drawCategoryMask(
 | 
			
		||||
      mask: MPMask, categoryToColorMap: CategoryToColorMap,
 | 
			
		||||
      background: RGBAColor|ImageSource = [0, 0, 0, 255]): void {
 | 
			
		||||
    if (this.context2d) {
 | 
			
		||||
      this.drawCategoryMask2D(mask, background, categoryToColorMap);
 | 
			
		||||
    } else {
 | 
			
		||||
      this.drawCategoryMaskWebGL(
 | 
			
		||||
          mask.getAsWebGLTexture(), background, categoryToColorMap);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Converts the given mask to a WebGLTexture and runs the callback. Cleans
 | 
			
		||||
   * up any new resources after the callback finished executing.
 | 
			
		||||
   */
 | 
			
		||||
  private runWithWebGLTexture(
 | 
			
		||||
      mask: MPMask, callback: (texture: WebGLTexture) => void): void {
 | 
			
		||||
    if (!mask.hasWebGLTexture()) {
 | 
			
		||||
      // Re-create the MPMask but use our the WebGL canvas so we can draw the
 | 
			
		||||
      // texture directly.
 | 
			
		||||
      const data = mask.hasFloat32Array() ? mask.getAsFloat32Array() :
 | 
			
		||||
                                            mask.getAsUint8Array();
 | 
			
		||||
      this.convertToWebGLTextureShaderContext =
 | 
			
		||||
          this.convertToWebGLTextureShaderContext ?? new MPImageShaderContext();
 | 
			
		||||
      const gl = this.getWebGLRenderingContext();
 | 
			
		||||
 | 
			
		||||
      const convertedMask = new MPMask(
 | 
			
		||||
          [data],
 | 
			
		||||
          /* ownsWebGlTexture= */ false,
 | 
			
		||||
          gl.canvas,
 | 
			
		||||
          this.convertToWebGLTextureShaderContext,
 | 
			
		||||
          mask.width,
 | 
			
		||||
          mask.height,
 | 
			
		||||
      );
 | 
			
		||||
      callback(convertedMask.getAsWebGLTexture());
 | 
			
		||||
      convertedMask.close();
 | 
			
		||||
    } else {
 | 
			
		||||
      callback(mask.getAsWebGLTexture());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  /**
 | 
			
		||||
   * Frees all WebGL resources held by this class.
 | 
			
		||||
   * @export
 | 
			
		||||
   */
 | 
			
		||||
  close(): void {
 | 
			
		||||
    this.categoryMaskShaderContext?.close();
 | 
			
		||||
    this.categoryMaskShaderContext = undefined;
 | 
			
		||||
    this.convertToWebGLTextureShaderContext?.close();
 | 
			
		||||
    this.convertToWebGLTextureShaderContext = undefined;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										189
									
								
								mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								mediapipe/tasks/web/vision/core/drawing_utils_category_mask.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,189 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
 | 
			
		||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A fragment shader that maps categories to colors based on a background
 | 
			
		||||
 * texture, a mask texture and a 256x1 "color mapping texture" that contains one
 | 
			
		||||
 * color for each pixel.
 | 
			
		||||
 */
 | 
			
		||||
const FRAGMENT_SHADER = `
 | 
			
		||||
  precision mediump float;
 | 
			
		||||
  uniform sampler2D backgroundTexture;
 | 
			
		||||
  uniform sampler2D maskTexture;
 | 
			
		||||
  uniform sampler2D colorMappingTexture;
 | 
			
		||||
  varying vec2 vTex;
 | 
			
		||||
  void main() {
 | 
			
		||||
    vec4 backgroundColor = texture2D(backgroundTexture, vTex);
 | 
			
		||||
    float category = texture2D(maskTexture, vTex).r;
 | 
			
		||||
    vec4 categoryColor = texture2D(colorMappingTexture, vec2(category, 0.0));
 | 
			
		||||
    gl_FragColor = mix(backgroundColor, categoryColor, categoryColor.a);
 | 
			
		||||
  }
 | 
			
		||||
 `;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A four channel color with values for red, green, blue and alpha
 | 
			
		||||
 * respectively.
 | 
			
		||||
 */
 | 
			
		||||
export type RGBAColor = [number, number, number, number]|number[];
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A category to color mapping that uses either a map or an array to assign
 | 
			
		||||
 * category indexes to RGBA colors.
 | 
			
		||||
 */
 | 
			
		||||
export type CategoryToColorMap = Map<number, RGBAColor>|RGBAColor[];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/** Checks CategoryToColorMap maps for deep equality. */
 | 
			
		||||
function isEqualColorMap(
 | 
			
		||||
    a: CategoryToColorMap, b: CategoryToColorMap): boolean {
 | 
			
		||||
  if (a !== b) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const aEntries = a.entries();
 | 
			
		||||
  const bEntries = b.entries();
 | 
			
		||||
  for (const [aKey, aValue] of aEntries) {
 | 
			
		||||
    const bNext = bEntries.next();
 | 
			
		||||
    if (bNext.done) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const [bKey, bValue] = bNext.value;
 | 
			
		||||
    if (aKey !== bKey) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (aValue[0] !== bValue[0] || aValue[1] !== bValue[1] ||
 | 
			
		||||
        aValue[2] !== bValue[2] || aValue[3] !== bValue[3]) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return !!bEntries.next().done;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/** A drawing util class for category masks. */
 | 
			
		||||
export class CategoryMaskShaderContext extends MPImageShaderContext {
 | 
			
		||||
  backgroundTexture?: WebGLTexture;
 | 
			
		||||
  colorMappingTexture?: WebGLTexture;
 | 
			
		||||
  colorMappingTextureUniform?: WebGLUniformLocation;
 | 
			
		||||
  backgroundTextureUniform?: WebGLUniformLocation;
 | 
			
		||||
  maskTextureUniform?: WebGLUniformLocation;
 | 
			
		||||
  currentColorMap?: CategoryToColorMap;
 | 
			
		||||
 | 
			
		||||
  bindAndUploadTextures(
 | 
			
		||||
      categoryMask: WebGLTexture, background: ImageSource,
 | 
			
		||||
      colorMap: Map<number, number[]>|number[][]) {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
 | 
			
		||||
    // TODO: We should avoid uploading textures from CPU to GPU
 | 
			
		||||
    // if the textures haven't changed. This can lead to drastic performance
 | 
			
		||||
    // slowdowns (~50ms per frame). Users can reduce the penalty by passing a
 | 
			
		||||
    // canvas object instead of ImageData/HTMLImageElement.
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE0);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, this.backgroundTexture!);
 | 
			
		||||
    gl.texImage2D(
 | 
			
		||||
        gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, background);
 | 
			
		||||
 | 
			
		||||
    // Bind color mapping texture if changed.
 | 
			
		||||
    if (!this.currentColorMap ||
 | 
			
		||||
        !isEqualColorMap(this.currentColorMap, colorMap)) {
 | 
			
		||||
      this.currentColorMap = colorMap;
 | 
			
		||||
 | 
			
		||||
      const pixels = new Array(256 * 4).fill(0);
 | 
			
		||||
      colorMap.forEach((rgba, index) => {
 | 
			
		||||
        if (rgba.length !== 4) {
 | 
			
		||||
          throw new Error(
 | 
			
		||||
              `Color at index ${index} is not a four-channel value.`);
 | 
			
		||||
        }
 | 
			
		||||
        pixels[index * 4] = rgba[0];
 | 
			
		||||
        pixels[index * 4 + 1] = rgba[1];
 | 
			
		||||
        pixels[index * 4 + 2] = rgba[2];
 | 
			
		||||
        pixels[index * 4 + 3] = rgba[3];
 | 
			
		||||
      });
 | 
			
		||||
      gl.activeTexture(gl.TEXTURE1);
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
 | 
			
		||||
      gl.texImage2D(
 | 
			
		||||
          gl.TEXTURE_2D, 0, gl.RGBA, 256, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE,
 | 
			
		||||
          new Uint8Array(pixels));
 | 
			
		||||
    } else {
 | 
			
		||||
      gl.activeTexture(gl.TEXTURE1);
 | 
			
		||||
      gl.bindTexture(gl.TEXTURE_2D, this.colorMappingTexture!);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Bind category mask
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE2);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, categoryMask);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  unbindTextures() {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE0);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE1);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE2);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override getFragmentShader(): string {
 | 
			
		||||
    return FRAGMENT_SHADER;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override setupTextures(): void {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    gl.activeTexture(gl.TEXTURE0);
 | 
			
		||||
    this.backgroundTexture = this.createTexture(gl, gl.LINEAR);
 | 
			
		||||
    // Use `gl.NEAREST` to prevent interpolating values in our category to
 | 
			
		||||
    // color map.
 | 
			
		||||
    this.colorMappingTexture = this.createTexture(gl, gl.NEAREST);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override setupShaders(): void {
 | 
			
		||||
    super.setupShaders();
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    this.backgroundTextureUniform = assertNotNull(
 | 
			
		||||
        gl.getUniformLocation(this.program!, 'backgroundTexture'),
 | 
			
		||||
        'Uniform location');
 | 
			
		||||
    this.colorMappingTextureUniform = assertNotNull(
 | 
			
		||||
        gl.getUniformLocation(this.program!, 'colorMappingTexture'),
 | 
			
		||||
        'Uniform location');
 | 
			
		||||
    this.maskTextureUniform = assertNotNull(
 | 
			
		||||
        gl.getUniformLocation(this.program!, 'maskTexture'),
 | 
			
		||||
        'Uniform location');
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override configureUniforms(): void {
 | 
			
		||||
    super.configureUniforms();
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    gl.uniform1i(this.backgroundTextureUniform!, 0);
 | 
			
		||||
    gl.uniform1i(this.colorMappingTextureUniform!, 1);
 | 
			
		||||
    gl.uniform1i(this.maskTextureUniform!, 2);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  override close(): void {
 | 
			
		||||
    if (this.backgroundTexture) {
 | 
			
		||||
      this.gl!.deleteTexture(this.backgroundTexture);
 | 
			
		||||
    }
 | 
			
		||||
    if (this.colorMappingTexture) {
 | 
			
		||||
      this.gl!.deleteTexture(this.colorMappingTexture);
 | 
			
		||||
    }
 | 
			
		||||
    super.close();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -27,9 +27,9 @@ const FRAGMENT_SHADER = `
 | 
			
		|||
  precision mediump float;
 | 
			
		||||
  varying vec2 vTex;
 | 
			
		||||
  uniform sampler2D inputTexture;
 | 
			
		||||
   void main() {
 | 
			
		||||
     gl_FragColor = texture2D(inputTexture, vTex);
 | 
			
		||||
   }
 | 
			
		||||
  void main() {
 | 
			
		||||
    gl_FragColor = texture2D(inputTexture, vTex);
 | 
			
		||||
  }
 | 
			
		||||
 `;
 | 
			
		||||
 | 
			
		||||
/** Helper to assert that `value` is not null.  */
 | 
			
		||||
| 
						 | 
				
			
			@ -73,9 +73,9 @@ class MPImageShaderBuffers {
 | 
			
		|||
 * For internal use only.
 | 
			
		||||
 */
 | 
			
		||||
export class MPImageShaderContext {
 | 
			
		||||
  private gl?: WebGL2RenderingContext;
 | 
			
		||||
  protected gl?: WebGL2RenderingContext;
 | 
			
		||||
  private framebuffer?: WebGLFramebuffer;
 | 
			
		||||
  private program?: WebGLProgram;
 | 
			
		||||
  protected program?: WebGLProgram;
 | 
			
		||||
  private vertexShader?: WebGLShader;
 | 
			
		||||
  private fragmentShader?: WebGLShader;
 | 
			
		||||
  private aVertex?: GLint;
 | 
			
		||||
| 
						 | 
				
			
			@ -94,6 +94,14 @@ export class MPImageShaderContext {
 | 
			
		|||
   */
 | 
			
		||||
  private shaderBuffersFlipVertically?: MPImageShaderBuffers;
 | 
			
		||||
 | 
			
		||||
  protected getFragmentShader(): string {
 | 
			
		||||
    return FRAGMENT_SHADER;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected getVertexShader(): string {
 | 
			
		||||
    return VERTEX_SHADER;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private compileShader(source: string, type: number): WebGLShader {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    const shader =
 | 
			
		||||
| 
						 | 
				
			
			@ -108,14 +116,15 @@ export class MPImageShaderContext {
 | 
			
		|||
    return shader;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private setupShaders(): void {
 | 
			
		||||
  protected setupShaders(): void {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    this.program =
 | 
			
		||||
        assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
 | 
			
		||||
 | 
			
		||||
    this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER);
 | 
			
		||||
    this.vertexShader =
 | 
			
		||||
        this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER);
 | 
			
		||||
    this.fragmentShader =
 | 
			
		||||
        this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER);
 | 
			
		||||
        this.compileShader(this.getFragmentShader(), gl.FRAGMENT_SHADER);
 | 
			
		||||
 | 
			
		||||
    gl.linkProgram(this.program);
 | 
			
		||||
    const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS);
 | 
			
		||||
| 
						 | 
				
			
			@ -128,6 +137,10 @@ export class MPImageShaderContext {
 | 
			
		|||
    this.aTex = gl.getAttribLocation(this.program, 'aTex');
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected setupTextures(): void {}
 | 
			
		||||
 | 
			
		||||
  protected configureUniforms(): void {}
 | 
			
		||||
 | 
			
		||||
  private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
 | 
			
		||||
    const gl = this.gl!;
 | 
			
		||||
    const vertexArrayObject =
 | 
			
		||||
| 
						 | 
				
			
			@ -193,17 +206,44 @@ export class MPImageShaderContext {
 | 
			
		|||
 | 
			
		||||
    if (!this.program) {
 | 
			
		||||
      this.setupShaders();
 | 
			
		||||
      this.setupTextures();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const shaderBuffers = this.getShaderBuffers(flipVertically);
 | 
			
		||||
    gl.useProgram(this.program!);
 | 
			
		||||
    shaderBuffers.bind();
 | 
			
		||||
    this.configureUniforms();
 | 
			
		||||
    const result = callback();
 | 
			
		||||
    shaderBuffers.unbind();
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates and configures a texture.
 | 
			
		||||
   *
 | 
			
		||||
   * @param gl The rendering context.
 | 
			
		||||
   * @param filter The setting to use for `gl.TEXTURE_MIN_FILTER` and
 | 
			
		||||
   *     `gl.TEXTURE_MAG_FILTER`. Defaults to `gl.LINEAR`.
 | 
			
		||||
   * @param wrapping The setting to use for `gl.TEXTURE_WRAP_S` and
 | 
			
		||||
   *     `gl.TEXTURE_WRAP_T`. Defaults to `gl.CLAMP_TO_EDGE`.
 | 
			
		||||
   */
 | 
			
		||||
  createTexture(gl: WebGL2RenderingContext, filter?: GLenum, wrapping?: GLenum):
 | 
			
		||||
      WebGLTexture {
 | 
			
		||||
    this.maybeInitGL(gl);
 | 
			
		||||
    const texture =
 | 
			
		||||
        assertNotNull(gl.createTexture(), 'Failed to create texture');
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, texture);
 | 
			
		||||
    gl.texParameteri(
 | 
			
		||||
        gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(
 | 
			
		||||
        gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, wrapping ?? gl.CLAMP_TO_EDGE);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, filter ?? gl.LINEAR);
 | 
			
		||||
    gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, filter ?? gl.LINEAR);
 | 
			
		||||
    gl.bindTexture(gl.TEXTURE_2D, null);
 | 
			
		||||
    return texture;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Binds a framebuffer to the canvas. If the framebuffer does not yet exist,
 | 
			
		||||
   * creates it first. Binds the provided texture to the framebuffer.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,24 +16,6 @@
 | 
			
		|||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
// Pre-baked color table for a maximum of 12 classes.
 | 
			
		||||
const CM_ALPHA = 128;
 | 
			
		||||
const COLOR_MAP: Array<[number, number, number, number]> = [
 | 
			
		||||
  [0, 0, 0, CM_ALPHA],        // class 0 is BG = transparent
 | 
			
		||||
  [255, 0, 0, CM_ALPHA],      // class 1 is red
 | 
			
		||||
  [0, 255, 0, CM_ALPHA],      // class 2 is light green
 | 
			
		||||
  [0, 0, 255, CM_ALPHA],      // class 3 is blue
 | 
			
		||||
  [255, 255, 0, CM_ALPHA],    // class 4 is yellow
 | 
			
		||||
  [255, 0, 255, CM_ALPHA],    // class 5 is light purple / magenta
 | 
			
		||||
  [0, 255, 255, CM_ALPHA],    // class 6 is light blue / aqua
 | 
			
		||||
  [128, 128, 128, CM_ALPHA],  // class 7 is gray
 | 
			
		||||
  [255, 128, 0, CM_ALPHA],    // class 8 is orange
 | 
			
		||||
  [128, 0, 255, CM_ALPHA],    // class 9 is dark purple
 | 
			
		||||
  [0, 128, 0, CM_ALPHA],      // class 10 is dark green
 | 
			
		||||
  [255, 255, 255, CM_ALPHA]   // class 11 is white; could do black instead?
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/** Helper function to draw a confidence mask */
 | 
			
		||||
export function drawConfidenceMask(
 | 
			
		||||
    ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
 | 
			
		||||
| 
						 | 
				
			
			@ -47,23 +29,3 @@ export function drawConfidenceMask(
 | 
			
		|||
  }
 | 
			
		||||
  ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Helper function to draw a category mask. For GPU, we only have F32Arrays
 | 
			
		||||
 * for now.
 | 
			
		||||
 */
 | 
			
		||||
export function drawCategoryMask(
 | 
			
		||||
    ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
 | 
			
		||||
    width: number, height: number): void {
 | 
			
		||||
  const rgbaArray = new Uint8ClampedArray(width * height * 4);
 | 
			
		||||
  const isFloatArray = image instanceof Float32Array;
 | 
			
		||||
  for (let i = 0; i < image.length; i++) {
 | 
			
		||||
    const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
 | 
			
		||||
    const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
 | 
			
		||||
    rgbaArray[4 * i] = color[0];
 | 
			
		||||
    rgbaArray[4 * i + 1] = color[1];
 | 
			
		||||
    rgbaArray[4 * i + 2] = color[2];
 | 
			
		||||
    rgbaArray[4 * i + 3] = color[3];
 | 
			
		||||
  }
 | 
			
		||||
  ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user