Create non-callback APIs for APIs that return callbacks.

PiperOrigin-RevId: 529799515
This commit is contained in:
Sebastian Schmidt 2023-05-05 13:26:29 -07:00 committed by Copybara-Service
parent ecc8dca8ba
commit f065910559
10 changed files with 465 additions and 132 deletions

View File

@ -273,9 +273,13 @@ export class MPImage {
case MPImageType.IMAGE_DATA: case MPImageType.IMAGE_DATA:
return this.containers.find(img => img instanceof ImageData); return this.containers.find(img => img instanceof ImageData);
case MPImageType.IMAGE_BITMAP: case MPImageType.IMAGE_BITMAP:
return this.containers.find(img => img instanceof ImageBitmap); return this.containers.find(
img => typeof ImageBitmap !== 'undefined' &&
img instanceof ImageBitmap);
case MPImageType.WEBGL_TEXTURE: case MPImageType.WEBGL_TEXTURE:
return this.containers.find(img => img instanceof WebGLTexture); return this.containers.find(
img => typeof WebGLTexture !== 'undefined' &&
img instanceof WebGLTexture);
default: default:
throw new Error(`Type is not supported: ${type}`); throw new Error(`Type is not supported: ${type}`);
} }

View File

@ -230,7 +230,8 @@ export abstract class VisionTaskRunner extends TaskRunner {
* (adding an alpha channel if necessary), passes through WebGLTextures and * (adding an alpha channel if necessary), passes through WebGLTextures and
* throws for Float32Array-backed images. * throws for Float32Array-backed images.
*/ */
protected convertToMPImage(wasmImage: WasmImage): MPImage { protected convertToMPImage(wasmImage: WasmImage, shouldCopyData: boolean):
MPImage {
const {data, width, height} = wasmImage; const {data, width, height} = wasmImage;
const pixels = width * height; const pixels = width * height;
@ -263,10 +264,11 @@ export abstract class VisionTaskRunner extends TaskRunner {
container = data; container = data;
} }
return new MPImage( const image = new MPImage(
[container], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, [container], /* ownsImageBitmap= */ false,
this.graphRunner.wasmModule.canvas!, this.shaderContext, width, /* ownsWebGLTexture= */ false, this.graphRunner.wasmModule.canvas!,
height); this.shaderContext, width, height);
return shouldCopyData ? image.clone() : image;
} }
/** Closes and cleans up the resources held by this task. */ /** Closes and cleans up the resources held by this task. */

View File

