Simplify MPImage API by removing the Type Enums from the public API

PiperOrigin-RevId: 529960399
This commit is contained in:
Sebastian Schmidt 2023-05-06 07:47:37 -07:00 committed by Copybara-Service
parent 8a6fe90759
commit e9fc66277a
6 changed files with 67 additions and 78 deletions

View File

@ -16,7 +16,7 @@
import 'jasmine';
import {MPImage, MPImageType} from './image';
import {MPImage} from './image';
import {MPImageShaderContext} from './image_shader_context';
const WIDTH = 2;
@ -122,14 +122,14 @@ class MPImageTestContext {
function assertEquality(image: MPImage, expected: ImageType): void {
if (expected instanceof ImageData) {
const result = image.get(MPImageType.IMAGE_DATA);
const result = image.getAsImageData();
expect(result).toEqual(expected);
} else if (expected instanceof ImageBitmap) {
const result = image.get(MPImageType.IMAGE_BITMAP);
const result = image.getAsImageBitmap();
expect(readPixelsFromImageBitmap(result))
.toEqual(readPixelsFromImageBitmap(expected));
} else { // WebGLTexture
const result = image.get(MPImageType.WEBGL_TEXTURE);
const result = image.getAsWebGLTexture();
expect(readPixelsFromWebGLTexture(result))
.toEqual(readPixelsFromWebGLTexture(expected));
}
@ -139,7 +139,8 @@ class MPImageTestContext {
shaderContext: MPImageShaderContext, input: ImageType, width: number,
height: number): MPImage {
return new MPImage(
[input], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
[input],
/* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
context.canvas, shaderContext, width, height);
}
@ -189,7 +190,7 @@ class MPImageTestContext {
/* ownsWebGLTexture= */ false, context.canvas, shaderContext, WIDTH,
HEIGHT);
const result = image.clone().get(MPImageType.IMAGE_DATA);
const result = image.clone().getAsImageData();
expect(result).toEqual(context.imageData);
shaderContext.close();
@ -206,13 +207,13 @@ class MPImageTestContext {
// Verify that we can mix the different shader modes by running them out of
// order.
let result = image.get(MPImageType.IMAGE_DATA);
let result = image.getAsImageData();
expect(result).toEqual(context.imageData);
result = image.clone().get(MPImageType.IMAGE_DATA);
result = image.clone().getAsImageData();
expect(result).toEqual(context.imageData);
result = image.get(MPImageType.IMAGE_DATA);
result = image.getAsImageData();
expect(result).toEqual(context.imageData);
shaderContext.close();
@ -224,21 +225,21 @@ class MPImageTestContext {
const shaderContext = new MPImageShaderContext();
const image = createImage(shaderContext, context.imageData, WIDTH, HEIGHT);
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(false);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(false);
expect(image.hasImageData()).toBe(true);
expect(image.hasWebGLTexture()).toBe(false);
expect(image.hasImageBitmap()).toBe(false);
image.get(MPImageType.WEBGL_TEXTURE);
image.getAsWebGLTexture();
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(true);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(false);
expect(image.hasImageData()).toBe(true);
expect(image.hasWebGLTexture()).toBe(true);
expect(image.hasImageBitmap()).toBe(false);
image.get(MPImageType.IMAGE_BITMAP);
image.getAsImageBitmap();
expect(image.has(MPImageType.IMAGE_DATA)).toBe(true);
expect(image.has(MPImageType.WEBGL_TEXTURE)).toBe(true);
expect(image.has(MPImageType.IMAGE_BITMAP)).toBe(true);
expect(image.hasImageData()).toBe(true);
expect(image.hasWebGLTexture()).toBe(true);
expect(image.hasImageBitmap()).toBe(true);
image.close();
shaderContext.close();

View File

@ -17,7 +17,7 @@
import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** The underlying type of the image. */
export enum MPImageType {
enum MPImageType {
/** Represents the native `ImageData` type. */
IMAGE_DATA,
/** Represents the native `ImageBitmap` type. */
@ -34,9 +34,9 @@ export type MPImageContainer = ImageData|ImageBitmap|WebGLTexture;
*
* Images are stored as `ImageData`, `ImageBitmap` or `WebGLTexture` objects.
* You can convert the underlying type to any other type by passing the
* desired type to `get()`. As type conversions can be expensive, it is
* desired type to `getAs...()`. As type conversions can be expensive, it is
* recommended to limit these conversions. You can verify what underlying
* types are already available by invoking `has()`.
* types are already available by invoking `has...()`.
*
* Images that are returned from a MediaPipe Tasks are owned by by the
* underlying C++ Task. If you need to extend the lifetime of these objects,
@ -52,9 +52,6 @@ export type MPImageContainer = ImageData|ImageBitmap|WebGLTexture;
export class MPImage {
private gl?: WebGL2RenderingContext;
/** The underlying type of the image. */
static TYPE = MPImageType;
/** @hideconstructor */
constructor(
private readonly containers: MPImageContainer[],
@ -69,13 +66,19 @@ export class MPImage {
readonly height: number,
) {}
/**
* Returns whether this `MPImage` stores the image in the desired format.
* This method can be called to reduce expensive conversion before invoking
* `get()`.
*/
has(type: MPImageType): boolean {
return !!this.getContainer(type);
/** Returns whether this `MPImage` contains a mask of type `ImageData`. */
hasImageData(): boolean {
return !!this.getContainer(MPImageType.IMAGE_DATA);
}
/** Returns whether this `MPImage` contains a mask of type `ImageBitmap`. */
hasImageBitmap(): boolean {
return !!this.getContainer(MPImageType.IMAGE_BITMAP);
}
/** Returns whether this `MPImage` contains a mask of type `WebGLTexture`. */
hasWebGLTexture(): boolean {
return !!this.getContainer(MPImageType.WEBGL_TEXTURE);
}
/**
@ -85,7 +88,10 @@ export class MPImage {
*
* @return The current image as an ImageData object.
*/
get(type: MPImageType.IMAGE_DATA): ImageData;
getAsImageData(): ImageData {
return this.convertToImageData();
}
/**
* Returns the underlying image as an `ImageBitmap`. Note that
* conversions to `ImageBitmap` are expensive, especially if the data
@ -96,32 +102,24 @@ export class MPImage {
* https://developer.mozilla.org/en-US/docs/Web/API/OffscreenCanvas/getContext
* for a list of supported platforms.
*
* @param type The type of image to return.
* @return The current image as an ImageBitmap object.
*/
get(type: MPImageType.IMAGE_BITMAP): ImageBitmap;
getAsImageBitmap(): ImageBitmap {
return this.convertToImageBitmap();
}
/**
* Returns the underlying image as a `WebGLTexture` object. Note that this
* involves a CPU to GPU transfer if the current image is only available as
* an `ImageData` object. The returned texture is bound to the current
* canvas (see `.canvas`).
*
* @param type The type of image to return.
* @return The current image as a WebGLTexture.
*/
get(type: MPImageType.WEBGL_TEXTURE): WebGLTexture;
get(type?: MPImageType): MPImageContainer {
switch (type) {
case MPImageType.IMAGE_DATA:
return this.convertToImageData();
case MPImageType.IMAGE_BITMAP:
return this.convertToImageBitmap();
case MPImageType.WEBGL_TEXTURE:
return this.convertToWebGLTexture();
default:
throw new Error(`Type is not supported: ${type}`);
}
getAsWebGLTexture(): WebGLTexture {
return this.convertToWebGLTexture();
}
private getContainer(type: MPImageType.IMAGE_DATA): ImageData|undefined;
private getContainer(type: MPImageType.IMAGE_BITMAP): ImageBitmap|undefined;
private getContainer(type: MPImageType.WEBGL_TEXTURE): WebGLTexture|undefined;
@ -200,9 +198,8 @@ export class MPImage {
}
return new MPImage(
destinationContainers, this.has(MPImageType.IMAGE_BITMAP),
this.has(MPImageType.WEBGL_TEXTURE), this.canvas, this.shaderContext,
this.width, this.height);
destinationContainers, this.hasImageBitmap(), this.hasWebGLTexture(),
this.canvas, this.shaderContext, this.width, this.height);
}
private getOffscreenCanvas(): OffscreenCanvas {
@ -251,27 +248,22 @@ export class MPImage {
private convertToImageData(): ImageData {
let imageData = this.getContainer(MPImageType.IMAGE_DATA);
if (!imageData) {
if (this.has(MPImageType.IMAGE_BITMAP) ||
this.has(MPImageType.WEBGL_TEXTURE)) {
const gl = this.getGL();
const shaderContext = this.getShaderContext();
const pixels = new Uint8Array(this.width * this.height * 4);
const gl = this.getGL();
const shaderContext = this.getShaderContext();
const pixels = new Uint8Array(this.width * this.height * 4);
// Create texture if needed
const webGlTexture = this.convertToWebGLTexture();
// Create texture if needed
const webGlTexture = this.convertToWebGLTexture();
// Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
shaderContext.unbindFramebuffer();
// Create a framebuffer from the texture and read back pixels
shaderContext.bindFramebuffer(gl, webGlTexture);
gl.readPixels(
0, 0, this.width, this.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
shaderContext.unbindFramebuffer();
imageData = new ImageData(
new Uint8ClampedArray(pixels.buffer), this.width, this.height);
this.containers.push(imageData);
} else {
throw new Error('Couldn\t find backing image for ImageData conversion');
}
imageData = new ImageData(
new Uint8ClampedArray(pixels.buffer), this.width, this.height);
this.containers.push(imageData);
}
return imageData;

View File

@ -47,7 +47,6 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
],
)

View File

@ -19,7 +19,6 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {FaceStylizer} from './face_stylizer';
@ -117,7 +116,7 @@ describe('FaceStylizer', () => {
const image = faceStylizer.stylize({} as HTMLImageElement);
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).not.toBeNull();
expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue();
expect(image!.hasImageData()).toBeTrue();
expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1);
image!.close();
@ -142,7 +141,7 @@ describe('FaceStylizer', () => {
faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).not.toBeNull();
expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue();
expect(image!.hasImageData()).toBeTrue();
expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1);
done();

View File

@ -16,7 +16,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl, MPImageType as MPImageTypeImpl} from '../../../tasks/web/vision/core/image';
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
import {MPMask as MPMaskImpl, MPMaskType as MPMaskTypeImpl} from '../../../tasks/web/vision/core/mask';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
@ -35,7 +35,6 @@ import {PoseLandmarker as PoseLandmarkerImpl} from '../../../tasks/web/vision/po
const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl;
const MPImageType = MPImageTypeImpl;
const MPMask = MPMaskImpl;
const MPMaskType = MPMaskTypeImpl;
const FaceDetector = FaceDetectorImpl;
@ -55,7 +54,6 @@ export {
DrawingUtils,
FilesetResolver,
MPImage,
MPImageType,
MPMask,
MPMaskType,
FaceDetector,

View File

@ -16,7 +16,7 @@
export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils';
export {MPImage, MPImageType} from '../../../tasks/web/vision/core/image';
export {MPImage} from '../../../tasks/web/vision/core/image';
export {MPMask, MPMaskType} from '../../../tasks/web/vision/core/mask';
export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';