Simplify MPMask by removing the Type Enums from the public API

PiperOrigin-RevId: 529975377
This commit is contained in:
Sebastian Schmidt 2023-05-06 10:23:58 -07:00 committed by Copybara-Service
parent e9fc66277a
commit ddb84702f6
4 changed files with 52 additions and 58 deletions

View File

@ -17,7 +17,7 @@
import 'jasmine';
import {MPImageShaderContext} from './image_shader_context';
import {MPMask, MPMaskType} from './mask';
import {MPMask} from './mask';
const WIDTH = 2;
const HEIGHT = 2;
@ -117,13 +117,13 @@ class MPMaskTestContext {
function assertEquality(mask: MPMask, expected: MaskType): void {
if (expected instanceof Uint8Array) {
const result = mask.get(MPMaskType.UINT8_ARRAY);
const result = mask.getAsUint8Array();
expect(result).toEqual(expected);
} else if (expected instanceof Float32Array) {
const result = mask.get(MPMaskType.FLOAT32_ARRAY);
const result = mask.getAsFloat32Array();
expect(result).toEqual(expected);
} else { // WebGLTexture
const result = mask.get(MPMaskType.WEBGL_TEXTURE);
const result = mask.getAsWebGLTexture();
expect(readPixelsFromWebGLTexture(result))
.toEqual(readPixelsFromWebGLTexture(expected));
}
@ -183,7 +183,7 @@ class MPMaskTestContext {
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT);
const result = mask.clone().get(MPMaskType.UINT8_ARRAY);
const result = mask.clone().getAsUint8Array();
expect(result).toEqual(context.uint8Array);
shaderContext.close();
});
@ -199,13 +199,13 @@ class MPMaskTestContext {
// Verify that we can mix the different shader modes by running them out of
// order.
let result = mask.get(MPMaskType.UINT8_ARRAY);
let result = mask.getAsUint8Array();
expect(result).toEqual(context.uint8Array);
result = mask.clone().get(MPMaskType.UINT8_ARRAY);
result = mask.clone().getAsUint8Array();
expect(result).toEqual(context.uint8Array);
result = mask.get(MPMaskType.UINT8_ARRAY);
result = mask.getAsUint8Array();
expect(result).toEqual(context.uint8Array);
shaderContext.close();
@ -217,20 +217,21 @@ class MPMaskTestContext {
const shaderContext = new MPImageShaderContext();
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false);
expect(mask.hasUint8Array()).toBe(true);
expect(mask.hasFloat32Array()).toBe(false);
expect(mask.hasWebGLTexture()).toBe(false);
mask.get(MPMaskType.FLOAT32_ARRAY);
mask.getAsFloat32Array();
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
expect(mask.hasUint8Array()).toBe(true);
expect(mask.hasFloat32Array()).toBe(true);
expect(mask.hasWebGLTexture()).toBe(false);
mask.get(MPMaskType.WEBGL_TEXTURE);
mask.getAsWebGLTexture();
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true);
expect(mask.hasUint8Array()).toBe(true);
expect(mask.hasFloat32Array()).toBe(true);
expect(mask.hasWebGLTexture()).toBe(true);
mask.close();
shaderContext.close();

View File

