Simplify MPMask by removing the Type Enums from the public API

PiperOrigin-RevId: 529975377
This commit is contained in:
Sebastian Schmidt 2023-05-06 10:23:58 -07:00 committed by Copybara-Service
parent e9fc66277a
commit ddb84702f6
4 changed files with 52 additions and 58 deletions

View File

@ -17,7 +17,7 @@
import 'jasmine'; import 'jasmine';
import {MPImageShaderContext} from './image_shader_context'; import {MPImageShaderContext} from './image_shader_context';
import {MPMask, MPMaskType} from './mask'; import {MPMask} from './mask';
const WIDTH = 2; const WIDTH = 2;
const HEIGHT = 2; const HEIGHT = 2;
@ -117,13 +117,13 @@ class MPMaskTestContext {
function assertEquality(mask: MPMask, expected: MaskType): void { function assertEquality(mask: MPMask, expected: MaskType): void {
if (expected instanceof Uint8Array) { if (expected instanceof Uint8Array) {
const result = mask.get(MPMaskType.UINT8_ARRAY); const result = mask.getAsUint8Array();
expect(result).toEqual(expected); expect(result).toEqual(expected);
} else if (expected instanceof Float32Array) { } else if (expected instanceof Float32Array) {
const result = mask.get(MPMaskType.FLOAT32_ARRAY); const result = mask.getAsFloat32Array();
expect(result).toEqual(expected); expect(result).toEqual(expected);
} else { // WebGLTexture } else { // WebGLTexture
const result = mask.get(MPMaskType.WEBGL_TEXTURE); const result = mask.getAsWebGLTexture();
expect(readPixelsFromWebGLTexture(result)) expect(readPixelsFromWebGLTexture(result))
.toEqual(readPixelsFromWebGLTexture(expected)); .toEqual(readPixelsFromWebGLTexture(expected));
} }
@ -183,7 +183,7 @@ class MPMaskTestContext {
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH, /* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT); HEIGHT);
const result = mask.clone().get(MPMaskType.UINT8_ARRAY); const result = mask.clone().getAsUint8Array();
expect(result).toEqual(context.uint8Array); expect(result).toEqual(context.uint8Array);
shaderContext.close(); shaderContext.close();
}); });
@ -199,13 +199,13 @@ class MPMaskTestContext {
// Verify that we can mix the different shader modes by running them out of // Verify that we can mix the different shader modes by running them out of
// order. // order.
let result = mask.get(MPMaskType.UINT8_ARRAY); let result = mask.getAsUint8Array();
expect(result).toEqual(context.uint8Array); expect(result).toEqual(context.uint8Array);
result = mask.clone().get(MPMaskType.UINT8_ARRAY); result = mask.clone().getAsUint8Array();
expect(result).toEqual(context.uint8Array); expect(result).toEqual(context.uint8Array);
result = mask.get(MPMaskType.UINT8_ARRAY); result = mask.getAsUint8Array();
expect(result).toEqual(context.uint8Array); expect(result).toEqual(context.uint8Array);
shaderContext.close(); shaderContext.close();
@ -217,20 +217,21 @@ class MPMaskTestContext {
const shaderContext = new MPImageShaderContext(); const shaderContext = new MPImageShaderContext();
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT); const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); expect(mask.hasUint8Array()).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false); expect(mask.hasFloat32Array()).toBe(false);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false); expect(mask.hasWebGLTexture()).toBe(false);
mask.get(MPMaskType.FLOAT32_ARRAY); mask.getAsFloat32Array();
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); expect(mask.hasUint8Array()).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true); expect(mask.hasFloat32Array()).toBe(true);
expect(mask.hasWebGLTexture()).toBe(false);
mask.get(MPMaskType.WEBGL_TEXTURE); mask.getAsWebGLTexture();
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); expect(mask.hasUint8Array()).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true); expect(mask.hasFloat32Array()).toBe(true);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true); expect(mask.hasWebGLTexture()).toBe(true);
mask.close(); mask.close();
shaderContext.close(); shaderContext.close();

View File

@ -17,7 +17,7 @@
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** The underlying type of the image. */ /** The underlying type of the image. */
export enum MPMaskType { enum MPMaskType {
/** Represents the native `UInt8Array` type. */ /** Represents the native `UInt8Array` type. */
UINT8_ARRAY, UINT8_ARRAY,
/** Represents the native `Float32Array` type. */ /** Represents the native `Float32Array` type. */
@ -34,9 +34,9 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
* *
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects. * Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
* You can convert the underlying type to any other type by passing the desired * You can convert the underlying type to any other type by passing the desired
* type to `get()`. As type conversions can be expensive, it is recommended to * type to `getAs...()`. As type conversions can be expensive, it is recommended
* limit these conversions. You can verify what underlying types are already * to limit these conversions. You can verify what underlying types are already
* available by invoking `has()`. * available by invoking `has...()`.
* *
* Masks that are returned from a MediaPipe Tasks are owned by by the * Masks that are returned from a MediaPipe Tasks are owned by by the
* underlying C++ Task. If you need to extend the lifetime of these objects, * underlying C++ Task. If you need to extend the lifetime of these objects,
@ -47,9 +47,6 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
export class MPMask { export class MPMask {
private gl?: WebGL2RenderingContext; private gl?: WebGL2RenderingContext;
/** The underlying type of the mask. */
static TYPE = MPMaskType;
/** @hideconstructor */ /** @hideconstructor */
constructor( constructor(
private readonly containers: MPMaskContainer[], private readonly containers: MPMaskContainer[],
@ -63,13 +60,19 @@ export class MPMask {
readonly height: number, readonly height: number,
) {} ) {}
/** /** Returns whether this `MPMask` contains a mask of type `Uint8Array`. */
* Returns whether this `MPMask` stores the mask in the desired hasUint8Array(): boolean {
* format. This method can be called to reduce expensive conversion before return !!this.getContainer(MPMaskType.UINT8_ARRAY);
* invoking `get()`. }
*/
has(type: MPMaskType): boolean { /** Returns whether this `MPMask` contains a mask of type `Float32Array`. */
return !!this.getContainer(type); hasFloat32Array(): boolean {
return !!this.getContainer(MPMaskType.FLOAT32_ARRAY);
}
/** Returns whether this `MPMask` contains a mask of type `WebGLTexture`. */
hasWebGLTexture(): boolean {
return !!this.getContainer(MPMaskType.WEBGL_TEXTURE);
} }
/** /**
@ -77,42 +80,34 @@ export class MPMask {
* expensive GPU to CPU transfer if the current mask is only available as a * expensive GPU to CPU transfer if the current mask is only available as a
* `WebGLTexture`. * `WebGLTexture`.
* *
* @param type The type of mask to return.
* @return The current data as a Uint8Array. * @return The current data as a Uint8Array.
*/ */
get(type: MPMaskType.UINT8_ARRAY): Uint8Array; getAsUint8Array(): Uint8Array {
return this.convertToUint8Array();
}
/** /**
* Returns the underlying mask as a single channel `Float32Array`. Note that * Returns the underlying mask as a single channel `Float32Array`. Note that
* this involves an expensive GPU to CPU transfer if the current mask is only * this involves an expensive GPU to CPU transfer if the current mask is only
* available as a `WebGLTexture`. * available as a `WebGLTexture`.
* *
* @param type The type of mask to return.
* @return The current mask as a Float32Array. * @return The current mask as a Float32Array.
*/ */
get(type: MPMaskType.FLOAT32_ARRAY): Float32Array; getAsFloat32Array(): Float32Array {
return this.convertToFloat32Array();
}
/** /**
* Returns the underlying mask as a `WebGLTexture` object. Note that this * Returns the underlying mask as a `WebGLTexture` object. Note that this
* involves a CPU to GPU transfer if the current mask is only available as * involves a CPU to GPU transfer if the current mask is only available as
* a CPU array. The returned texture is bound to the current canvas (see * a CPU array. The returned texture is bound to the current canvas (see
* `.canvas`). * `.canvas`).
* *
* @param type The type of mask to return.
* @return The current mask as a WebGLTexture. * @return The current mask as a WebGLTexture.
*/ */
get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture; getAsWebGLTexture(): WebGLTexture {
get(type?: MPMaskType): MPMaskContainer {
switch (type) {
case MPMaskType.UINT8_ARRAY:
return this.convertToUint8Array();
case MPMaskType.FLOAT32_ARRAY:
return this.convertToFloat32Array();
case MPMaskType.WEBGL_TEXTURE:
return this.convertToWebGLTexture(); return this.convertToWebGLTexture();
default:
throw new Error(`Type is not supported: ${type}`);
} }
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
@ -186,7 +181,7 @@ export class MPMask {
} }
return new MPMask( return new MPMask(
destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas, destinationContainers, this.hasWebGLTexture(), this.canvas,
this.shaderContext, this.width, this.height); this.shaderContext, this.width, this.height);
} }
@ -220,9 +215,9 @@ export class MPMask {
private convertToFloat32Array(): Float32Array { private convertToFloat32Array(): Float32Array {
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY); let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
if (!float32Array) { if (!float32Array) {
if (this.has(MPMaskType.UINT8_ARRAY)) { const uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY);
const source = this.getContainer(MPMaskType.UINT8_ARRAY)!; if (uint8Array) {
float32Array = new Float32Array(source).map(v => v / 255); float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else { } else {
const gl = this.getGL(); const gl = this.getGL();
const shaderContext = this.getShaderContext(); const shaderContext = this.getShaderContext();

View File

@ -17,7 +17,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image'; import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask'; import {MPMask as MPMaskImpl} from '../../../tasks/web/vision/core/mask';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
@ -36,7 +36,6 @@ const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl; const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl; const MPImage = MPImageImpl;
const MPMask = MPMaskImpl; const MPMask = MPMaskImpl;
const MPMaskType = MPMaskTypeImpl;
const FaceDetector = FaceDetectorImpl; const FaceDetector = FaceDetectorImpl;
const FaceLandmarker = FaceLandmarkerImpl; const FaceLandmarker = FaceLandmarkerImpl;
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl; const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
@ -55,7 +54,6 @@ export {
FilesetResolver, FilesetResolver,
MPImage, MPImage,
MPMask, MPMask,
MPMaskType,
FaceDetector, FaceDetector,
FaceLandmarker, FaceLandmarker,
FaceLandmarksConnections, FaceLandmarksConnections,

View File

@ -17,7 +17,7 @@
export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils'; export * from '../../../tasks/web/vision/core/drawing_utils';
export {MPImage} from '../../../tasks/web/vision/core/image'; export {MPImage} from '../../../tasks/web/vision/core/image';
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask'; export {MPMask} from '../../../tasks/web/vision/core/mask';
export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';