From a54409810053e9b2d4ba871d4fae7e95797574c4 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 28 Apr 2023 14:02:53 -0700 Subject: [PATCH] Update FaceStylizer to return MPImage PiperOrigin-RevId: 527980696 --- mediapipe/tasks/web/vision/core/BUILD | 1 + mediapipe/tasks/web/vision/core/image.ts | 5 ++ .../web/vision/core/vision_task_runner.ts | 62 ++++++++++++++----- .../tasks/web/vision/face_stylizer/BUILD | 2 + .../web/vision/face_stylizer/face_stylizer.ts | 27 ++++---- .../face_stylizer/face_stylizer_test.ts | 35 ++++++++--- mediapipe/tasks/web/vision/index.ts | 4 +- mediapipe/tasks/web/vision/types.ts | 2 +- 8 files changed, 96 insertions(+), 42 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index f010a8bdd..daeef060d 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -60,6 +60,7 @@ mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], deps = [ + ":image", ":image_processing_options", ":vision_task_options", "//mediapipe/framework/formats:rect_jspb_proto", diff --git a/mediapipe/tasks/web/vision/core/image.ts b/mediapipe/tasks/web/vision/core/image.ts index 739f05f0f..7a12a923b 100644 --- a/mediapipe/tasks/web/vision/core/image.ts +++ b/mediapipe/tasks/web/vision/core/image.ts @@ -59,8 +59,11 @@ function assertNotNull(value: T|null, msg: string): T { return value; } +// TODO: Move internal-only types to different module. + /** * Utility class that encapsulates the buffers used by `MPImageShaderContext`. + * For internal use only. */ class MPImageShaderBuffers { constructor( @@ -87,6 +90,8 @@ class MPImageShaderBuffers { /** * A class that encapsulates the shaders used by an MPImage. Can be re-used * across MPImages that use the same WebGL2Rendering context. + * + * For internal use only. */ export class MPImageShaderContext { private gl?: WebGL2RenderingContext; diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index a79cee559..5099d2960 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -17,6 +17,7 @@ import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {MPImage, MPImageShaderContext} from '../../../../tasks/web/vision/core/image'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner'; import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; @@ -51,6 +52,8 @@ function createCanvas(): HTMLCanvasElement|OffscreenCanvas|undefined { /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { + private readonly shaderContext = new MPImageShaderContext(); + protected static async createVisionInstance( type: WasmMediaPipeConstructor, fileset: WasmFileset, options: VisionTaskOptions): Promise { @@ -219,30 +222,55 @@ export abstract class VisionTaskRunner extends TaskRunner { this.finishProcessing(); } - /** Converts the RGB or RGBA Uint8Array of a WasmImage to ImageData. */ - protected convertToImageData(wasmImage: WasmImage): ImageData { + /** + * Converts a WasmImage to an MPImage. + * + * Converts the underlying Uint8ClampedArray-backed images to ImageData + * (adding an alpha channel if necessary), passes through WebGLTextures and + * throws for Float32Array-backed images. + */ + protected convertToMPImage(wasmImage: WasmImage): MPImage { const {data, width, height} = wasmImage; - if (!(data instanceof Uint8ClampedArray)) { - throw new Error( - 'Only Uint8ClampedArray-based images can be converted to ImageData'); - } - if (data.length === width * height * 4) { - return new ImageData(data, width, height); - } else if (data.length === width * height * 3) { - const rgba = new Uint8ClampedArray(width * height * 4); - for (let i = 0; i < width * height; ++i) { - rgba[4 * i] = data[3 * i]; - rgba[4 * i + 1] = data[3 * i + 1]; - rgba[4 * i + 2] = data[3 * i + 2]; - rgba[4 * i + 3] = 255; + if (data instanceof Uint8ClampedArray) { + let rgba: Uint8ClampedArray; + if (data.length === width * height * 4) { + rgba = data; + } else if (data.length === width * height * 3) { + // TODO: Convert in C++ + rgba = new Uint8ClampedArray(width * height * 4); + for (let i = 0; i < width * height; ++i) { + rgba[4 * i] = data[3 * i]; + rgba[4 * i + 1] = data[3 * i + 1]; + rgba[4 * i + 2] = data[3 * i + 2]; + rgba[4 * i + 3] = 255; + } + } else { + throw new Error( + `Unsupported channel count: ${data.length / width / height}`); } - return new ImageData(rgba, width, height); + + return new MPImage( + [new ImageData(rgba, width, height)], + /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, + this.graphRunner.wasmModule.canvas!, this.shaderContext, width, + height); + } else if (data instanceof WebGLTexture) { + return new MPImage( + [data], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, + this.graphRunner.wasmModule.canvas!, this.shaderContext, width, + height); } else { throw new Error( - `Unsupported channel count: ${data.length / width / height}`); + `Cannot convert type ${data.constructor.name} to MPImage.`); } } + + /** Closes and cleans up the resources held by this task. */ + override close(): void { + this.shaderContext.close(); + super.close(); + } } diff --git a/mediapipe/tasks/web/vision/face_stylizer/BUILD b/mediapipe/tasks/web/vision/face_stylizer/BUILD index 7716d617f..0c0167dbd 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/BUILD +++ b/mediapipe/tasks/web/vision/face_stylizer/BUILD @@ -17,6 +17,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:types", "//mediapipe/tasks/web/vision/core:vision_task_runner", @@ -46,6 +47,7 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:image", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", ], ) diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts index 13558e235..2a9adb315 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {MPImage} from '../../../../tasks/web/vision/core/image'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -39,15 +40,13 @@ const FACE_STYLIZER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** - * A callback that receives an image from the face stylizer, or `null` if no - * face was detected. The lifetime of the underlying data is limited to the - * duration of the callback. If asynchronous processing is needed, all data - * needs to be copied before the callback returns. - * - * The `WebGLTexture` output type is reserved for future usage. + * A callback that receives an `MPImage` object from the face stylizer, or + * `null` if no face was detected. The lifetime of the underlying data is + * limited to the duration of the callback. If asynchronous processing is + * needed, all data needs to be copied before the callback returns (via + * `image.clone()`). */ -export type FaceStylizerCallback = - (image: ImageData|WebGLTexture|null, width: number, height: number) => void; +export type FaceStylizerCallback = (image: MPImage|null) => void; /** Performs face stylization on images. */ export class FaceStylizer extends VisionTaskRunner { @@ -270,18 +269,14 @@ export class FaceStylizer extends VisionTaskRunner { graphConfig.addNode(segmenterNode); this.graphRunner.attachImageListener( - STYLIZED_IMAGE_STREAM, (image, timestamp) => { - if (image.data instanceof WebGLTexture) { - this.userCallback(image.data, image.width, image.height); - } else { - const imageData = this.convertToImageData(image); - this.userCallback(imageData, image.width, image.height); - } + STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => { + const mpImage = this.convertToMPImage(wasmImage); + this.userCallback(mpImage); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( STYLIZED_IMAGE_STREAM, timestamp => { - this.userCallback(null, /* width= */ 0, /* height= */ 0); + this.userCallback(null); this.setLatestOutputTimestamp(timestamp); }); diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts index 167fd674e..7d30ef2a9 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts @@ -19,6 +19,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {MPImageStorageType} from '../../../../tasks/web/vision/core/image'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {FaceStylizer} from './face_stylizer'; @@ -114,11 +115,33 @@ describe('FaceStylizer', () => { }); // Invoke the face stylizeer - faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => { + faceStylizer.stylize({} as HTMLImageElement, image => { expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(image).toBeInstanceOf(ImageData); - expect(width).toEqual(1); - expect(height).toEqual(1); + expect(image).not.toBeNull(); + expect(image!.hasType(MPImageStorageType.IMAGE_DATA)).toBeTrue(); + expect(image!.width).toEqual(1); + expect(image!.height).toEqual(1); + done(); + }); + }); + + it('invokes callback even when no faes are detected', (done) => { + if (typeof ImageData === 'undefined') { + console.log('ImageData tests are not supported on Node'); + done(); + return; + } + + // Pass the test data to our listener + faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceStylizer); + faceStylizer.emptyPacketListener!(/* timestamp= */ 1337); + }); + + // Invoke the face stylizeer + faceStylizer.stylize({} as HTMLImageElement, image => { + expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(image).toBeNull(); done(); }); }); @@ -131,11 +154,9 @@ describe('FaceStylizer', () => { }); // Invoke the face stylizeer - faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => { + faceStylizer.stylize({} as HTMLImageElement, image => { expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(image).toBeNull(); - expect(width).toEqual(0); - expect(height).toEqual(0); done(); }); }); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 2ea0e7278..632a294d6 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -16,7 +16,7 @@ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; -import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image'; +import {MPImage as MPImageImpl, MPImageStorageType as MPImageStorageTypeImpl} from '../../../tasks/web/vision/core/image'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; @@ -34,6 +34,7 @@ import {PoseLandmarker as PoseLandmarkerImpl} from '../../../tasks/web/vision/po const DrawingUtils = DrawingUtilsImpl; const FilesetResolver = FilesetResolverImpl; const MPImage = MPImageImpl; +const MPImageStorageType = MPImageStorageTypeImpl; const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceLandmarksConnections = FaceLandmarksConnectionsImpl; @@ -51,6 +52,7 @@ export { DrawingUtils, FilesetResolver, MPImage, + MPImageStorageType, FaceDetector, FaceLandmarker, FaceLandmarksConnections, diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 1c0466bc1..381052881 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -16,7 +16,7 @@ export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/vision/core/drawing_utils'; -export * from '../../../tasks/web/vision/core/image'; +export {MPImage, MPImageChannelConverter, MPImageStorageType} from '../../../tasks/web/vision/core/image'; export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';