Simplify MPImage API by removing the Type Enums from the public API

PiperOrigin-RevId: 529960399
This commit is contained in:
Sebastian Schmidt 2023-05-06 07:47:37 -07:00 committed by Copybara-Service
parent 8a6fe90759
commit e9fc66277a
6 changed files with 67 additions and 78 deletions

View File

@ -16,7 +16,7 @@
import 'jasmine'; import 'jasmine';
import {MPImage, MPImageType} from './image'; import {MPImage} from './image';
import {MPImageShaderContext} from './image_shader_context'; import {MPImageShaderContext} from './image_shader_context';
const WIDTH = 2; const WIDTH = 2;
@ -122,14 +122,14 @@ class MPImageTestContext {
function assertEquality(image: MPImage, expected: ImageType): void { function assertEquality(image: MPImage, expected: ImageType): void {
if (expected instanceof ImageData) { if (expected instanceof ImageData) {
const result = image.get(MPImageType.IMAGE_DATA); const result = image.getAsImageData();
expect(result).toEqual(expected); expect(result).toEqual(expected);
} else if (expected instanceof ImageBitmap) { } else if (expected instanceof ImageBitmap) {
const result = image.get(MPImageType.IMAGE_BITMAP); const result = image.getAsImageBitmap();
expect(readPixelsFromImageBitmap(result)) expect(readPixelsFromImageBitmap(result))
.toEqual(readPixelsFromImageBitmap(expected)); .toEqual(readPixelsFromImageBitmap(expected));
} else { // WebGLTexture } else { // WebGLTexture
const result = image.get(MPImageType.WEBGL_TEXTURE); const result = image.getAsWebGLTexture();
expect(readPixelsFromWebGLTexture(result)) expect(readPixelsFromWebGLTexture(result))
.toEqual(readPixelsFromWebGLTexture(expected)); .toEqual(readPixelsFromWebGLTexture(expected));
} }
@ -139,7 +139,8 @@ class MPImageTestContext {
shaderContext: MPImageShaderContext, input: ImageType, width: number, shaderContext: MPImageShaderContext, input: ImageType, width: number,
height: number): MPImage { height: number): MPImage {
return new MPImage( return new MPImage(
[input], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, [input],
/* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
context.canvas, shaderContext, width, height); context.canvas, shaderContext, width, height);
} }
@ -189,7 +190,7 @@ class MPImageTestContext {
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH, /* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT); HEIGHT);
const result = image.clone().get(MPImageType.IMAGE_DATA); const result = image.clone().getAsImageData();
expect(result).toEqual(context.imageData); expect(result).toEqual(context.imageData);
shaderContext.close(); shaderContext.close();
@ -206,13 +207,13 @@ class MPImageTestContext {
// Verify that we can mix the different shader modes by running them out of // Verify that we can mix the different shader modes by running them out of
// order. // order.
let result = image.get(MPImageType.IMAGE_DATA); let result = image.getAsImageData();
expect(result).toEqual(context.imageData); expect(result).toEqual(context.imageData);
result = image.clone().get(MPImageType.IMAGE_DATA); result = image.clone().getAsImageData();
expect(result).toEqual(context.imageData); expect(result).toEqual(context.imageData);
result = image.get(MPImageType.IMAGE_DATA); result = image.getAsImageData();
expect(result).toEqual(context.imageData); expect(result).toEqual(context.imageData);
shaderContext.close(); shaderContext.close();
@ -224,21 +225,21 @@ class MPImageTestContext {
const shaderContext = new MPImageShaderContext(); const shaderContext = new MPImageShaderContext();
const image = createImage(shaderContext, context.imageData, WIDTH, HEIGHT); const image = createImage(shaderContext, context.imageData, WIDTH, HEIGHT);
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true); expect(image.hasImageData()).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(false); expect(image.hasWebGLTexture()).toBe(false);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(false); expect(image.hasImageBitmap()).toBe(false);
image.get(MPImageType.WEBGL_TEXTURE); image.getAsWebGLTexture();
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true); expect(image.hasImageData()).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(true); expect(image.hasWebGLTexture()).toBe(true);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(false); expect(image.hasImageBitmap()).toBe(false);
image.get(MPImageType.IMAGE_BITMAP); image.getAsImageBitmap();
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true); expect(image.hasImageData()).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(true); expect(image.hasWebGLTexture()).toBe(true);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(true); expect(image.hasImageBitmap()).toBe(true);
image.close(); image.close();
shaderContext.close(); shaderContext.close();

View File

@ -17,7 +17,7 @@
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** The underlying type of the image. */ /** The underlying type of the image. */
export enum MPImageType { enum MPImageType {
/** Represents the native `ImageData` type. */ /** Represents the native `ImageData` type. */
IMAGE_DATA, IMAGE_DATA,
/** Represents the native `ImageBitmap` type. */ /** Represents the native `ImageBitmap` type. */
@ -34,9 +34,9 @@ export type MPImageContainer = ImageData|ImageBitmap|WebGLTexture;
* *
* Images are stored as `ImageData`, `ImageBitmap` or `WebGLTexture` objects. * Images are stored as `ImageData`, `ImageBitmap` or `WebGLTexture` objects.
* You can convert the underlying type to any other type by passing the * You can convert the underlying type to any other type by passing the
* desired type to `get()`. As type conversions can be expensive, it is * desired type to `getAs...()`. As type conversions can be expensive, it is
* recommended to limit these conversions. You can verify what underlying * recommended to limit these conversions. You can verify what underlying
* types are already available by invoking `has()`. * types are already available by invoking `has...()`.
* *
* Images that are returned from a MediaPipe Tasks are owned by by the * Images that are returned from a MediaPipe Tasks are owned by by the
* underlying C++ Task. If you need to extend the lifetime of these objects, * underlying C++ Task. If you need to extend the lifetime of these objects,
@ -52,9 +52,6 @@ export type MPImageContainer = ImageData|ImageBitmap|WebGLTexture;
export class MPImage { export class MPImage {
private gl?: WebGL2RenderingContext; private gl?: WebGL2RenderingContext;
/** The underlying type of the image. */
static TYPE = MPImageType;
/** @hideconstructor */ /** @hideconstructor */
constructor( constructor(
private readonly containers: MPImageContainer[], private readonly containers: MPImageContainer[],
@ -69,13 +66,19 @@ export class MPImage {
readonly height: number, readonly height: number,
) {} ) {}
/** /** Returns whether this `MPImage` contains a mask of type `ImageData`. */
* Returns whether this `MPImage` stores the image in the desired format. hasImageData(): boolean {
* This method can be called to reduce expensive conversion before invoking return !!this.getContainer(MPImageType.IMAGE_DATA);
* `get()`. }
*/
has(type: MPImageType): boolean { /** Returns whether this `MPImage` contains a mask of type `ImageBitmap`. */
return !!this.getContainer(type); hasImageBitmap(): boolean {
return !!this.getContainer(MPImageType.IMAGE_BITMAP);
}
/** Returns whether this `MPImage` contains a mask of type `WebGLTexture`. */
hasWebGLTexture(): boolean {
return !!this.getContainer(MPImageType.WEBGL_TEXTURE);
} }
/** /**
@ -85,7 +88,10 @@ export class MPImage {
* *
* @return The current image as an ImageData object. * @return The current image as an ImageData object.
*/ */
get(type: MPImageType.IMAGE_DATA): ImageData; getAsImageData(): ImageData {
return this.convertToImageData();
}
/** /**
* Returns the underlying image as an `ImageBitmap`. Note that * Returns the underlying image as an `ImageBitmap`. Note that
* conversions to `ImageBitmap` are expensive, especially if the data * conversions to `ImageBitmap` are expensive, especially if the data
@ -96,32 +102,24 @@ export class MPImage {
* https://developer.mozilla.org/en-US/docs/Web/API/OffscreenCanvas/getContext * https://developer.mozilla.org/en-US/docs/Web/API/OffscreenCanvas/getContext
* for a list of supported platforms. * for a list of supported platforms.
* *
* @param type The type of image to return.
* @return The current image as an ImageBitmap object. * @return The current image as an ImageBitmap object.
*/ */
get(type: MPImageType.IMAGE_BITMAP): ImageBitmap; getAsImageBitmap(): ImageBitmap {
return this.convertToImageBitmap();
}
/** /**
* Returns the underlying image as a `WebGLTexture` object. Note that this * Returns the underlying image as a `WebGLTexture` object. Note that this
* involves a CPU to GPU transfer if the current image is only available as * involves a CPU to GPU transfer if the current image is only available as
* an `ImageData` object. The returned texture is bound to the current * an `ImageData` object. The returned texture is bound to the current
* canvas (see `.canvas`). * canvas (see `.canvas`).
* *
* @param type The type of image to return.
* @return The current image as a WebGLTexture. * @return The current image as a WebGLTexture.
*/ */
get(type: MPImageType.WEBGL_TEXTURE): WebGLTexture; getAsWebGLTexture(): WebGLTexture {
get(type?: MPImageType): MPImageContainer {
switch (type) {
case MPImageType.IMAGE_DATA:
return this.convertToImageData();
case MPImageType.IMAGE_BITMAP:
return this.convertToImageBitmap();
case MPImageType.WEBGL_TEXTURE:
return this.convertToWebGLTexture(); return this.convertToWebGLTexture();
default:
throw new Error(`Type is not supported: ${type}`);
}
} }
private getContainer(type: MPImageType.IMAGE_DATA): ImageData|undefined; private getContainer(type: MPImageType.IMAGE_DATA): ImageData|undefined;
private getContainer(type: MPImageType.IMAGE_BITMAP): ImageBitmap|undefined; private getContainer(type: MPImageType.IMAGE_BITMAP): ImageBitmap|undefined;
private getContainer(type: MPImageType.WEBGL_TEXTURE): WebGLTexture|undefined; private getContainer(type: MPImageType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -200,9 +198,8 @@ export class MPImage {
} }
return new MPImage( return new MPImage(
destinationContainers, this.has(MPImageType.IMAGE_BITMAP), destinationContainers, this.hasImageBitmap(), this.hasWebGLTexture(),
this.has(MPImageType.WEBGL_TEXTURE), this.canvas, this.shaderContext, this.canvas, this.shaderContext, this.width, this.height);
this.width, this.height);
} }
private getOffscreenCanvas(): OffscreenCanvas { private getOffscreenCanvas(): OffscreenCanvas {
@ -251,8 +248,6 @@ export class MPImage {
private convertToImageData(): ImageData { private convertToImageData(): ImageData {
let imageData = this.getContainer(MPImageType.IMAGE_DATA); let imageData = this.getContainer(MPImageType.IMAGE_DATA);
if (!imageData) { if (!imageData) {
if (this.has(MPImageType.IMAGE_BITMAP) ||
this.has(MPImageType.WEBGL_TEXTURE)) {
const gl = this.getGL(); const gl = this.getGL();
const shaderContext = this.getShaderContext(); const shaderContext = this.getShaderContext();
const pixels = new Uint8Array(this.width * this.height * 4); const pixels = new Uint8Array(this.width * this.height * 4);
@ -269,9 +264,6 @@ export class MPImage {
imageData = new ImageData( imageData = new ImageData(
new Uint8ClampedArray(pixels.buffer), this.width, this.height); new Uint8ClampedArray(pixels.buffer), this.width, this.height);
this.containers.push(imageData); this.containers.push(imageData);
} else {
throw new Error('Couldn\t find backing image for ImageData conversion');
}
} }
return imageData; return imageData;

View File

@ -47,7 +47,6 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
], ],
) )

View File

@ -19,7 +19,6 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {FaceStylizer} from './face_stylizer'; import {FaceStylizer} from './face_stylizer';
@ -117,7 +116,7 @@ describe('FaceStylizer', () => {
const image = faceStylizer.stylize({} as HTMLImageElement); const image = faceStylizer.stylize({} as HTMLImageElement);
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).not.toBeNull(); expect(image).not.toBeNull();
expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue(); expect(image!.hasImageData()).toBeTrue();
expect(image!.width).toEqual(1); expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1); expect(image!.height).toEqual(1);
image!.close(); image!.close();
@ -142,7 +141,7 @@ describe('FaceStylizer', () => {
faceStylizer.stylize({} as HTMLImageElement, image => { faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).not.toBeNull(); expect(image).not.toBeNull();
expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue(); expect(image!.hasImageData()).toBeTrue();
expect(image!.width).toEqual(1); expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1); expect(image!.height).toEqual(1);
done(); done();

View File

@ -16,7 +16,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl, MPImageType as MPImageTypeImpl} from '../../../tasks/web/vision/core/image'; import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask'; import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
@ -35,7 +35,6 @@ import {PoseLandmarker as PoseLandmarkerImpl} from '../../../tasks/web/vision/po
const DrawingUtils = DrawingUtilsImpl; const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl; const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl; const MPImage = MPImageImpl;
const MPImageType = MPImageTypeImpl;
const MPMask = MPMaskImpl; const MPMask = MPMaskImpl;
const MPMaskType = MPMaskTypeImpl; const MPMaskType = MPMaskTypeImpl;
const FaceDetector = FaceDetectorImpl; const FaceDetector = FaceDetectorImpl;
@ -55,7 +54,6 @@ export {
DrawingUtils, DrawingUtils,
FilesetResolver, FilesetResolver,
MPImage, MPImage,
MPImageType,
MPMask, MPMask,
MPMaskType, MPMaskType,
FaceDetector, FaceDetector,

View File

@ -16,7 +16,7 @@
export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils'; export * from '../../../tasks/web/vision/core/drawing_utils';
export {MPImage, MPImageType} from '../../../tasks/web/vision/core/image'; export {MPImage} from '../../../tasks/web/vision/core/image';
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask'; export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';