Update FaceStylizer to return MPImage
PiperOrigin-RevId: 527980696
This commit is contained in:
parent
5cffb3973f
commit
a544098100
|
@ -60,6 +60,7 @@ mediapipe_ts_library(
|
||||||
name = "vision_task_runner",
|
name = "vision_task_runner",
|
||||||
srcs = ["vision_task_runner.ts"],
|
srcs = ["vision_task_runner.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":image",
|
||||||
":image_processing_options",
|
":image_processing_options",
|
||||||
":vision_task_options",
|
":vision_task_options",
|
||||||
"//mediapipe/framework/formats:rect_jspb_proto",
|
"//mediapipe/framework/formats:rect_jspb_proto",
|
||||||
|
|
|
@ -59,8 +59,11 @@ function assertNotNull<T>(value: T|null, msg: string): T {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Move internal-only types to different module.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility class that encapsulates the buffers used by `MPImageShaderContext`.
|
* Utility class that encapsulates the buffers used by `MPImageShaderContext`.
|
||||||
|
* For internal use only.
|
||||||
*/
|
*/
|
||||||
class MPImageShaderBuffers {
|
class MPImageShaderBuffers {
|
||||||
constructor(
|
constructor(
|
||||||
|
@ -87,6 +90,8 @@ class MPImageShaderBuffers {
|
||||||
/**
|
/**
|
||||||
* A class that encapsulates the shaders used by an MPImage. Can be re-used
|
* A class that encapsulates the shaders used by an MPImage. Can be re-used
|
||||||
* across MPImages that use the same WebGL2Rendering context.
|
* across MPImages that use the same WebGL2Rendering context.
|
||||||
|
*
|
||||||
|
* For internal use only.
|
||||||
*/
|
*/
|
||||||
export class MPImageShaderContext {
|
export class MPImageShaderContext {
|
||||||
private gl?: WebGL2RenderingContext;
|
private gl?: WebGL2RenderingContext;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import {NormalizedRect} from '../../../../framework/formats/rect_pb';
|
import {NormalizedRect} from '../../../../framework/formats/rect_pb';
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {MPImage, MPImageShaderContext} from '../../../../tasks/web/vision/core/image';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
|
import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
|
||||||
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
@ -51,6 +52,8 @@ function createCanvas(): HTMLCanvasElement|OffscreenCanvas|undefined {
|
||||||
|
|
||||||
/** Base class for all MediaPipe Vision Tasks. */
|
/** Base class for all MediaPipe Vision Tasks. */
|
||||||
export abstract class VisionTaskRunner extends TaskRunner {
|
export abstract class VisionTaskRunner extends TaskRunner {
|
||||||
|
private readonly shaderContext = new MPImageShaderContext();
|
||||||
|
|
||||||
protected static async createVisionInstance<T extends VisionTaskRunner>(
|
protected static async createVisionInstance<T extends VisionTaskRunner>(
|
||||||
type: WasmMediaPipeConstructor<T>, fileset: WasmFileset,
|
type: WasmMediaPipeConstructor<T>, fileset: WasmFileset,
|
||||||
options: VisionTaskOptions): Promise<T> {
|
options: VisionTaskOptions): Promise<T> {
|
||||||
|
@ -219,30 +222,55 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts the RGB or RGBA Uint8Array of a WasmImage to ImageData. */
|
/**
|
||||||
protected convertToImageData(wasmImage: WasmImage): ImageData {
|
* Converts a WasmImage to an MPImage.
|
||||||
|
*
|
||||||
|
* Converts the underlying Uint8ClampedArray-backed images to ImageData
|
||||||
|
* (adding an alpha channel if necessary), passes through WebGLTextures and
|
||||||
|
* throws for Float32Array-backed images.
|
||||||
|
*/
|
||||||
|
protected convertToMPImage(wasmImage: WasmImage): MPImage {
|
||||||
const {data, width, height} = wasmImage;
|
const {data, width, height} = wasmImage;
|
||||||
if (!(data instanceof Uint8ClampedArray)) {
|
|
||||||
throw new Error(
|
|
||||||
'Only Uint8ClampedArray-based images can be converted to ImageData');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.length === width * height * 4) {
|
if (data instanceof Uint8ClampedArray) {
|
||||||
return new ImageData(data, width, height);
|
let rgba: Uint8ClampedArray;
|
||||||
} else if (data.length === width * height * 3) {
|
if (data.length === width * height * 4) {
|
||||||
const rgba = new Uint8ClampedArray(width * height * 4);
|
rgba = data;
|
||||||
for (let i = 0; i < width * height; ++i) {
|
} else if (data.length === width * height * 3) {
|
||||||
rgba[4 * i] = data[3 * i];
|
// TODO: Convert in C++
|
||||||
rgba[4 * i + 1] = data[3 * i + 1];
|
rgba = new Uint8ClampedArray(width * height * 4);
|
||||||
rgba[4 * i + 2] = data[3 * i + 2];
|
for (let i = 0; i < width * height; ++i) {
|
||||||
rgba[4 * i + 3] = 255;
|
rgba[4 * i] = data[3 * i];
|
||||||
|
rgba[4 * i + 1] = data[3 * i + 1];
|
||||||
|
rgba[4 * i + 2] = data[3 * i + 2];
|
||||||
|
rgba[4 * i + 3] = 255;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
`Unsupported channel count: ${data.length / width / height}`);
|
||||||
}
|
}
|
||||||
return new ImageData(rgba, width, height);
|
|
||||||
|
return new MPImage(
|
||||||
|
[new ImageData(rgba, width, height)],
|
||||||
|
/* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
|
||||||
|
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
|
||||||
|
height);
|
||||||
|
} else if (data instanceof WebGLTexture) {
|
||||||
|
return new MPImage(
|
||||||
|
[data], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
|
||||||
|
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
|
||||||
|
height);
|
||||||
} else {
|
} else {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Unsupported channel count: ${data.length / width / height}`);
|
`Cannot convert type ${data.constructor.name} to MPImage.`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the resources held by this task. */
|
||||||
|
override close(): void {
|
||||||
|
this.shaderContext.close();
|
||||||
|
super.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/vision/core:image",
|
||||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/web/vision/core:types",
|
"//mediapipe/tasks/web/vision/core:types",
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
|
@ -46,6 +47,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/framework:calculator_jspb_proto",
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
"//mediapipe/tasks/web/vision/core:image",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb';
|
import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {MPImage} from '../../../../tasks/web/vision/core/image';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
@ -39,15 +40,13 @@ const FACE_STYLIZER_GRAPH =
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A callback that receives an image from the face stylizer, or `null` if no
|
* A callback that receives an `MPImage` object from the face stylizer, or
|
||||||
* face was detected. The lifetime of the underlying data is limited to the
|
* `null` if no face was detected. The lifetime of the underlying data is
|
||||||
* duration of the callback. If asynchronous processing is needed, all data
|
* limited to the duration of the callback. If asynchronous processing is
|
||||||
* needs to be copied before the callback returns.
|
* needed, all data needs to be copied before the callback returns (via
|
||||||
*
|
* `image.clone()`).
|
||||||
* The `WebGLTexture` output type is reserved for future usage.
|
|
||||||
*/
|
*/
|
||||||
export type FaceStylizerCallback =
|
export type FaceStylizerCallback = (image: MPImage|null) => void;
|
||||||
(image: ImageData|WebGLTexture|null, width: number, height: number) => void;
|
|
||||||
|
|
||||||
/** Performs face stylization on images. */
|
/** Performs face stylization on images. */
|
||||||
export class FaceStylizer extends VisionTaskRunner {
|
export class FaceStylizer extends VisionTaskRunner {
|
||||||
|
@ -270,18 +269,14 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
STYLIZED_IMAGE_STREAM, (image, timestamp) => {
|
STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => {
|
||||||
if (image.data instanceof WebGLTexture) {
|
const mpImage = this.convertToMPImage(wasmImage);
|
||||||
this.userCallback(image.data, image.width, image.height);
|
this.userCallback(mpImage);
|
||||||
} else {
|
|
||||||
const imageData = this.convertToImageData(image);
|
|
||||||
this.userCallback(imageData, image.width, image.height);
|
|
||||||
}
|
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
STYLIZED_IMAGE_STREAM, timestamp => {
|
STYLIZED_IMAGE_STREAM, timestamp => {
|
||||||
this.userCallback(null, /* width= */ 0, /* height= */ 0);
|
this.userCallback(null);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ import 'jasmine';
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
import {MPImageStorageType} from '../../../../tasks/web/vision/core/image';
|
||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
|
||||||
import {FaceStylizer} from './face_stylizer';
|
import {FaceStylizer} from './face_stylizer';
|
||||||
|
@ -114,11 +115,33 @@ describe('FaceStylizer', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the face stylizeer
|
// Invoke the face stylizeer
|
||||||
faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => {
|
faceStylizer.stylize({} as HTMLImageElement, image => {
|
||||||
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(image).toBeInstanceOf(ImageData);
|
expect(image).not.toBeNull();
|
||||||
expect(width).toEqual(1);
|
expect(image!.hasType(MPImageStorageType.IMAGE_DATA)).toBeTrue();
|
||||||
expect(height).toEqual(1);
|
expect(image!.width).toEqual(1);
|
||||||
|
expect(image!.height).toEqual(1);
|
||||||
|
done();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('invokes callback even when no faes are detected', (done) => {
|
||||||
|
if (typeof ImageData === 'undefined') {
|
||||||
|
console.log('ImageData tests are not supported on Node');
|
||||||
|
done();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(faceStylizer);
|
||||||
|
faceStylizer.emptyPacketListener!(/* timestamp= */ 1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the face stylizeer
|
||||||
|
faceStylizer.stylize({} as HTMLImageElement, image => {
|
||||||
|
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(image).toBeNull();
|
||||||
done();
|
done();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -131,11 +154,9 @@ describe('FaceStylizer', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the face stylizeer
|
// Invoke the face stylizeer
|
||||||
faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => {
|
faceStylizer.stylize({} as HTMLImageElement, image => {
|
||||||
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(image).toBeNull();
|
expect(image).toBeNull();
|
||||||
expect(width).toEqual(0);
|
|
||||||
expect(height).toEqual(0);
|
|
||||||
done();
|
done();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||||
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
|
import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils';
|
||||||
import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image';
|
import {MPImage as MPImageImpl, MPImageStorageType as MPImageStorageTypeImpl} from '../../../tasks/web/vision/core/image';
|
||||||
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
||||||
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||||
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||||
|
@ -34,6 +34,7 @@ import {PoseLandmarker as PoseLandmarkerImpl} from '../../../tasks/web/vision/po
|
||||||
const DrawingUtils = DrawingUtilsImpl;
|
const DrawingUtils = DrawingUtilsImpl;
|
||||||
const FilesetResolver = FilesetResolverImpl;
|
const FilesetResolver = FilesetResolverImpl;
|
||||||
const MPImage = MPImageImpl;
|
const MPImage = MPImageImpl;
|
||||||
|
const MPImageStorageType = MPImageStorageTypeImpl;
|
||||||
const FaceDetector = FaceDetectorImpl;
|
const FaceDetector = FaceDetectorImpl;
|
||||||
const FaceLandmarker = FaceLandmarkerImpl;
|
const FaceLandmarker = FaceLandmarkerImpl;
|
||||||
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
|
const FaceLandmarksConnections = FaceLandmarksConnectionsImpl;
|
||||||
|
@ -51,6 +52,7 @@ export {
|
||||||
DrawingUtils,
|
DrawingUtils,
|
||||||
FilesetResolver,
|
FilesetResolver,
|
||||||
MPImage,
|
MPImage,
|
||||||
|
MPImageStorageType,
|
||||||
FaceDetector,
|
FaceDetector,
|
||||||
FaceLandmarker,
|
FaceLandmarker,
|
||||||
FaceLandmarksConnections,
|
FaceLandmarksConnections,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
export * from '../../../tasks/web/core/fileset_resolver';
|
export * from '../../../tasks/web/core/fileset_resolver';
|
||||||
export * from '../../../tasks/web/vision/core/drawing_utils';
|
export * from '../../../tasks/web/vision/core/drawing_utils';
|
||||||
export * from '../../../tasks/web/vision/core/image';
|
export {MPImage, MPImageChannelConverter, MPImageStorageType} from '../../../tasks/web/vision/core/image';
|
||||||
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
||||||
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||||
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user