From ddb84702f649a0037300a4bebf6c0be54d80844c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 6 May 2023 10:23:58 -0700 Subject: [PATCH] Simplify MPMask by removing the Type Enums from the public API PiperOrigin-RevId: 529975377 --- mediapipe/tasks/web/vision/core/mask.test.ts | 37 +++++------ mediapipe/tasks/web/vision/core/mask.ts | 67 +++++++++----------- mediapipe/tasks/web/vision/index.ts | 4 +- mediapipe/tasks/web/vision/types.ts | 2 +- 4 files changed, 52 insertions(+), 58 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/mask.test.ts b/mediapipe/tasks/web/vision/core/mask.test.ts index 310a59ef3..b632f2dc5 100644 --- a/mediapipe/tasks/web/vision/core/mask.test.ts +++ b/mediapipe/tasks/web/vision/core/mask.test.ts @@ -17,7 +17,7 @@ import 'jasmine'; import {MPImageShaderContext} from './image_shader_context'; -import {MPMask, MPMaskType} from './mask'; +import {MPMask} from './mask'; const WIDTH = 2; const HEIGHT = 2; @@ -117,13 +117,13 @@ class MPMaskTestContext { function assertEquality(mask: MPMask, expected: MaskType): void { if (expected instanceof Uint8Array) { - const result = mask.get(MPMaskType.UINT8_ARRAY); + const result = mask.getAsUint8Array(); expect(result).toEqual(expected); } else if (expected instanceof Float32Array) { - const result = mask.get(MPMaskType.FLOAT32_ARRAY); + const result = mask.getAsFloat32Array(); expect(result).toEqual(expected); } else { // WebGLTexture - const result = mask.get(MPMaskType.WEBGL_TEXTURE); + const result = mask.getAsWebGLTexture(); expect(readPixelsFromWebGLTexture(result)) .toEqual(readPixelsFromWebGLTexture(expected)); } @@ -183,7 +183,7 @@ class MPMaskTestContext { /* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH, HEIGHT); - const result = mask.clone().get(MPMaskType.UINT8_ARRAY); + const result = mask.clone().getAsUint8Array(); expect(result).toEqual(context.uint8Array); shaderContext.close(); }); @@ -199,13 +199,13 @@ class MPMaskTestContext { // Verify that we can mix the different shader modes by running them out of // order. - let result = mask.get(MPMaskType.UINT8_ARRAY); + let result = mask.getAsUint8Array(); expect(result).toEqual(context.uint8Array); - result = mask.clone().get(MPMaskType.UINT8_ARRAY); + result = mask.clone().getAsUint8Array(); expect(result).toEqual(context.uint8Array); - result = mask.get(MPMaskType.UINT8_ARRAY); + result = mask.getAsUint8Array(); expect(result).toEqual(context.uint8Array); shaderContext.close(); @@ -217,20 +217,21 @@ class MPMaskTestContext { const shaderContext = new MPImageShaderContext(); const mask = createImage(shaderContext, context.uint8Array, WIDTH, HEIGHT); - expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); - expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(false); - expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(false); + expect(mask.hasUint8Array()).toBe(true); + expect(mask.hasFloat32Array()).toBe(false); + expect(mask.hasWebGLTexture()).toBe(false); - mask.get(MPMaskType.FLOAT32_ARRAY); + mask.getAsFloat32Array(); - expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); - expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true); + expect(mask.hasUint8Array()).toBe(true); + expect(mask.hasFloat32Array()).toBe(true); + expect(mask.hasWebGLTexture()).toBe(false); - mask.get(MPMaskType.WEBGL_TEXTURE); + mask.getAsWebGLTexture(); - expect(mask.has(MPMaskType.UINT8_ARRAY)).toBe(true); - expect(mask.has(MPMaskType.FLOAT32_ARRAY)).toBe(true); - expect(mask.has(MPMaskType.WEBGL_TEXTURE)).toBe(true); + expect(mask.hasUint8Array()).toBe(true); + expect(mask.hasFloat32Array()).toBe(true); + expect(mask.hasWebGLTexture()).toBe(true); mask.close(); shaderContext.close(); diff --git a/mediapipe/tasks/web/vision/core/mask.ts b/mediapipe/tasks/web/vision/core/mask.ts index a3dedf63a..da14f104f 100644 --- a/mediapipe/tasks/web/vision/core/mask.ts +++ b/mediapipe/tasks/web/vision/core/mask.ts @@ -17,7 +17,7 @@ import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; /** The underlying type of the image. */ -export enum MPMaskType { +enum MPMaskType { /** Represents the native `UInt8Array` type. */ UINT8_ARRAY, /** Represents the native `Float32Array` type. */ @@ -34,9 +34,9 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; * * Masks are stored as `Uint8Array`, `Float32Array` or `WebGLTexture` objects. * You can convert the underlying type to any other type by passing the desired - * type to `get()`. As type conversions can be expensive, it is recommended to - * limit these conversions. You can verify what underlying types are already - * available by invoking `has()`. + * type to `getAs...()`. As type conversions can be expensive, it is recommended + * to limit these conversions. You can verify what underlying types are already + * available by invoking `has...()`. * * Masks that are returned from a MediaPipe Tasks are owned by by the * underlying C++ Task. If you need to extend the lifetime of these objects, @@ -47,9 +47,6 @@ export type MPMaskContainer = Uint8Array|Float32Array|WebGLTexture; export class MPMask { private gl?: WebGL2RenderingContext; - /** The underlying type of the mask. */ - static TYPE = MPMaskType; - /** @hideconstructor */ constructor( private readonly containers: MPMaskContainer[], @@ -63,13 +60,19 @@ export class MPMask { readonly height: number, ) {} - /** - * Returns whether this `MPMask` stores the mask in the desired - * format. This method can be called to reduce expensive conversion before - * invoking `get()`. - */ - has(type: MPMaskType): boolean { - return !!this.getContainer(type); + /** Returns whether this `MPMask` contains a mask of type `Uint8Array`. */ + hasUint8Array(): boolean { + return !!this.getContainer(MPMaskType.UINT8_ARRAY); + } + + /** Returns whether this `MPMask` contains a mask of type `Float32Array`. */ + hasFloat32Array(): boolean { + return !!this.getContainer(MPMaskType.FLOAT32_ARRAY); + } + + /** Returns whether this `MPMask` contains a mask of type `WebGLTexture`. */ + hasWebGLTexture(): boolean { + return !!this.getContainer(MPMaskType.WEBGL_TEXTURE); } /** @@ -77,43 +80,35 @@ export class MPMask { * expensive GPU to CPU transfer if the current mask is only available as a * `WebGLTexture`. * - * @param type The type of mask to return. * @return The current data as a Uint8Array. */ - get(type: MPMaskType.UINT8_ARRAY): Uint8Array; + getAsUint8Array(): Uint8Array { + return this.convertToUint8Array(); + } + /** * Returns the underlying mask as a single channel `Float32Array`. Note that * this involves an expensive GPU to CPU transfer if the current mask is only * available as a `WebGLTexture`. * - * @param type The type of mask to return. * @return The current mask as a Float32Array. */ - get(type: MPMaskType.FLOAT32_ARRAY): Float32Array; + getAsFloat32Array(): Float32Array { + return this.convertToFloat32Array(); + } + /** * Returns the underlying mask as a `WebGLTexture` object. Note that this * involves a CPU to GPU transfer if the current mask is only available as * a CPU array. The returned texture is bound to the current canvas (see * `.canvas`). * - * @param type The type of mask to return. * @return The current mask as a WebGLTexture. */ - get(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture; - get(type?: MPMaskType): MPMaskContainer { - switch (type) { - case MPMaskType.UINT8_ARRAY: - return this.convertToUint8Array(); - case MPMaskType.FLOAT32_ARRAY: - return this.convertToFloat32Array(); - case MPMaskType.WEBGL_TEXTURE: - return this.convertToWebGLTexture(); - default: - throw new Error(`Type is not supported: ${type}`); - } + getAsWebGLTexture(): WebGLTexture { + return this.convertToWebGLTexture(); } - private getContainer(type: MPMaskType.UINT8_ARRAY): Uint8Array|undefined; private getContainer(type: MPMaskType.FLOAT32_ARRAY): Float32Array|undefined; private getContainer(type: MPMaskType.WEBGL_TEXTURE): WebGLTexture|undefined; @@ -186,7 +181,7 @@ export class MPMask { } return new MPMask( - destinationContainers, this.has(MPMaskType.WEBGL_TEXTURE), this.canvas, + destinationContainers, this.hasWebGLTexture(), this.canvas, this.shaderContext, this.width, this.height); } @@ -220,9 +215,9 @@ export class MPMask { private convertToFloat32Array(): Float32Array { let float32Array = this.getContainer(MPMaskType.FLOAT32_ARRAY); if (!float32Array) { - if (this.has(MPMaskType.UINT8_ARRAY)) { - const source = this.getContainer(MPMaskType.UINT8_ARRAY)!; - float32Array = new Float32Array(source).map(v => v / 255); + const uint8Array = this.getContainer(MPMaskType.UINT8_ARRAY); + if (uint8Array) { + float32Array = new Float32Array(uint8Array).map(v => v / 255); } else { const gl = this.getGL(); const shaderContext = this.getShaderContext(); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index ea385fdfc..5b643b84e 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -17,7 +17,7 @@ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image'; -import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask'; +import {MPMask as MPMaskImpl} from '../../../tasks/web/vision/core/mask'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; @@ -36,7 +36,6 @@ const DrawingUtils = DrawingUtilsImpl; const FilesetResolver = FilesetResolverImpl; const MPImage = MPImageImpl; const MPMask = MPMaskImpl; -const MPMaskType = MPMaskTypeImpl; const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceLandmarksConnections = FaceLandmarksConnectionsImpl; @@ -55,7 +54,6 @@ export { FilesetResolver, MPImage, MPMask, - MPMaskType, FaceDetector, FaceLandmarker, FaceLandmarksConnections, diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 203983fa1..760b97b77 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -17,7 +17,7 @@ export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/vision/core/drawing_utils'; export {MPImage} from '../../../tasks/web/vision/core/image'; -export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask'; +export {MPMask} from '../../../tasks/web/vision/core/mask'; export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';