@ -17,7 +17,7 @@
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** The underlying type of the image. */
export enum MPMaskType {
enum MPMaskType {
/** Represents the native `UInt8Array` type. */
UINT8_ARRAY,
/** Represents the native `Float32Array` type. */
@ -34,9 +34,9 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
*
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
* You can convert the underlying type to any other type by passing the desired
* type to `get()`. As type conversions can be expensive, it is recommended to
* limit these conversions. You can verify what underlying types are already
* available by invoking `has()`.
* type to `getAs...()`. As type conversions can be expensive, it is recommended
* to limit these conversions. You can verify what underlying types are already
* available by invoking `has...()`.
*
* Masks that are returned from a MediaPipe Tasks are owned by by the
* underlying C++ Task. If you need to extend the lifetime of these objects,
@ -47,9 +47,6 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
export class MPMask {
private gl?: WebGL2RenderingContext;
/** The underlying type of the mask. */
static TYPE = MPMaskType;
/** @hideconstructor */
constructor(
private readonly containers: MPMaskContainer[],
@ -63,13 +60,19 @@ export class MPMask {
readonly height: number,
) {}
/**
* Returns whether this `MPMask` stores the mask in the desired
* format. This method can be called to reduce expensive conversion before
* invoking `get()`.
*/
has(type: MPMaskType): boolean {
return !!this.getContainer(type);
/** Returns whether this `MPMask` contains a mask of type `Uint8Array`. */
hasUint8Array(): boolean {
return !!this.getContainer(MPMaskType.UINT8_ARRAY);
}
/** Returns whether this `MPMask` contains a mask of type `Float32Array`. */
hasFloat32Array(): boolean {
return !!this.getContainer(MPMaskType.FLOAT32_ARRAY);
}
/** Returns whether this `MPMask` contains a mask of type `WebGLTexture`. */
hasWebGLTexture(): boolean {
return !!this.getContainer(MPMaskType.WEBGL_TEXTURE);
}
/**
@ -77,42 +80,34 @@ export class MPMask {
* expensive GPU to CPU transfer if the current mask is only available as a
* `WebGLTexture`.
*
* @param type The type of mask to return.
* @return The current data as a Uint8Array.
*/
get(type: MPMaskType.UINT8_ARRAY): Uint8Array;
getAsUint8Array(): Uint8Array {
return this.convertToUint8Array();
}
/**
* Returns the underlying mask as a single channel `Float32Array`. Note that
* this involves an expensive GPU to CPU transfer if the current mask is only
* available as a `WebGLTexture`.
*
* @param type The type of mask to return.
* @return The current mask as a Float32Array.
*/
get(type: MPMaskType.FLOAT32_ARRAY): Float32Array;
getAsFloat32Array(): Float32Array {
return this.convertToFloat32Array();
}
/**
* Returns the underlying mask as a `WebGLTexture` object. Note that this
* involves a CPU to GPU transfer if the current mask is only available as
* a CPU array. The returned texture is bound to the current canvas (see
* `.canvas`).
*
* @param type The type of mask to return.
* @return The current mask as a WebGLTexture.
*/
get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture;
get(type?: MPMaskType): MPMaskContainer {
switch (type) {
case MPMaskType.UINT8_ARRAY:
return this.convertToUint8Array();
case MPMaskType.FLOAT32_ARRAY:
return this.convertToFloat32Array();
case MPMaskType.WEBGL_TEXTURE:
getAsWebGLTexture(): WebGLTexture {
return this.convertToWebGLTexture();
default:
throw new Error(`Type is not supported: ${type}`);
}
}
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
@ -186,7 +181,7 @@ export class MPMask {
}
return new MPMask(
destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas,
destinationContainers, this.hasWebGLTexture(), this.canvas,
this.shaderContext, this.width, this.height);
}
@ -220,9 +215,9 @@ export class MPMask {
private convertToFloat32Array(): Float32Array {
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
if (!float32Array) {
if (this.has(MPMaskType.UINT8_ARRAY)) {
const source = this.getContainer(MPMaskType.UINT8_ARRAY)!;
float32Array = new Float32Array(source).map(v => v / 255);
const uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY);
if (uint8Array) {
float32Array = new Float32Array(uint8Array).map(v => v / 255);
} else {
const gl = this.getGL();
const shaderContext = this.getShaderContext();

View File

@ -17,7 +17,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
import {MPMask as MPMaskImpl} from '../../../tasks/web/vision/core/mask';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
@ -36,7 +36,6 @@ const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl;
const MPMask = MPMaskImpl;
const MPMaskType = MPMaskTypeImpl;
const FaceDetector = FaceDetectorImpl;
const FaceLandmarker = FaceLandmarkerImpl;
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
@ -55,7 +54,6 @@ export {
FilesetResolver,
MPImage,
MPMask,
MPMaskType,
FaceDetector,
FaceLandmarker,
FaceLandmarksConnections,

View File

@ -17,7 +17,7 @@
export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils';
export {MPImage} from '../../../tasks/web/vision/core/image';
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
export {MPMask} from '../../../tasks/web/vision/core/mask';
export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';