Create a MediaPipe Mask Type

PiperOrigin-RevId: 529868427
This commit is contained in:
Sebastian Schmidt 2023-05-05 19:21:50 -07:00 committed by Copybara-Service
parent 3562a7f7dc
commit e707c84a3d
6 changed files with 616 additions and 0 deletions

View File

@ -21,6 +21,7 @@ VISION_LIBS = [
"//mediapipe/tasks/web/core:fileset_resolver",
"//mediapipe/tasks/web/vision/core:drawing_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/face_detector",
"//mediapipe/tasks/web/vision/face_landmarker",
"//mediapipe/tasks/web/vision/face_stylizer",

View File

@ -60,6 +60,27 @@ jasmine_node_test(
deps = [":image_test_lib"],
)
mediapipe_ts_library(
name = "mask",
srcs = ["mask.ts"],
deps = [":image"],
)
mediapipe_ts_library(
name = "mask_test_lib",
testonly = True,
srcs = ["mask.test.ts"],
deps = [
":image",
":mask",
],
)
jasmine_node_test(
name = "mask_test",
deps = [":mask_test_lib"],
)
mediapipe_ts_library(
name = "vision_task_runner",
srcs = ["vision_task_runner.ts"],

View File

@ -0,0 +1,268 @@
/**
* Copyright 2022 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {MPImageShaderContext} from './image_shader_context';
import {MPMask, MPMaskType} from './mask';
const WIDTH = 2;
const HEIGHT = 2;
const skip = typeof document === 'undefined';
if (skip) {
console.log('These tests must be run in a browser.');
}
/** The mask types supported by MPMask. */
type MaskType = Uint8Array|Float32Array|WebGLTexture;
const MASK_2_1 = [1, 2];
const MASK_2_2 = [1, 2, 3, 4];
const MASK_2_3 = [1, 2, 3, 4, 5, 6];
/** The test images and data to use for the unit tests below. */
class MPMaskTestContext {
canvas!: OffscreenCanvas;
gl!: WebGL2RenderingContext;
uint8Array!: Uint8Array;
float32Array!: Float32Array;
webGLTexture!: WebGLTexture;
async init(pixels = MASK_2_2, width = WIDTH, height = HEIGHT): Promise<void> {
// Initialize a canvas with default dimensions. Note that the canvas size
// can be different from the mask size.
this.canvas = new OffscreenCanvas(WIDTH, HEIGHT);
this.gl = this.canvas.getContext('webgl2') as WebGL2RenderingContext;
const gl = this.gl;
if (!gl.getExtension('EXT_color_buffer_float')) {
throw new Error('Missing required EXT_color_buffer_float extension');
}
this.uint8Array = new Uint8Array(pixels);
this.float32Array = new Float32Array(pixels.length);
for (let i = 0; i < this.uint8Array.length; ++i) {
this.float32Array[i] = pixels[i] / 255;
}
this.webGLTexture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, width, height, 0, gl.RED, gl.FLOAT,
new Float32Array(pixels).map(v => v / 255));
gl.bindTexture(gl.TEXTURE_2D, null);
}
get(type: unknown) {
switch (type) {
case Uint8Array:
return this.uint8Array;
case Float32Array:
return this.float32Array;
case WebGLTexture:
return this.webGLTexture;
default:
throw new Error(`Unsupported type: ${type}`);
}
}
close(): void {
this.gl.deleteTexture(this.webGLTexture);
}
}
(skip ? xdescribe : describe)('MPMask', () => {
const context = new MPMaskTestContext();
afterEach(() => {
context.close();
});
function readPixelsFromWebGLTexture(texture: WebGLTexture): Float32Array {
const pixels = new Float32Array(WIDTH * HEIGHT);
const gl = context.gl;
gl.bindTexture(gl.TEXTURE_2D, texture);
const framebuffer = gl.createFramebuffer()!;
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.framebufferTexture2D(
gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
gl.readPixels(0, 0, WIDTH, HEIGHT, gl.RED, gl.FLOAT, pixels);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteFramebuffer(framebuffer);
gl.bindTexture(gl.TEXTURE_2D, null);
// Sanity check values
expect(pixels[0]).not.toBe(0);
return pixels;
}
function assertEquality(mask: MPMask, expected: MaskType): void {
if (expected instanceof Uint8Array) {
const result = mask.get(MPMaskType.UINT8_ARRAY);
expect(result).toEqual(expected);
} else if (expected instanceof Float32Array) {
const result = mask.get(MPMaskType.FLOAT32_ARRAY);
expect(result).toEqual(expected);
} else { // WebGLTexture
const result = mask.get(MPMaskType.WEBGL_TEXTURE);
expect(readPixelsFromWebGLTexture(result))
.toEqual(readPixelsFromWebGLTexture(expected));
}
}
function createImage(
shaderContext: MPImageShaderContext, input: MaskType, width: number,
height: number): MPMask {
return new MPMask(
[input],
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, width,
height);
}
function runConversionTest(
input: MaskType, output: MaskType, width = WIDTH, height = HEIGHT): void {
const shaderContext = new MPImageShaderContext();
const mask = createImage(shaderContext, input, width, height);
assertEquality(mask, output);
mask.close();
shaderContext.close();
}
function runCloneTest(input: MaskType): void {
const shaderContext = new MPImageShaderContext();
const mask = createImage(shaderContext, input, WIDTH, HEIGHT);
const clone = mask.clone();
assertEquality(clone, input);
clone.close();
shaderContext.close();
}
const sources = skip ? [] : [Uint8Array, Float32Array, WebGLTexture];
for (let i = 0; i < sources.length; i++) {
for (let j = 0; j < sources.length; j++) {
it(`converts from ${sources[i].name} to ${sources[j].name}`, async () => {
await context.init();
runConversionTest(context.get(sources[i]), context.get(sources[j]));
});
}
}
for (let i = 0; i < sources.length; i++) {
it(`clones ${sources[i].name}`, async () => {
await context.init();
runCloneTest(context.get(sources[i]));
});
}
it(`does not flip textures twice`, async () => {
await context.init();
const shaderContext = new MPImageShaderContext();
const mask = new MPMask(
[context.webGLTexture],
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT);
const result = mask.clone().get(MPMaskType.UINT8_ARRAY);
expect(result).toEqual(context.uint8Array);
shaderContext.close();
});
it(`can clone and get mask`, async () => {
await context.init();
const shaderContext = new MPImageShaderContext();
const mask = new MPMask(
[context.webGLTexture],
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT);
// Verify that we can mix the different shader modes by running them out of
// order.
let result = mask.get(MPMaskType.UINT8_ARRAY);
expect(result).toEqual(context.uint8Array);
result = mask.clone().get(MPMaskType.UINT8_ARRAY);
expect(result).toEqual(context.uint8Array);
result = mask.get(MPMaskType.UINT8_ARRAY);
expect(result).toEqual(context.uint8Array);
shaderContext.close();
});
it('supports has()', async () => {
await context.init();
const shaderContext = new MPImageShaderContext();
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false);
mask.get(MPMaskType.FLOAT32_ARRAY);
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
mask.get(MPMaskType.WEBGL_TEXTURE);
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true);
mask.close();
shaderContext.close();
});
it('supports mask that is smaller than the canvas', async () => {
await context.init(MASK_2_1, /* width= */ 2, /* height= */ 1);
runConversionTest(
context.uint8Array, context.webGLTexture, /* width= */ 2,
/* height= */ 1);
runConversionTest(
context.webGLTexture, context.float32Array, /* width= */ 2,
/* height= */ 1);
runConversionTest(
context.float32Array, context.uint8Array, /* width= */ 2,
/* height= */ 1);
context.close();
});
it('supports mask that is larger than the canvas', async () => {
await context.init(MASK_2_3, /* width= */ 2, /* height= */ 3);
runConversionTest(
context.uint8Array, context.webGLTexture, /* width= */ 2,
/* height= */ 3);
runConversionTest(
context.webGLTexture, context.float32Array, /* width= */ 2,
/* height= */ 3);
runConversionTest(
context.float32Array, context.uint8Array, /* width= */ 2,
/* height= */ 3);
});
});

