Simplify MPMask by removing the Type Enums from the public API
PiperOrigin-RevId: 529975377
This commit is contained in:
parent
e9fc66277a
commit
ddb84702f6
|
@ -17,7 +17,7 @@
|
|||
import 'jasmine';
|
||||
|
||||
import {MPImageShaderContext} from './image_shader_context';
|
||||
import {MPMask, MPMaskType} from './mask';
|
||||
import {MPMask} from './mask';
|
||||
|
||||
const WIDTH = 2;
|
||||
const HEIGHT = 2;
|
||||
|
@ -117,13 +117,13 @@ class MPMaskTestContext {
|
|||
|
||||
function assertEquality(mask: MPMask, expected: MaskType): void {
|
||||
if (expected instanceof Uint8Array) {
|
||||
const result = mask.get(MPMaskType.UINT8_ARRAY);
|
||||
const result = mask.getAsUint8Array();
|
||||
expect(result).toEqual(expected);
|
||||
} else if (expected instanceof Float32Array) {
|
||||
const result = mask.get(MPMaskType.FLOAT32_ARRAY);
|
||||
const result = mask.getAsFloat32Array();
|
||||
expect(result).toEqual(expected);
|
||||
} else { // WebGLTexture
|
||||
const result = mask.get(MPMaskType.WEBGL_TEXTURE);
|
||||
const result = mask.getAsWebGLTexture();
|
||||
expect(readPixelsFromWebGLTexture(result))
|
||||
.toEqual(readPixelsFromWebGLTexture(expected));
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ class MPMaskTestContext {
|
|||
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
|
||||
HEIGHT);
|
||||
|
||||
const result = mask.clone().get(MPMaskType.UINT8_ARRAY);
|
||||
const result = mask.clone().getAsUint8Array();
|
||||
expect(result).toEqual(context.uint8Array);
|
||||
shaderContext.close();
|
||||
});
|
||||
|
@ -199,13 +199,13 @@ class MPMaskTestContext {
|
|||
|
||||
// Verify that we can mix the different shader modes by running them out of
|
||||
// order.
|
||||
let result = mask.get(MPMaskType.UINT8_ARRAY);
|
||||
let result = mask.getAsUint8Array();
|
||||
expect(result).toEqual(context.uint8Array);
|
||||
|
||||
result = mask.clone().get(MPMaskType.UINT8_ARRAY);
|
||||
result = mask.clone().getAsUint8Array();
|
||||
expect(result).toEqual(context.uint8Array);
|
||||
|
||||
result = mask.get(MPMaskType.UINT8_ARRAY);
|
||||
result = mask.getAsUint8Array();
|
||||
expect(result).toEqual(context.uint8Array);
|
||||
|
||||
shaderContext.close();
|
||||
|
@ -217,20 +217,21 @@ class MPMaskTestContext {
|
|||
const shaderContext = new MPImageShaderContext();
|
||||
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
|
||||
|
||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false);
|
||||
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false);
|
||||
expect(mask.hasUint8Array()).toBe(true);
|
||||
expect(mask.hasFloat32Array()).toBe(false);
|
||||
expect(mask.hasWebGLTexture()).toBe(false);
|
||||
|
||||
mask.get(MPMaskType.FLOAT32_ARRAY);
|
||||
mask.getAsFloat32Array();
|
||||
|
||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
|
||||
expect(mask.hasUint8Array()).toBe(true);
|
||||
expect(mask.hasFloat32Array()).toBe(true);
|
||||
expect(mask.hasWebGLTexture()).toBe(false);
|
||||
|
||||
mask.get(MPMaskType.WEBGL_TEXTURE);
|
||||
mask.getAsWebGLTexture();
|
||||
|
||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
|
||||
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true);
|
||||
expect(mask.hasUint8Array()).toBe(true);
|
||||
expect(mask.hasFloat32Array()).toBe(true);
|
||||
expect(mask.hasWebGLTexture()).toBe(true);
|
||||
|
||||
mask.close();
|
||||
shaderContext.close();
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||
|
||||
/** The underlying type of the image. */
|
||||
export enum MPMaskType {
|
||||
enum MPMaskType {
|
||||
/** Represents the native `UInt8Array` type. */
|
||||
UINT8_ARRAY,
|
||||
/** Represents the native `Float32Array` type. */
|
||||
|
@ -34,9 +34,9 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
|||
*
|
||||
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
|
||||
* You can convert the underlying type to any other type by passing the desired
|
||||
* type to `get()`. As type conversions can be expensive, it is recommended to
|
||||
* limit these conversions. You can verify what underlying types are already
|
||||
* available by invoking `has()`.
|
||||
* type to `getAs...()`. As type conversions can be expensive, it is recommended
|
||||
* to limit these conversions. You can verify what underlying types are already
|
||||
* available by invoking `has...()`.
|
||||
*
|
||||
* Masks that are returned from a MediaPipe Tasks are owned by by the
|
||||
* underlying C++ Task. If you need to extend the lifetime of these objects,
|
||||
|
@ -47,9 +47,6 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
|||
export class MPMask {
|
||||
private gl?: WebGL2RenderingContext;
|
||||
|
||||
/** The underlying type of the mask. */
|
||||
static TYPE = MPMaskType;
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
private readonly containers: MPMaskContainer[],
|
||||
|
@ -63,13 +60,19 @@ export class MPMask {
|
|||
readonly height: number,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Returns whether this `MPMask` stores the mask in the desired
|
||||
* format. This method can be called to reduce expensive conversion before
|
||||
* invoking `get()`.
|
||||
*/
|
||||
has(type: MPMaskType): boolean {
|
||||
return !!this.getContainer(type);
|
||||
/** Returns whether this `MPMask` contains a mask of type `Uint8Array`. */
|
||||
hasUint8Array(): boolean {
|
||||
return !!this.getContainer(MPMaskType.UINT8_ARRAY);
|
||||
}
|
||||
|
||||
/** Returns whether this `MPMask` contains a mask of type `Float32Array`. */
|
||||
hasFloat32Array(): boolean {
|
||||
return !!this.getContainer(MPMaskType.FLOAT32_ARRAY);
|
||||
}
|
||||
|
||||
/** Returns whether this `MPMask` contains a mask of type `WebGLTexture`. */
|
||||
hasWebGLTexture(): boolean {
|
||||
return !!this.getContainer(MPMaskType.WEBGL_TEXTURE);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -77,43 +80,35 @@ export class MPMask {
|
|||
* expensive GPU to CPU transfer if the current mask is only available as a
|
||||
* `WebGLTexture`.
|
||||
*
|
||||
* @param type The type of mask to return.
|
||||
* @return The current data as a Uint8Array.
|
||||
*/
|
||||
get(type: MPMaskType.UINT8_ARRAY): Uint8Array;
|
||||
getAsUint8Array(): Uint8Array {
|
||||
return this.convertToUint8Array();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the underlying mask as a single channel `Float32Array`. Note that
|
||||
* this involves an expensive GPU to CPU transfer if the current mask is only
|
||||
* available as a `WebGLTexture`.
|
||||
*
|
||||
* @param type The type of mask to return.
|
||||
* @return The current mask as a Float32Array.
|
||||
*/
|
||||
get(type: MPMaskType.FLOAT32_ARRAY): Float32Array;
|
||||
getAsFloat32Array(): Float32Array {
|
||||
return this.convertToFloat32Array();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the underlying mask as a `WebGLTexture` object. Note that this
|
||||
* involves a CPU to GPU transfer if the current mask is only available as
|
||||
* a CPU array. The returned texture is bound to the current canvas (see
|
||||
* `.canvas`).
|
||||
*
|
||||
* @param type The type of mask to return.
|
||||
* @return The current mask as a WebGLTexture.
|
||||
*/
|
||||
get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture;
|
||||
get(type?: MPMaskType): MPMaskContainer {
|
||||
switch (type) {
|
||||
case MPMaskType.UINT8_ARRAY:
|
||||
return this.convertToUint8Array();
|
||||
case MPMaskType.FLOAT32_ARRAY:
|
||||
return this.convertToFloat32Array();
|
||||
case MPMaskType.WEBGL_TEXTURE:
|
||||
return this.convertToWebGLTexture();
|
||||
default:
|
||||
throw new Error(`Type is not supported: ${type}`);
|
||||
}
|
||||
getAsWebGLTexture(): WebGLTexture {
|
||||
return this.convertToWebGLTexture();
|
||||
}
|
||||
|
||||
|
||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||
|
@ -186,7 +181,7 @@ export class MPMask {
|
|||
}
|
||||
|
||||
return new MPMask(
|
||||
destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas,
|
||||
destinationContainers, this.hasWebGLTexture(), this.canvas,
|
||||
this.shaderContext, this.width, this.height);
|
||||
}
|
||||
|
||||
|
@ -220,9 +215,9 @@ export class MPMask {
|
|||
private convertToFloat32Array(): Float32Array {
|
||||
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
|
||||
if (!float32Array) {
|
||||
if (this.has(MPMaskType.UINT8_ARRAY)) {
|
||||
const source = this.getContainer(MPMaskType.UINT8_ARRAY)!;
|
||||
float32Array = new Float32Array(source).map(v => v / 255);
|
||||
const uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY);
|
||||
if (uint8Array) {
|
||||
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||
} else {
|
||||
const gl = this.getGL();
|
||||
const shaderContext = this.getShaderContext();
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
|
||||
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
|
||||
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
|
||||
import {MPMask as MPMaskImpl} from '../../../tasks/web/vision/core/mask';
|
||||
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
||||
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||
|
@ -36,7 +36,6 @@ const DrawingUtils = DrawingUtilsImpl;
|
|||
const FilesetResolver = FilesetResolverImpl;
|
||||
const MPImage = MPImageImpl;
|
||||
const MPMask = MPMaskImpl;
|
||||
const MPMaskType = MPMaskTypeImpl;
|
||||
const FaceDetector = FaceDetectorImpl;
|
||||
const FaceLandmarker = FaceLandmarkerImpl;
|
||||
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
|
||||
|
@ -55,7 +54,6 @@ export {
|
|||
FilesetResolver,
|
||||
MPImage,
|
||||
MPMask,
|
||||
MPMaskType,
|
||||
FaceDetector,
|
||||
FaceLandmarker,
|
||||
FaceLandmarksConnections,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
export * from '../../../tasks/web/core/fileset_resolver';
|
||||
export * from '../../../tasks/web/vision/core/drawing_utils';
|
||||
export {MPImage} from '../../../tasks/web/vision/core/image';
|
||||
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
|
||||
export {MPMask} from '../../../tasks/web/vision/core/mask';
|
||||
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
||||
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||
|
|
Loading…
Reference in New Issue
Block a user