Simplify MPMask by removing the Type Enums from the public API
PiperOrigin-RevId: 529975377
This commit is contained in:
parent
e9fc66277a
commit
ddb84702f6
|
@ -17,7 +17,7 @@
|
||||||
import 'jasmine';
|
import 'jasmine';
|
||||||
|
|
||||||
import {MPImageShaderContext} from './image_shader_context';
|
import {MPImageShaderContext} from './image_shader_context';
|
||||||
import {MPMask, MPMaskType} from './mask';
|
import {MPMask} from './mask';
|
||||||
|
|
||||||
const WIDTH = 2;
|
const WIDTH = 2;
|
||||||
const HEIGHT = 2;
|
const HEIGHT = 2;
|
||||||
|
@ -117,13 +117,13 @@ class MPMaskTestContext {
|
||||||
|
|
||||||
function assertEquality(mask: MPMask, expected: MaskType): void {
|
function assertEquality(mask: MPMask, expected: MaskType): void {
|
||||||
if (expected instanceof Uint8Array) {
|
if (expected instanceof Uint8Array) {
|
||||||
const result = mask.get(MPMaskType.UINT8_ARRAY);
|
const result = mask.getAsUint8Array();
|
||||||
expect(result).toEqual(expected);
|
expect(result).toEqual(expected);
|
||||||
} else if (expected instanceof Float32Array) {
|
} else if (expected instanceof Float32Array) {
|
||||||
const result = mask.get(MPMaskType.FLOAT32_ARRAY);
|
const result = mask.getAsFloat32Array();
|
||||||
expect(result).toEqual(expected);
|
expect(result).toEqual(expected);
|
||||||
} else { // WebGLTexture
|
} else { // WebGLTexture
|
||||||
const result = mask.get(MPMaskType.WEBGL_TEXTURE);
|
const result = mask.getAsWebGLTexture();
|
||||||
expect(readPixelsFromWebGLTexture(result))
|
expect(readPixelsFromWebGLTexture(result))
|
||||||
.toEqual(readPixelsFromWebGLTexture(expected));
|
.toEqual(readPixelsFromWebGLTexture(expected));
|
||||||
}
|
}
|
||||||
|
@ -183,7 +183,7 @@ class MPMaskTestContext {
|
||||||
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
|
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
|
||||||
HEIGHT);
|
HEIGHT);
|
||||||
|
|
||||||
const result = mask.clone().get(MPMaskType.UINT8_ARRAY);
|
const result = mask.clone().getAsUint8Array();
|
||||||
expect(result).toEqual(context.uint8Array);
|
expect(result).toEqual(context.uint8Array);
|
||||||
shaderContext.close();
|
shaderContext.close();
|
||||||
});
|
});
|
||||||
|
@ -199,13 +199,13 @@ class MPMaskTestContext {
|
||||||
|
|
||||||
// Verify that we can mix the different shader modes by running them out of
|
// Verify that we can mix the different shader modes by running them out of
|
||||||
// order.
|
// order.
|
||||||
let result = mask.get(MPMaskType.UINT8_ARRAY);
|
let result = mask.getAsUint8Array();
|
||||||
expect(result).toEqual(context.uint8Array);
|
expect(result).toEqual(context.uint8Array);
|
||||||
|
|
||||||
result = mask.clone().get(MPMaskType.UINT8_ARRAY);
|
result = mask.clone().getAsUint8Array();
|
||||||
expect(result).toEqual(context.uint8Array);
|
expect(result).toEqual(context.uint8Array);
|
||||||
|
|
||||||
result = mask.get(MPMaskType.UINT8_ARRAY);
|
result = mask.getAsUint8Array();
|
||||||
expect(result).toEqual(context.uint8Array);
|
expect(result).toEqual(context.uint8Array);
|
||||||
|
|
||||||
shaderContext.close();
|
shaderContext.close();
|
||||||
|
@ -217,20 +217,21 @@ class MPMaskTestContext {
|
||||||
const shaderContext = new MPImageShaderContext();
|
const shaderContext = new MPImageShaderContext();
|
||||||
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
|
const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT);
|
||||||
|
|
||||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
expect(mask.hasUint8Array()).toBe(true);
|
||||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false);
|
expect(mask.hasFloat32Array()).toBe(false);
|
||||||
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false);
|
expect(mask.hasWebGLTexture()).toBe(false);
|
||||||
|
|
||||||
mask.get(MPMaskType.FLOAT32_ARRAY);
|
mask.getAsFloat32Array();
|
||||||
|
|
||||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
expect(mask.hasUint8Array()).toBe(true);
|
||||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
|
expect(mask.hasFloat32Array()).toBe(true);
|
||||||
|
expect(mask.hasWebGLTexture()).toBe(false);
|
||||||
|
|
||||||
mask.get(MPMaskType.WEBGL_TEXTURE);
|
mask.getAsWebGLTexture();
|
||||||
|
|
||||||
expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true);
|
expect(mask.hasUint8Array()).toBe(true);
|
||||||
expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true);
|
expect(mask.hasFloat32Array()).toBe(true);
|
||||||
expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true);
|
expect(mask.hasWebGLTexture()).toBe(true);
|
||||||
|
|
||||||
mask.close();
|
mask.close();
|
||||||
shaderContext.close();
|
shaderContext.close();
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
|
||||||
|
|
||||||
/** The underlying type of the image. */
|
/** The underlying type of the image. */
|
||||||
export enum MPMaskType {
|
enum MPMaskType {
|
||||||
/** Represents the native `UInt8Array` type. */
|
/** Represents the native `UInt8Array` type. */
|
||||||
UINT8_ARRAY,
|
UINT8_ARRAY,
|
||||||
/** Represents the native `Float32Array` type. */
|
/** Represents the native `Float32Array` type. */
|
||||||
|
@ -34,9 +34,9 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||||
*
|
*
|
||||||
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
|
* Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects.
|
||||||
* You can convert the underlying type to any other type by passing the desired
|
* You can convert the underlying type to any other type by passing the desired
|
||||||
* type to `get()`. As type conversions can be expensive, it is recommended to
|
* type to `getAs...()`. As type conversions can be expensive, it is recommended
|
||||||
* limit these conversions. You can verify what underlying types are already
|
* to limit these conversions. You can verify what underlying types are already
|
||||||
* available by invoking `has()`.
|
* available by invoking `has...()`.
|
||||||
*
|
*
|
||||||
* Masks that are returned from a MediaPipe Tasks are owned by by the
|
* Masks that are returned from a MediaPipe Tasks are owned by by the
|
||||||
* underlying C++ Task. If you need to extend the lifetime of these objects,
|
* underlying C++ Task. If you need to extend the lifetime of these objects,
|
||||||
|
@ -47,9 +47,6 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture;
|
||||||
export class MPMask {
|
export class MPMask {
|
||||||
private gl?: WebGL2RenderingContext;
|
private gl?: WebGL2RenderingContext;
|
||||||
|
|
||||||
/** The underlying type of the mask. */
|
|
||||||
static TYPE = MPMaskType;
|
|
||||||
|
|
||||||
/** @hideconstructor */
|
/** @hideconstructor */
|
||||||
constructor(
|
constructor(
|
||||||
private readonly containers: MPMaskContainer[],
|
private readonly containers: MPMaskContainer[],
|
||||||
|
@ -63,13 +60,19 @@ export class MPMask {
|
||||||
readonly height: number,
|
readonly height: number,
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
/**
|
/** Returns whether this `MPMask` contains a mask of type `Uint8Array`. */
|
||||||
* Returns whether this `MPMask` stores the mask in the desired
|
hasUint8Array(): boolean {
|
||||||
* format. This method can be called to reduce expensive conversion before
|
return !!this.getContainer(MPMaskType.UINT8_ARRAY);
|
||||||
* invoking `get()`.
|
}
|
||||||
*/
|
|
||||||
has(type: MPMaskType): boolean {
|
/** Returns whether this `MPMask` contains a mask of type `Float32Array`. */
|
||||||
return !!this.getContainer(type);
|
hasFloat32Array(): boolean {
|
||||||
|
return !!this.getContainer(MPMaskType.FLOAT32_ARRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns whether this `MPMask` contains a mask of type `WebGLTexture`. */
|
||||||
|
hasWebGLTexture(): boolean {
|
||||||
|
return !!this.getContainer(MPMaskType.WEBGL_TEXTURE);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -77,43 +80,35 @@ export class MPMask {
|
||||||
* expensive GPU to CPU transfer if the current mask is only available as a
|
* expensive GPU to CPU transfer if the current mask is only available as a
|
||||||
* `WebGLTexture`.
|
* `WebGLTexture`.
|
||||||
*
|
*
|
||||||
* @param type The type of mask to return.
|
|
||||||
* @return The current data as a Uint8Array.
|
* @return The current data as a Uint8Array.
|
||||||
*/
|
*/
|
||||||
get(type: MPMaskType.UINT8_ARRAY): Uint8Array;
|
getAsUint8Array(): Uint8Array {
|
||||||
|
return this.convertToUint8Array();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the underlying mask as a single channel `Float32Array`. Note that
|
* Returns the underlying mask as a single channel `Float32Array`. Note that
|
||||||
* this involves an expensive GPU to CPU transfer if the current mask is only
|
* this involves an expensive GPU to CPU transfer if the current mask is only
|
||||||
* available as a `WebGLTexture`.
|
* available as a `WebGLTexture`.
|
||||||
*
|
*
|
||||||
* @param type The type of mask to return.
|
|
||||||
* @return The current mask as a Float32Array.
|
* @return The current mask as a Float32Array.
|
||||||
*/
|
*/
|
||||||
get(type: MPMaskType.FLOAT32_ARRAY): Float32Array;
|
getAsFloat32Array(): Float32Array {
|
||||||
|
return this.convertToFloat32Array();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the underlying mask as a `WebGLTexture` object. Note that this
|
* Returns the underlying mask as a `WebGLTexture` object. Note that this
|
||||||
* involves a CPU to GPU transfer if the current mask is only available as
|
* involves a CPU to GPU transfer if the current mask is only available as
|
||||||
* a CPU array. The returned texture is bound to the current canvas (see
|
* a CPU array. The returned texture is bound to the current canvas (see
|
||||||
* `.canvas`).
|
* `.canvas`).
|
||||||
*
|
*
|
||||||
* @param type The type of mask to return.
|
|
||||||
* @return The current mask as a WebGLTexture.
|
* @return The current mask as a WebGLTexture.
|
||||||
*/
|
*/
|
||||||
get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture;
|
getAsWebGLTexture(): WebGLTexture {
|
||||||
get(type?: MPMaskType): MPMaskContainer {
|
return this.convertToWebGLTexture();
|
||||||
switch (type) {
|
|
||||||
case MPMaskType.UINT8_ARRAY:
|
|
||||||
return this.convertToUint8Array();
|
|
||||||
case MPMaskType.FLOAT32_ARRAY:
|
|
||||||
return this.convertToFloat32Array();
|
|
||||||
case MPMaskType.WEBGL_TEXTURE:
|
|
||||||
return this.convertToWebGLTexture();
|
|
||||||
default:
|
|
||||||
throw new Error(`Type is not supported: ${type}`);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined;
|
||||||
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined;
|
||||||
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined;
|
||||||
|
@ -186,7 +181,7 @@ export class MPMask {
|
||||||
}
|
}
|
||||||
|
|
||||||
return new MPMask(
|
return new MPMask(
|
||||||
destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas,
|
destinationContainers, this.hasWebGLTexture(), this.canvas,
|
||||||
this.shaderContext, this.width, this.height);
|
this.shaderContext, this.width, this.height);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,9 +215,9 @@ export class MPMask {
|
||||||
private convertToFloat32Array(): Float32Array {
|
private convertToFloat32Array(): Float32Array {
|
||||||
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
|
let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY);
|
||||||
if (!float32Array) {
|
if (!float32Array) {
|
||||||
if (this.has(MPMaskType.UINT8_ARRAY)) {
|
const uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY);
|
||||||
const source = this.getContainer(MPMaskType.UINT8_ARRAY)!;
|
if (uint8Array) {
|
||||||
float32Array = new Float32Array(source).map(v => v / 255);
|
float32Array = new Float32Array(uint8Array).map(v => v / 255);
|
||||||
} else {
|
} else {
|
||||||
const gl = this.getGL();
|
const gl = this.getGL();
|
||||||
const shaderContext = this.getShaderContext();
|
const shaderContext = this.getShaderContext();
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||||
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
|
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
|
||||||
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
|
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
|
||||||
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
|
import {MPMask as MPMaskImpl} from '../../../tasks/web/vision/core/mask';
|
||||||
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
||||||
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||||
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||||
|
@ -36,7 +36,6 @@ const DrawingUtils = DrawingUtilsImpl;
|
||||||
const FilesetResolver = FilesetResolverImpl;
|
const FilesetResolver = FilesetResolverImpl;
|
||||||
const MPImage = MPImageImpl;
|
const MPImage = MPImageImpl;
|
||||||
const MPMask = MPMaskImpl;
|
const MPMask = MPMaskImpl;
|
||||||
const MPMaskType = MPMaskTypeImpl;
|
|
||||||
const FaceDetector = FaceDetectorImpl;
|
const FaceDetector = FaceDetectorImpl;
|
||||||
const FaceLandmarker = FaceLandmarkerImpl;
|
const FaceLandmarker = FaceLandmarkerImpl;
|
||||||
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
|
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
|
||||||
|
@ -55,7 +54,6 @@ export {
|
||||||
FilesetResolver,
|
FilesetResolver,
|
||||||
MPImage,
|
MPImage,
|
||||||
MPMask,
|
MPMask,
|
||||||
MPMaskType,
|
|
||||||
FaceDetector,
|
FaceDetector,
|
||||||
FaceLandmarker,
|
FaceLandmarker,
|
||||||
FaceLandmarksConnections,
|
FaceLandmarksConnections,
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
export * from '../../../tasks/web/core/fileset_resolver';
|
export * from '../../../tasks/web/core/fileset_resolver';
|
||||||
export * from '../../../tasks/web/vision/core/drawing_utils';
|
export * from '../../../tasks/web/vision/core/drawing_utils';
|
||||||
export {MPImage} from '../../../tasks/web/vision/core/image';
|
export {MPImage} from '../../../tasks/web/vision/core/image';
|
||||||
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
|
export {MPMask} from '../../../tasks/web/vision/core/mask';
|
||||||
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
||||||
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||||
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user