Update FaceStylizer to return MPImage

PiperOrigin-RevId: 527980696
This commit is contained in:
Sebastian Schmidt 2023-04-28 14:02:53 -07:00 committed by Copybara-Service
parent 5cffb3973f
commit a544098100
8 changed files with 96 additions and 42 deletions

View File

@ -60,6 +60,7 @@ mediapipe_ts_library(
name = "vision_task_runner", name = "vision_task_runner",
srcs = ["vision_task_runner.ts"], srcs = ["vision_task_runner.ts"],
deps = [ deps = [
":image",
":image_processing_options", ":image_processing_options",
":vision_task_options", ":vision_task_options",
"//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto",

View File

@ -59,8 +59,11 @@ function assertNotNull<T>(value: T|null, msg: string): T {
return value; return value;
} }
// TODO: Move internal-only types to different module.
/** /**
* Utility class that encapsulates the buffers used by `MPImageShaderContext`. * Utility class that encapsulates the buffers used by `MPImageShaderContext`.
* For internal use only.
*/ */
class MPImageShaderBuffers { class MPImageShaderBuffers {
constructor( constructor(
@ -87,6 +90,8 @@ class MPImageShaderBuffers {
/** /**
* A class that encapsulates the shaders used by an MPImage. Can be re-used * A class that encapsulates the shaders used by an MPImage. Can be re-used
* across MPImages that use the same WebGL2Rendering context. * across MPImages that use the same WebGL2Rendering context.
*
* For internal use only.
*/ */
export class MPImageShaderContext { export class MPImageShaderContext {
private gl?: WebGL2RenderingContext; private gl?: WebGL2RenderingContext;

View File

@ -17,6 +17,7 @@
import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {MPImage, MPImageShaderContext} from '../../../../tasks/web/vision/core/image';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner'; import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -51,6 +52,8 @@ function createCanvas(): HTMLCanvasElement|OffscreenCanvas|undefined {
/** Base class for all MediaPipe Vision Tasks. */ /** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner extends TaskRunner { export abstract class VisionTaskRunner extends TaskRunner {
private readonly shaderContext = new MPImageShaderContext();
protected static async createVisionInstance<T extends VisionTaskRunner>( protected static async createVisionInstance<T extends VisionTaskRunner>(
type: WasmMediaPipeConstructor<T>, fileset: WasmFileset, type: WasmMediaPipeConstructor<T>, fileset: WasmFileset,
options: VisionTaskOptions): Promise<T> { options: VisionTaskOptions): Promise<T> {
@ -219,29 +222,54 @@ export abstract class VisionTaskRunner extends TaskRunner {
this.finishProcessing(); this.finishProcessing();
} }
/** Converts the RGB or RGBA Uint8Array of a WasmImage to ImageData. */ /**
protected convertToImageData(wasmImage: WasmImage): ImageData { * Converts a WasmImage to an MPImage.
*
* Converts the underlying Uint8ClampedArray-backed images to ImageData
* (adding an alpha channel if necessary), passes through WebGLTextures and
* throws for Float32Array-backed images.
*/
protected convertToMPImage(wasmImage: WasmImage): MPImage {
const {data, width, height} = wasmImage; const {data, width, height} = wasmImage;
if (!(data instanceof Uint8ClampedArray)) {
throw new Error(
'Only Uint8ClampedArray-based images can be converted to ImageData');
}
if (data instanceof Uint8ClampedArray) {
let rgba: Uint8ClampedArray;
if (data.length === width * height * 4) { if (data.length === width * height * 4) {
return new ImageData(data, width, height); rgba = data;
} else if (data.length === width * height * 3) { } else if (data.length === width * height * 3) {
const rgba = new Uint8ClampedArray(width * height * 4); // TODO: Convert in C++
rgba = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < width * height; ++i) { for (let i = 0; i < width * height; ++i) {
rgba[4 * i] = data[3 * i]; rgba[4 * i] = data[3 * i];
rgba[4 * i + 1] = data[3 * i + 1]; rgba[4 * i + 1] = data[3 * i + 1];
rgba[4 * i + 2] = data[3 * i + 2]; rgba[4 * i + 2] = data[3 * i + 2];
rgba[4 * i + 3] = 255; rgba[4 * i + 3] = 255;
} }
return new ImageData(rgba, width, height);
} else { } else {
throw new Error( throw new Error(
`Unsupported channel count: ${data.length / width / height}`); `Unsupported channel count: ${data.length / width / height}`);
} }
return new MPImage(
[new ImageData(rgba, width, height)],
/* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
height);
} else if (data instanceof WebGLTexture) {
return new MPImage(
[data], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
height);
} else {
throw new Error(
`Cannot convert type ${data.constructor.name} to MPImage.`);
}
}
/** Closes and cleans up the resources held by this task. */
override close(): void {
this.shaderContext.close();
super.close();
} }
} }

View File

@ -17,6 +17,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:types", "//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
@ -46,6 +47,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
], ],
) )

View File

@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb'; import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
@ -39,15 +40,13 @@ const FACE_STYLIZER_GRAPH =
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/** /**
* A callback that receives an image from the face stylizer, or `null` if no * A callback that receives an `MPImage` object from the face stylizer, or
* face was detected. The lifetime of the underlying data is limited to the * `null` if no face was detected. The lifetime of the underlying data is
* duration of the callback. If asynchronous processing is needed, all data * limited to the duration of the callback. If asynchronous processing is
* needs to be copied before the callback returns. * needed, all data needs to be copied before the callback returns (via
* * `image.clone()`).
* The `WebGLTexture` output type is reserved for future usage.
*/ */
export type FaceStylizerCallback = export type FaceStylizerCallback = (image: MPImage|null) => void;
(image: ImageData|WebGLTexture|null, width: number, height: number) => void;
/** Performs face stylization on images. */ /** Performs face stylization on images. */
export class FaceStylizer extends VisionTaskRunner { export class FaceStylizer extends VisionTaskRunner {
@ -270,18 +269,14 @@ export class FaceStylizer extends VisionTaskRunner {
graphConfig.addNode(segmenterNode); graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
STYLIZED_IMAGE_STREAM, (image, timestamp) => { STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => {
if (image.data instanceof WebGLTexture) { const mpImage = this.convertToMPImage(wasmImage);
this.userCallback(image.data, image.width, image.height); this.userCallback(mpImage);
} else {
const imageData = this.convertToImageData(image);
this.userCallback(imageData, image.width, image.height);
}
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
STYLIZED_IMAGE_STREAM, timestamp => { STYLIZED_IMAGE_STREAM, timestamp => {
this.userCallback(null, /* width= */ 0, /* height= */ 0); this.userCallback(null);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });

View File

@ -19,6 +19,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImageStorageType} from '../../../../tasks/web/vision/core/image';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {FaceStylizer} from './face_stylizer'; import {FaceStylizer} from './face_stylizer';
@ -114,11 +115,33 @@ describe('FaceStylizer', () => {
}); });
// Invoke the face stylizeer // Invoke the face stylizeer
faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => { faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).toBeInstanceOf(ImageData); expect(image).not.toBeNull();
expect(width).toEqual(1); expect(image!.hasType(MPImageStorageType.IMAGE_DATA)).toBeTrue();
expect(height).toEqual(1); expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1);
done();
});
});
it('invokes callback even when no faes are detected', (done) => {
if (typeof ImageData === 'undefined') {
console.log('ImageData tests are not supported on Node');
done();
return;
}
// Pass the test data to our listener
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer);
faceStylizer.emptyPacketListener!(/* timestamp= */ 1337);
});
// Invoke the face stylizeer
faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).toBeNull();
done(); done();
}); });
}); });
@ -131,11 +154,9 @@ describe('FaceStylizer', () => {
}); });
// Invoke the face stylizeer // Invoke the face stylizeer
faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => { faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).toBeNull(); expect(image).toBeNull();
expect(width).toEqual(0);
expect(height).toEqual(0);
done(); done();
}); });
}); });