@ -50,7 +50,8 @@ export type FaceStylizerCallback = (image: MPImage|null) => void;
/** Performs face stylization on images. */ /** Performs face stylization on images. */
export class FaceStylizer extends VisionTaskRunner { export class FaceStylizer extends VisionTaskRunner {
private userCallback: FaceStylizerCallback = () => {}; private userCallback?: FaceStylizerCallback;
private result?: MPImage|null;
private readonly options: FaceStylizerGraphOptionsProto; private readonly options: FaceStylizerGraphOptionsProto;
/** /**
@ -130,21 +131,58 @@ export class FaceStylizer extends VisionTaskRunner {
return super.applyOptions(options); return super.applyOptions(options);
} }
/** /**
* Performs face stylization on the provided single image. The method returns * Performs face stylization on the provided single image and invokes the
* synchronously once the callback returns. Only use this method when the * callback with result. The method returns synchronously once the callback
* FaceStylizer is created with the image running mode. * returns. Only use this method when the FaceStylizer is created with the
* image running mode.
* *
* @param image An image to process. * @param image An image to process.
* @param callback The callback that is invoked with the stylized image. The * @param callback The callback that is invoked with the stylized image or
* lifetime of the returned data is only guaranteed for the duration of the * `null` if no face was detected. The lifetime of the returned data is
* callback. * only guaranteed for the duration of the callback.
*/ */
stylize(image: ImageSource, callback: FaceStylizerCallback): void; stylize(image: ImageSource, callback: FaceStylizerCallback): void;
/** /**
* Performs face stylization on the provided single image. The method returns * Performs face stylization on the provided single image and invokes the
* synchronously once the callback returns. Only use this method when the * callback with result. The method returns synchronously once the callback
* returns. Only use this method when the FaceStylizer is created with the
* image running mode.
*
* The 'imageProcessingOptions' parameter can be used to specify one or all
* of:
* - the rotation to apply to the image before performing stylization, by
* setting its 'rotationDegrees' property.
* - the region-of-interest on which to perform stylization, by setting its
* 'regionOfInterest' property. If not specified, the full image is used.
* If both are specified, the crop around the region-of-interest is extracted
* first, then the specified rotation is applied to the crop.
*
* @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param callback The callback that is invoked with the stylized image or
* `null` if no face was detected. The lifetime of the returned data is
* only guaranteed for the duration of the callback.
*/
stylize(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: FaceStylizerCallback): void;
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* @param image An image to process.
* @return A stylized face or `null` if no face was detected. The result is
* copied to avoid lifetime issues.
*/
stylize(image: ImageSource): MPImage|null;
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* The 'imageProcessingOptions' parameter can be used to specify one or all * The 'imageProcessingOptions' parameter can be used to specify one or all
@ -159,18 +197,16 @@ export class FaceStylizer extends VisionTaskRunner {
* @param image An image to process. * @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input image before running inference.
* @param callback The callback that is invoked with the stylized image. The * @return A stylized face or `null` if no face was detected. The result is
* lifetime of the returned data is only guaranteed for the duration of the * copied to avoid lifetime issues.
* callback.
*/ */
stylize( stylize(image: ImageSource, imageProcessingOptions: ImageProcessingOptions):
image: ImageSource, imageProcessingOptions: ImageProcessingOptions, MPImage|null;
callback: FaceStylizerCallback): void;
stylize( stylize(
image: ImageSource, image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback?: ImageProcessingOptions|
FaceStylizerCallback, FaceStylizerCallback,
callback?: FaceStylizerCallback): void { callback?: FaceStylizerCallback): MPImage|null|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
@ -178,14 +214,19 @@ export class FaceStylizer extends VisionTaskRunner {
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback;
this.processImageData(image, imageProcessingOptions ?? {}); this.processImageData(image, imageProcessingOptions ?? {});
this.userCallback = () => {};
if (!this.userCallback) {
return this.result;
}
} }
/** /**
* Performs face stylization on the provided video frame. Only use this method * Performs face stylization on the provided video frame and invokes the
* when the FaceStylizer is created with the video running mode. * callback with result. The method returns synchronously once the callback
* returns. Only use this method when the FaceStylizer is created with the
* video running mode.
* *
* The input frame can be of any size. It's required to provide the video * The input frame can be of any size. It's required to provide the video
* frame's timestamp (in milliseconds). The input timestamps must be * frame's timestamp (in milliseconds). The input timestamps must be
@ -193,16 +234,18 @@ export class FaceStylizer extends VisionTaskRunner {
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the stylized image. The * @param callback The callback that is invoked with the stylized image or
* lifetime of the returned data is only guaranteed for the duration of * `null` if no face was detected. The lifetime of the returned data is only
* the callback. * guaranteed for the duration of the callback.
*/ */
stylizeForVideo( stylizeForVideo(
videoFrame: ImageSource, timestamp: number, videoFrame: ImageSource, timestamp: number,
callback: FaceStylizerCallback): void; callback: FaceStylizerCallback): void;
/** /**
* Performs face stylization on the provided video frame. Only use this * Performs face stylization on the provided video frame and invokes the
* method when the FaceStylizer is created with the video running mode. * callback with result. The method returns synchronously once the callback
* returns. Only use this method when the FaceStylizer is created with the
* video running mode.
* *
* The 'imageProcessingOptions' parameter can be used to specify one or all * The 'imageProcessingOptions' parameter can be used to specify one or all
* of: * of:
@ -221,18 +264,63 @@ export class FaceStylizer extends VisionTaskRunner {
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the stylized image. The * @param callback The callback that is invoked with the stylized image or
* lifetime of the returned data is only guaranteed for the duration of * `null` if no face was detected. The lifetime of the returned data is only
* the callback. * guaranteed for the duration of the callback.
*/ */
stylizeForVideo( stylizeForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: FaceStylizerCallback): void; timestamp: number, callback: FaceStylizerCallback): void;
/**
* Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the
* video running mode.
*
* The input frame can be of any size. It's required to provide the video
* frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @return A stylized face or `null` if no face was detected. The result is
* copied to avoid lifetime issues.
*/
stylizeForVideo(videoFrame: ImageSource, timestamp: number): MPImage|null;
/**
* Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the
* video running mode.
*
* The 'imageProcessingOptions' parameter can be used to specify one or all
* of:
* - the rotation to apply to the image before performing stylization, by
* setting its 'rotationDegrees' property.
* - the region-of-interest on which to perform stylization, by setting its
* 'regionOfInterest' property. If not specified, the full image is used.
* If both are specified, the crop around the region-of-interest is
* extracted first, then the specified rotation is applied to the crop.
*
* The input frame can be of any size. It's required to provide the video
* frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @return A stylized face or `null` if no face was detected. The result is
* copied to avoid lifetime issues.
*/
stylizeForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number): MPImage|null;
stylizeForVideo( stylizeForVideo(
videoFrame: ImageSource, videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions, timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|FaceStylizerCallback, timestampOrCallback?: number|FaceStylizerCallback,
callback?: FaceStylizerCallback): void { callback?: FaceStylizerCallback): MPImage|null|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ? typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions : timestampOrImageProcessingOptions :
@ -243,9 +331,13 @@ export class FaceStylizer extends VisionTaskRunner {
this.userCallback = typeof timestampOrCallback === 'function' ? this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback : timestampOrCallback :
callback!; callback;
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {}; this.userCallback = undefined;
if (!this.userCallback) {
return this.result;
}
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -270,13 +362,20 @@ export class FaceStylizer extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => { STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => {
const mpImage = this.convertToMPImage(wasmImage); const mpImage = this.convertToMPImage(
wasmImage, /* shouldCopyData= */ !this.userCallback);
this.result = mpImage;
if (this.userCallback) {
this.userCallback(mpImage); this.userCallback(mpImage);
}
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
STYLIZED_IMAGE_STREAM, timestamp => { STYLIZED_IMAGE_STREAM, timestamp => {
this.result = null;
if (this.userCallback) {
this.userCallback(null); this.userCallback(null);
}
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });

View File

@ -99,6 +99,30 @@ describe('FaceStylizer', () => {
]); ]);
}); });
it('returns result', () => {
if (typeof ImageData === 'undefined') {
console.log('ImageData tests are not supported on Node');
return;
}
// Pass the test data to our listener
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer);
faceStylizer.imageListener!
({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1},
/* timestamp= */ 1337);
});
// Invoke the face stylizeer
const image = faceStylizer.stylize({} as HTMLImageElement);
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).not.toBeNull();
expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue();
expect(image!.width).toEqual(1);
expect(image!.height).toEqual(1);
image!.close();
});
it('invokes callback', (done) => { it('invokes callback', (done) => {
if (typeof ImageData === 'undefined') { if (typeof ImageData === 'undefined') {
console.log('ImageData tests are not supported on Node'); console.log('ImageData tests are not supported on Node');
@ -125,28 +149,7 @@ describe('FaceStylizer', () => {
}); });
}); });
it('invokes callback even when no faes are detected', (done) => { it('invokes callback even when no faces are detected', (done) => {
if (typeof ImageData === 'undefined') {
console.log('ImageData tests are not supported on Node');
done();
return;
}
// Pass the test data to our listener
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer);
faceStylizer.emptyPacketListener!(/* timestamp= */ 1337);
});
// Invoke the face stylizeer
faceStylizer.stylize({} as HTMLImageElement, image => {
expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(image).toBeNull();
done();
});
});
it('invokes callback even when no faes are detected', (done) => {
// Pass the test data to our listener // Pass the test data to our listener
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer); verifyListenersRegistered(faceStylizer);

View File

@ -60,7 +60,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private result: ImageSegmenterResult = {}; private result: ImageSegmenterResult = {};
private labels: string[] = []; private labels: string[] = [];
private userCallback: ImageSegmenterCallback = () => {}; private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
@ -224,22 +224,51 @@ export class ImageSegmenter extends VisionTaskRunner {
segment( segment(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions, image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: ImageSegmenterCallback): void; callback: ImageSegmenterCallback): void;
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method
* when the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
* @return The segmentation result. The data is copied to avoid lifetime
* issues.
*/
segment(image: ImageSource): ImageSegmenterResult;
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @return The segmentation result. The data is copied to avoid lifetime
* issues.
*/
segment(image: ImageSource, imageProcessingOptions: ImageProcessingOptions):
ImageSegmenterResult;
segment( segment(
image: ImageSource, image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback?: ImageProcessingOptions|
ImageSegmenterCallback, ImageSegmenterCallback,
callback?: ImageSegmenterCallback): void { callback?: ImageSegmenterCallback): ImageSegmenterResult|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback;
this.reset(); this.reset();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
if (!this.userCallback) {
return this.result;
}
} }
/** /**
@ -265,7 +294,7 @@ export class ImageSegmenter extends VisionTaskRunner {
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input frame before running inference.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The * @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the * lifetime of the returned data is only guaranteed for the duration of the
@ -274,11 +303,41 @@ export class ImageSegmenter extends VisionTaskRunner {
segmentForVideo( segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: ImageSegmenterCallback): void; timestamp: number, callback: ImageSegmenterCallback): void;
/**
* Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method
* when the ImageSegmenter is created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @return The segmentation result. The data is copied to avoid lifetime
* issues.
*/
segmentForVideo(videoFrame: ImageSource, timestamp: number):
ImageSegmenterResult;
/**
* Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* the ImageSegmenter is created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input frame before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @return The segmentation result. The data is copied to avoid lifetime
* issues.
*/
segmentForVideo(
videoFrame: ImageSource,
imageProcessingOptions: ImageProcessingOptions,
timestamp: number,
): ImageSegmenterResult;
segmentForVideo( segmentForVideo(
videoFrame: ImageSource, videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions, timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|ImageSegmenterCallback, timestampOrCallback?: number|ImageSegmenterCallback,
callback?: ImageSegmenterCallback): void { callback?: ImageSegmenterCallback): ImageSegmenterResult|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ? typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions : timestampOrImageProcessingOptions :
@ -288,11 +347,14 @@ export class ImageSegmenter extends VisionTaskRunner {
timestampOrCallback as number; timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ? this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback : timestampOrCallback :
callback!; callback;
this.reset(); this.reset();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
if (!this.userCallback) {
return this.result;
}
} }
/** /**
@ -323,8 +385,10 @@ export class ImageSegmenter extends VisionTaskRunner {
return; return;
} }
if (this.userCallback) {
this.userCallback(this.result); this.userCallback(this.result);
} }
}
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void { protected override refreshGraph(): void {
@ -351,8 +415,9 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = this.result.confidenceMasks = masks.map(
masks.map(wasmImage => this.convertToMPImage(wasmImage)); wasmImage => this.convertToMPImage(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
}); });
@ -370,7 +435,8 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage(mask); this.result.categoryMask = this.convertToMPImage(
mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
}); });

View File

@ -292,4 +292,21 @@ describe('ImageSegmenter', () => {
}); });
}); });
}); });
it('returns result', () => {
const confidenceMask = new Float32Array([0.0]);
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
imageSegmenter.confidenceMasksListener!(
[
{data: confidenceMask, width: 1, height: 1},
],
1337);
});
const result = imageSegmenter.segment({} as HTMLImageElement);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
result.confidenceMasks![0].close();
});
}); });

View File

@ -86,7 +86,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private result: InteractiveSegmenterResult = {}; private result: InteractiveSegmenterResult = {};
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback: InteractiveSegmenterCallback = () => {}; private userCallback?: InteractiveSegmenterCallback;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto;
@ -186,14 +186,9 @@ export class InteractiveSegmenter extends VisionTaskRunner {
/** /**
* Performs interactive segmentation on the provided single image and invokes * Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a * the callback with the response. The method returns synchronously once the
* user's region of interest for segmentation. * callback returns. The `roi` parameter is used to represent a user's region
* * of interest for segmentation.
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
* of images that represent per-category segmented image mask. If the
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
* images that contains only one confidence image mask. The method returns
* synchronously once the callback returns.
* *
* @param image An image to process. * @param image An image to process.
* @param roi The region of interest for segmentation. * @param roi The region of interest for segmentation.
@ -206,8 +201,9 @@ export class InteractiveSegmenter extends VisionTaskRunner {
callback: InteractiveSegmenterCallback): void; callback: InteractiveSegmenterCallback): void;
/** /**
* Performs interactive segmentation on the provided single image and invokes * Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a * the callback with the response. The method returns synchronously once the
* user's region of interest for segmentation. * callback returns. The `roi` parameter is used to represent a user's region
* of interest for segmentation.
* *
* The 'image_processing_options' parameter can be used to specify the * The 'image_processing_options' parameter can be used to specify the
* rotation to apply to the image before performing segmentation, by setting * rotation to apply to the image before performing segmentation, by setting
@ -215,12 +211,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
* using the 'regionOfInterest' field is NOT supported and will result in an * using the 'regionOfInterest' field is NOT supported and will result in an
* error. * error.
* *
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
* of images that represent per-category segmented image mask. If the
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
* images that contains only one confidence image mask. The method returns
* synchronously once the callback returns.
*
* @param image An image to process. * @param image An image to process.
* @param roi The region of interest for segmentation. * @param roi The region of interest for segmentation.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
@ -233,23 +223,63 @@ export class InteractiveSegmenter extends VisionTaskRunner {
image: ImageSource, roi: RegionOfInterest, image: ImageSource, roi: RegionOfInterest,
imageProcessingOptions: ImageProcessingOptions, imageProcessingOptions: ImageProcessingOptions,
callback: InteractiveSegmenterCallback): void; callback: InteractiveSegmenterCallback): void;
/**
* Performs interactive segmentation on the provided video frame and returns
* the segmentation result. This method creates a copy of the resulting masks
* and should not be used in high-throughput applictions. The `roi` parameter
* is used to represent a user's region of interest for segmentation.
*
* @param image An image to process.
* @param roi The region of interest for segmentation.
* @return The segmentation result. The data is copied to avoid lifetime
* limits.
*/
segment(image: ImageSource, roi: RegionOfInterest):
InteractiveSegmenterResult;
/**
* Performs interactive segmentation on the provided video frame and returns
* the segmentation result. This method creates a copy of the resulting masks
* and should not be used in high-throughput applictions. The `roi` parameter
* is used to represent a user's region of interest for segmentation.
*
* The 'image_processing_options' parameter can be used to specify the
* rotation to apply to the image before performing segmentation, by setting
* its 'rotationDegrees' field. Note that specifying a region-of-interest
* using the 'regionOfInterest' field is NOT supported and will result in an
* error.
*
* @param image An image to process.
* @param roi The region of interest for segmentation.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @return The segmentation result. The data is copied to avoid lifetime
* limits.
*/
segment( segment(
image: ImageSource, roi: RegionOfInterest, image: ImageSource, roi: RegionOfInterest,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptions: ImageProcessingOptions):
InteractiveSegmenterResult;
segment(
image: ImageSource, roi: RegionOfInterest,
imageProcessingOptionsOrCallback?: ImageProcessingOptions|
InteractiveSegmenterCallback, InteractiveSegmenterCallback,
callback?: InteractiveSegmenterCallback): void { callback?: InteractiveSegmenterCallback): InteractiveSegmenterResult|
void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback;
this.reset(); this.reset();
this.processRenderData(roi, this.getSynctheticTimestamp()); this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
if (!this.userCallback) {
return this.result;
}
} }
private reset(): void { private reset(): void {
@ -265,8 +295,10 @@ export class InteractiveSegmenter extends VisionTaskRunner {
return; return;
} }
if (this.userCallback) {
this.userCallback(this.result); this.userCallback(this.result);
} }
}
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void { protected override refreshGraph(): void {
@ -295,8 +327,9 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = this.result.confidenceMasks = masks.map(
masks.map(wasmImage => this.convertToMPImage(wasmImage)); wasmImage => this.convertToMPImage(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
}); });
@ -314,7 +347,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage(mask); this.result.categoryMask = this.convertToMPImage(
mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
}); });