View File

@ -0,0 +1,320 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** The underlying type of the image. */
export enum MPMaskType {
/** Represents the native `UInt8Array` type. */
UINT8_ARRAY,
/** Represents the native `Float32Array` type. */
FLOAT32_ARRAY,
/** Represents the native `WebGLTexture` type. */
WEBGL_TEXTURE
}
/** The supported mask formats. For internal usage. */
export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
/**
* The wrapper class for MediaPipe segmentation masks.
*
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
* You can convert the underlying type to any other type by passing the desired
* type to `get()`. As type conversions can be expensive, it is recommended to
* limit these conversions. You can verify what underlying types are already
* available by invoking `has()`.
*
* Masks that are returned from a MediaPipe Tasks are owned by by the
* underlying C++ Task. If you need to extend the lifetime of these objects,
* you can invoke the `clone()` method. To free up the resources obtained
* during any clone or type conversion operation, it is important to invoke
* `close()` on the `MPMask` instance.
*/
export class MPMask {
private gl?: WebGL2RenderingContext;
/** The underlying type of the mask. */
static TYPE = MPMaskType;
/** @hideconstructor */
constructor(
private readonly containers: MPMaskContainer[],
private ownsWebGLTexture: boolean,
/** Returns the canvas element that the mask is bound to. */
readonly canvas: HTMLCanvasElement|OffscreenCanvas|undefined,
private shaderContext: MPImageShaderContext|undefined,
/** Returns the width of the mask. */
readonly width: number,
/** Returns the height of the mask. */
readonly height: number,
) {}
/**
* Returns whether this `MPMask` stores the mask in the desired
* format. This method can be called to reduce expensive conversion before
* invoking `get()`.
*/
has(type: MPMaskType): boolean {
return !!this.getContainer(type);
}
/**
* Returns the underlying mask as a Uint8Array`. Note that this involves an
* expensive GPU to CPU transfer if the current mask is only available as a
* `WebGLTexture`.
*
* @param type The type of mask to return.
* @return The current data as a Uint8Array.
*/
get(type: MPMaskType.UINT8_ARRAY): Uint8Array;
/**
* Returns the underlying mask as a single channel `Float32Array`. Note that
* this involves an expensive GPU to CPU transfer if the current mask is only
* available as a `WebGLTexture`.
*
* @param type The type of mask to return.
* @return The current mask as a Float32Array.
*/
get(type: MPMaskType.FLOAT32_ARRAY): Float32Array;
/**
* Returns the underlying mask as a `WebGLTexture` object. Note that this
* involves a CPU to GPU transfer if the current mask is only available as
* a CPU array. The returned texture is bound to the current canvas (see
* `.canvas`).
*
* @param type The type of mask to return.
* @return The current mask as a WebGLTexture.
*/
get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture;
get(type?: MPMaskType): MPMaskContainer {
switch (type) {
case MPMaskType.UINT8_ARRAY:
return this.convertToUint8Array();
case MPMaskType.FLOAT32_ARRAY:
return this.convertToFloat32Array();
case MPMaskType.WEBGL_TEXTURE:
return this.convertToWebGLTexture();
default:
throw new Error(`Type is not supported: ${type}`);
}
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
private getContainer(type: MPMaskType): MPMaskContainer|undefined;
/** Returns the container for the requested storage type iff it exists. */
private getContainer(type: MPMaskType): MPMaskContainer|undefined {
switch (type) {
case MPMaskType.UINT8_ARRAY:
return this.containers.find(img => img instanceof Uint8Array);
case MPMaskType.FLOAT32_ARRAY:
return this.containers.find(img => img instanceof Float32Array);
case MPMaskType.WEBGL_TEXTURE:
return this.containers.find(
img => typeof WebGLTexture !== 'undefined' &&
img instanceof WebGLTexture);
default:
throw new Error(`Type is not supported: ${type}`);
}
}
/**
* Creates a copy of the resources stored in this `MPMask`. You can
* invoke this method to extend the lifetime of a mask returned by a
* MediaPipe Task. Note that performance critical applications should aim to
* only use the `MPMask` within the MediaPipe Task callback so that
* copies can be avoided.
*/
clone(): MPMask {
const destinationContainers: MPMaskContainer[] = [];
// TODO: We might only want to clone one backing datastructure
// even if multiple are defined;
for (const container of this.containers) {
let destinationContainer: MPMaskContainer;
if (container instanceof Uint8Array) {
destinationContainer = new Uint8Array(container);
} else if (container instanceof Float32Array) {
destinationContainer = new Float32Array(container);
} else if (container instanceof WebGLTexture) {
const gl = this.getGL();
const shaderContext = this.getShaderContext();
// Create a new texture and use it to back a framebuffer
gl.activeTexture(gl.TEXTURE1);
destinationContainer =
assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, destinationContainer);
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.FLOAT, null);
gl.bindTexture(gl.TEXTURE_2D, null);
shaderContext.bindFramebuffer(gl, destinationContainer);
shaderContext.run(gl, /* flipVertically= */ false, () => {
this.bindTexture(); // This activates gl.TEXTURE0
gl.clearColor(0, 0, 0, 0);
gl.clear(gl.COLOR_BUFFER_BIT);
gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
this.unbindTexture();
});
shaderContext.unbindFramebuffer();
this.unbindTexture();
} else {
throw new Error(`Type is not supported: ${container}`);
}
destinationContainers.push(destinationContainer);
}
return new MPMask(
destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas,
this.shaderContext, this.width, this.height);
}
private getGL(): WebGL2RenderingContext {
if (!this.canvas) {
throw new Error(
'Conversion to different image formats require that a canvas ' +
'is passed when iniitializing the image.');
}
if (!this.gl) {
this.gl = assertNotNull(
this.canvas.getContext('webgl2') as WebGL2RenderingContext | null,
'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.');
}
const ext = this.gl.getExtension('EXT_color_buffer_float');
if (!ext) {
// TODO: Ensure this works on iOS
throw new Error('Missing required EXT_color_buffer_float extension');
}
return this.gl;
}
private getShaderContext(): MPImageShaderContext {
if (!this.shaderContext) {
this.shaderContext = new MPImageShaderContext();
}
return this.shaderContext;
}
private convertToFloat32Array(): Float32Array {
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
if (!float32Array) {
if (this.has(MPMaskType.UINT8_ARRAY)) {
const source = this.getContainer(MPMaskType.UINT8_ARRAY)!;
float32Array = new Float32Array(source).map(v => v / 255);
} else {
const gl = this.getGL();
const shaderContext = this.getShaderContext();
float32Array = new Float32Array(this.width * this.height);
// Create texture if needed
const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RED, gl.FLOAT, float32Array);
shaderContext.unbindFramebuffer();
}
this.containers.push(float32Array);
}
return float32Array;
}
private convertToUint8Array(): Uint8Array {
let uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY);
if (!uint8Array) {
const floatArray = this.convertToFloat32Array();
uint8Array = new Uint8Array(floatArray.map(v => 255 * v));
this.containers.push(uint8Array);
}
return uint8Array;
}
private convertToWebGLTexture(): WebGLTexture {
let webGLTexture = this.getContainer(MPMaskType.WEBGL_TEXTURE);
if (!webGLTexture) {
const gl = this.getGL();
webGLTexture = this.bindTexture();
const data = this.convertToFloat32Array();
// TODO: Add support for R16F to support iOS
gl.texImage2D(
gl.TEXTURE_2D, 0, gl.R32F, this.width, this.height, 0, gl.RED,
gl.FLOAT, data);
this.unbindTexture();
}
return webGLTexture;
}
/**
* Binds the backing texture to the canvas. If the texture does not yet
* exist, creates it first.
*/
private bindTexture(): WebGLTexture {
const gl = this.getGL();
gl.viewport(0, 0, this.width, this.height);
gl.activeTexture(gl.TEXTURE0);
let webGLTexture = this.getContainer(MPMaskType.WEBGL_TEXTURE);
if (!webGLTexture) {
webGLTexture =
assertNotNull(gl.createTexture(), 'Failed to create texture');
this.containers.push(webGLTexture);
this.ownsWebGLTexture = true;
}
gl.bindTexture(gl.TEXTURE_2D, webGLTexture);
// TODO: Ideally, we would only set these once per texture and
// not once every frame.
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
return webGLTexture;
}
private unbindTexture(): void {
this.gl!.bindTexture(this.gl!.TEXTURE_2D, null);
}
/**
* Frees up any resources owned by this `MPMask` instance.
*
* Note that this method does not free masks that are owned by the C++
* Task, as these are freed automatically once you leave the MediaPipe
* callback. Additionally, some shared state is freed only once you invoke
* the Task's `close()` method.
*/
close(): void {
if (this.ownsWebGLTexture) {
const gl = this.getGL();
gl.deleteTexture(this.getContainer(MPMaskType.WEBGL_TEXTURE)!);
}
}
}

View File

@ -17,6 +17,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl, MPImageType as MPImageTypeImpl} from '../../../tasks/web/vision/core/image';
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
@ -35,6 +36,8 @@ const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl;
const MPImageType = MPImageTypeImpl;
const MPMask = MPMaskImpl;
const MPMaskType = MPMaskTypeImpl;
const FaceDetector = FaceDetectorImpl;
const FaceLandmarker = FaceLandmarkerImpl;
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
@ -53,6 +56,8 @@ export {
FilesetResolver,
MPImage,
MPImageType,
MPMask,
MPMaskType,
FaceDetector,
FaceLandmarker,
FaceLandmarksConnections,

View File

@ -17,6 +17,7 @@
export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils';
export {MPImage, MPImageChannelConverter, MPImageType} from '../../../tasks/web/vision/core/image';
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';