View File

@ -16,7 +16,7 @@
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image'; import {MPImage as MPImageImpl, MPImageStorageType as MPImageStorageTypeImpl} from '../../../tasks/web/vision/core/image';
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
@ -34,6 +34,7 @@ import {PoseLandmarker as PoseLandmarkerImpl} from '../../../tasks/web/vision/po
const DrawingUtils = DrawingUtilsImpl; const DrawingUtils = DrawingUtilsImpl;
const FilesetResolver = FilesetResolverImpl; const FilesetResolver = FilesetResolverImpl;
const MPImage = MPImageImpl; const MPImage = MPImageImpl;
const MPImageStorageType = MPImageStorageTypeImpl;
const FaceDetector = FaceDetectorImpl; const FaceDetector = FaceDetectorImpl;
const FaceLandmarker = FaceLandmarkerImpl; const FaceLandmarker = FaceLandmarkerImpl;
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl; const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
@ -51,6 +52,7 @@ export {
DrawingUtils, DrawingUtils,
FilesetResolver, FilesetResolver,
MPImage, MPImage,
MPImageStorageType,
FaceDetector, FaceDetector,
FaceLandmarker, FaceLandmarker,
FaceLandmarksConnections, FaceLandmarksConnections,

View File

@ -16,7 +16,7 @@
export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/core/fileset_resolver';
export * from '../../../tasks/web/vision/core/drawing_utils'; export * from '../../../tasks/web/vision/core/drawing_utils';
export * from '../../../tasks/web/vision/core/image'; export {MPImage, MPImageChannelConverter, MPImageStorageType} from '../../../tasks/web/vision/core/image';
export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';