View File

@ -306,4 +306,22 @@ describe('InteractiveSegmenter', () => {
}); });
}); });
}); });
it('returns result', () => {
const confidenceMask = new Float32Array([0.0]);
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
interactiveSegmenter.confidenceMasksListener!(
[
{data: confidenceMask, width: 1, height: 1},
],
1337);
});
const result =
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
result.confidenceMasks![0].close();
});
}); });

View File

@ -64,7 +64,7 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
export class PoseLandmarker extends VisionTaskRunner { export class PoseLandmarker extends VisionTaskRunner {
private result: Partial<PoseLandmarkerResult> = {}; private result: Partial<PoseLandmarkerResult> = {};
private outputSegmentationMasks = false; private outputSegmentationMasks = false;
private userCallback: PoseLandmarkerCallback = () => {}; private userCallback?: PoseLandmarkerCallback;
private readonly options: PoseLandmarkerGraphOptions; private readonly options: PoseLandmarkerGraphOptions;
private readonly poseLandmarksDetectorGraphOptions: private readonly poseLandmarksDetectorGraphOptions:
PoseLandmarksDetectorGraphOptions; PoseLandmarksDetectorGraphOptions;
@ -200,21 +200,22 @@ export class PoseLandmarker extends VisionTaskRunner {
} }
/** /**
* Performs pose detection on the provided single image and waits * Performs pose detection on the provided single image and invokes the
* synchronously for the response. Only use this method when the * callback with the response. The method returns synchronously once the
* PoseLandmarker is created with running mode `image`. * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of * lifetime of the returned masks is only guaranteed for the duration of
* the callback. * the callback.
* @return The detected pose landmarks.
*/ */
detect(image: ImageSource, callback: PoseLandmarkerCallback): void; detect(image: ImageSource, callback: PoseLandmarkerCallback): void;
/** /**
* Performs pose detection on the provided single image and waits * Performs pose detection on the provided single image and invokes the
* synchronously for the response. Only use this method when the * callback with the response. The method returns synchronously once the
* PoseLandmarker is created with running mode `image`. * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
@ -222,16 +223,42 @@ export class PoseLandmarker extends VisionTaskRunner {
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of * lifetime of the returned masks is only guaranteed for the duration of
* the callback. * the callback.
* @return The detected pose landmarks.
*/ */
detect( detect(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions, image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: PoseLandmarkerCallback): void; callback: PoseLandmarkerCallback): void;
/**
* Performs pose detection on the provided single image and waits
* synchronously for the response. This method creates a copy of the resulting
* masks and should not be used in high-throughput applictions. Only
* use this method when the PoseLandmarker is created with running mode
* `image`.
*
* @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime
* limits.
* @return The detected pose landmarks.
*/
detect(image: ImageSource): PoseLandmarkerResult;
/**
* Performs pose detection on the provided single image and waits
* synchronously for the response. This method creates a copy of the resulting
* masks and should not be used in high-throughput applictions. Only
* use this method when the PoseLandmarker is created with running mode
* `image`.
*
* @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime
* limits.
* @return The detected pose landmarks.
*/
detect(image: ImageSource, imageProcessingOptions: ImageProcessingOptions):
PoseLandmarkerResult;
detect( detect(
image: ImageSource, image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback?: ImageProcessingOptions|
PoseLandmarkerCallback, PoseLandmarkerCallback,
callback?: PoseLandmarkerCallback): void { callback?: PoseLandmarkerCallback): PoseLandmarkerResult|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
@ -242,28 +269,32 @@ export class PoseLandmarker extends VisionTaskRunner {
this.resetResults(); this.resetResults();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
if (!this.userCallback) {
return this.result as PoseLandmarkerResult;
}
} }
/** /**
* Performs pose detection on the provided video frame and waits * Performs pose detection on the provided video frame and invokes the
* synchronously for the response. Only use this method when the * callback with the response. The method returns synchronously once the
* PoseLandmarker is created with running mode `video`. * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`.
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of * lifetime of the returned masks is only guaranteed for the duration of
* the callback. * the callback.
* @return The detected pose landmarks.
*/ */
detectForVideo( detectForVideo(
videoFrame: ImageSource, timestamp: number, videoFrame: ImageSource, timestamp: number,
callback: PoseLandmarkerCallback): void; callback: PoseLandmarkerCallback): void;
/** /**
* Performs pose detection on the provided video frame and waits * Performs pose detection on the provided video frame and invokes the
* synchronously for the response. Only use this method when the * callback with the response. The method returns synchronously once the
* PoseLandmarker is created with running mode `video`. * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`.
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
@ -272,16 +303,45 @@ export class PoseLandmarker extends VisionTaskRunner {
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of * lifetime of the returned masks is only guaranteed for the duration of
* the callback. * the callback.
* @return The detected pose landmarks.
*/ */
detectForVideo( detectForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: PoseLandmarkerCallback): void; timestamp: number, callback: PoseLandmarkerCallback): void;
/**
* Performs pose detection on the provided video frame and returns the result.
* This method creates a copy of the resulting masks and should not be used
* in high-throughput applictions. Only use this method when the
* PoseLandmarker is created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @return The landmarker result. Any masks are copied to extend the
* lifetime of the returned data.
*/
detectForVideo(videoFrame: ImageSource, timestamp: number):
PoseLandmarkerResult;
/**
* Performs pose detection on the provided video frame and returns the result.
* This method creates a copy of the resulting masks and should not be used
* in high-throughput applictions. The method returns synchronously once the
* callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @return The landmarker result. Any masks are copied to extend the lifetime
* of the returned data.
*/
detectForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number): PoseLandmarkerResult;
detectForVideo( detectForVideo(
videoFrame: ImageSource, videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions, timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|PoseLandmarkerCallback, timestampOrCallback?: number|PoseLandmarkerCallback,
callback?: PoseLandmarkerCallback): void { callback?: PoseLandmarkerCallback): PoseLandmarkerResult|void {
const imageProcessingOptions = const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ? typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions : timestampOrImageProcessingOptions :
@ -291,10 +351,14 @@ export class PoseLandmarker extends VisionTaskRunner {
timestampOrCallback as number; timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ? this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback : timestampOrCallback :
callback!; callback;
this.resetResults(); this.resetResults();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
if (!this.userCallback) {
return this.result as PoseLandmarkerResult;
}
} }
private resetResults(): void { private resetResults(): void {
@ -315,8 +379,11 @@ export class PoseLandmarker extends VisionTaskRunner {
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
return; return;
} }
if (this.userCallback) {
this.userCallback(this.result as Required<PoseLandmarkerResult>); this.userCallback(this.result as Required<PoseLandmarkerResult>);
} }
}
/** Sets the default values for the graph. */ /** Sets the default values for the graph. */
private initDefaults(): void { private initDefaults(): void {
@ -438,8 +505,9 @@ export class PoseLandmarker extends VisionTaskRunner {
'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM); 'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
SEGMENTATION_MASK_STREAM, (masks, timestamp) => { SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
this.result.segmentationMasks = this.result.segmentationMasks = masks.map(
masks.map(wasmImage => this.convertToMPImage(wasmImage)); wasmImage => this.convertToMPImage(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
}); });

View File

@ -331,4 +331,26 @@ describe('PoseLandmarker', () => {
listenerCalled = true; listenerCalled = true;
}); });
}); });
it('returns result', () => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
// Pass the test data to our listener
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
poseLandmarker.listeners.get('normalized_landmarks')!
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker
const result = poseLandmarker.detect({} as HTMLImageElement);
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
});
}); });