From f065910559c6883f78ef085418f49a674e4abad8 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 5 May 2023 13:26:29 -0700 Subject: [PATCH] Create non-callback APIs for APIs that return callbacks. PiperOrigin-RevId: 529799515 --- mediapipe/tasks/web/vision/core/image.ts | 8 +- .../web/vision/core/vision_task_runner.ts | 12 +- .../web/vision/face_stylizer/face_stylizer.ts | 173 ++++++++++++++---- .../face_stylizer/face_stylizer_test.ts | 47 ++--- .../vision/image_segmenter/image_segmenter.ts | 94 ++++++++-- .../image_segmenter/image_segmenter_test.ts | 17 ++ .../interactive_segmenter.ts | 84 ++++++--- .../interactive_segmenter_test.ts | 18 ++ .../vision/pose_landmarker/pose_landmarker.ts | 122 +++++++++--- .../pose_landmarker/pose_landmarker_test.ts | 22 +++ 10 files changed, 465 insertions(+), 132 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/image.ts b/mediapipe/tasks/web/vision/core/image.ts index e2b21c0e6..df7586ded 100644 --- a/mediapipe/tasks/web/vision/core/image.ts +++ b/mediapipe/tasks/web/vision/core/image.ts @@ -273,9 +273,13 @@ export class MPImage { case MPImageType.IMAGE_DATA: return this.containers.find(img => img instanceof ImageData); case MPImageType.IMAGE_BITMAP: - return this.containers.find(img => img instanceof ImageBitmap); + return this.containers.find( + img => typeof ImageBitmap !== 'undefined' && + img instanceof ImageBitmap); case MPImageType.WEBGL_TEXTURE: - return this.containers.find(img => img instanceof WebGLTexture); + return this.containers.find( + img => typeof WebGLTexture !== 'undefined' && + img instanceof WebGLTexture); default: throw new Error(`Type is not supported: ${type}`); } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3ff6e0604..c31195508 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -230,7 +230,8 @@ export abstract class VisionTaskRunner extends TaskRunner { * (adding an alpha channel if necessary), passes through WebGLTextures and * throws for Float32Array-backed images. */ - protected convertToMPImage(wasmImage: WasmImage): MPImage { + protected convertToMPImage(wasmImage: WasmImage, shouldCopyData: boolean): + MPImage { const {data, width, height} = wasmImage; const pixels = width * height; @@ -263,10 +264,11 @@ export abstract class VisionTaskRunner extends TaskRunner { container = data; } - return new MPImage( - [container], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, - this.graphRunner.wasmModule.canvas!, this.shaderContext, width, - height); + const image = new MPImage( + [container], /* ownsImageBitmap= */ false, + /* ownsWebGLTexture= */ false, this.graphRunner.wasmModule.canvas!, + this.shaderContext, width, height); + return shouldCopyData ? image.clone() : image; } /** Closes and cleans up the resources held by this task. */ diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts index 2a9adb315..641ab61d2 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -50,7 +50,8 @@ export type FaceStylizerCallback = (image: MPImage|null) => void; /** Performs face stylization on images. */ export class FaceStylizer extends VisionTaskRunner { - private userCallback: FaceStylizerCallback = () => {}; + private userCallback?: FaceStylizerCallback; + private result?: MPImage|null; private readonly options: FaceStylizerGraphOptionsProto; /** @@ -130,21 +131,58 @@ export class FaceStylizer extends VisionTaskRunner { return super.applyOptions(options); } - /** - * Performs face stylization on the provided single image. The method returns - * synchronously once the callback returns. Only use this method when the - * FaceStylizer is created with the image running mode. + * Performs face stylization on the provided single image and invokes the + * callback with result. The method returns synchronously once the callback + * returns. Only use this method when the FaceStylizer is created with the + * image running mode. * * @param image An image to process. - * @param callback The callback that is invoked with the stylized image. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. + * @param callback The callback that is invoked with the stylized image or + * `null` if no face was detected. The lifetime of the returned data is + * only guaranteed for the duration of the callback. */ stylize(image: ImageSource, callback: FaceStylizerCallback): void; /** - * Performs face stylization on the provided single image. The method returns - * synchronously once the callback returns. Only use this method when the + * Performs face stylization on the provided single image and invokes the + * callback with result. The method returns synchronously once the callback + * returns. Only use this method when the FaceStylizer is created with the + * image running mode. + * + * The 'imageProcessingOptions' parameter can be used to specify one or all + * of: + * - the rotation to apply to the image before performing stylization, by + * setting its 'rotationDegrees' property. + * - the region-of-interest on which to perform stylization, by setting its + * 'regionOfInterest' property. If not specified, the full image is used. + * If both are specified, the crop around the region-of-interest is extracted + * first, then the specified rotation is applied to the crop. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the stylized image or + * `null` if no face was detected. The lifetime of the returned data is + * only guaranteed for the duration of the callback. + */ + stylize( + image: ImageSource, imageProcessingOptions: ImageProcessingOptions, + callback: FaceStylizerCallback): void; + /** + * Performs face stylization on the provided single image and returns the + * result. This method creates a copy of the resulting image and should not be + * used in high-throughput applictions. Only use this method when the + * FaceStylizer is created with the image running mode. + * + * @param image An image to process. + * @return A stylized face or `null` if no face was detected. The result is + * copied to avoid lifetime issues. + */ + stylize(image: ImageSource): MPImage|null; + /** + * Performs face stylization on the provided single image and returns the + * result. This method creates a copy of the resulting image and should not be + * used in high-throughput applictions. Only use this method when the * FaceStylizer is created with the image running mode. * * The 'imageProcessingOptions' parameter can be used to specify one or all @@ -159,18 +197,16 @@ export class FaceStylizer extends VisionTaskRunner { * @param image An image to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. - * @param callback The callback that is invoked with the stylized image. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. + * @return A stylized face or `null` if no face was detected. The result is + * copied to avoid lifetime issues. */ - stylize( - image: ImageSource, imageProcessingOptions: ImageProcessingOptions, - callback: FaceStylizerCallback): void; + stylize(image: ImageSource, imageProcessingOptions: ImageProcessingOptions): + MPImage|null; stylize( image: ImageSource, - imageProcessingOptionsOrCallback: ImageProcessingOptions| + imageProcessingOptionsOrCallback?: ImageProcessingOptions| FaceStylizerCallback, - callback?: FaceStylizerCallback): void { + callback?: FaceStylizerCallback): MPImage|null|void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : @@ -178,14 +214,19 @@ export class FaceStylizer extends VisionTaskRunner { this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : - callback!; + callback; this.processImageData(image, imageProcessingOptions ?? {}); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result; + } } /** - * Performs face stylization on the provided video frame. Only use this method - * when the FaceStylizer is created with the video running mode. + * Performs face stylization on the provided video frame and invokes the + * callback with result. The method returns synchronously once the callback + * returns. Only use this method when the FaceStylizer is created with the + * video running mode. * * The input frame can be of any size. It's required to provide the video * frame's timestamp (in milliseconds). The input timestamps must be @@ -193,16 +234,18 @@ export class FaceStylizer extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the stylized image. The - * lifetime of the returned data is only guaranteed for the duration of - * the callback. + * @param callback The callback that is invoked with the stylized image or + * `null` if no face was detected. The lifetime of the returned data is only + * guaranteed for the duration of the callback. */ stylizeForVideo( videoFrame: ImageSource, timestamp: number, callback: FaceStylizerCallback): void; /** - * Performs face stylization on the provided video frame. Only use this - * method when the FaceStylizer is created with the video running mode. + * Performs face stylization on the provided video frame and invokes the + * callback with result. The method returns synchronously once the callback + * returns. Only use this method when the FaceStylizer is created with the + * video running mode. * * The 'imageProcessingOptions' parameter can be used to specify one or all * of: @@ -221,18 +264,63 @@ export class FaceStylizer extends VisionTaskRunner { * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the stylized image. The - * lifetime of the returned data is only guaranteed for the duration of - * the callback. + * @param callback The callback that is invoked with the stylized image or + * `null` if no face was detected. The lifetime of the returned data is only + * guaranteed for the duration of the callback. */ stylizeForVideo( videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, timestamp: number, callback: FaceStylizerCallback): void; + /** + * Performs face stylization on the provided video frame. This method creates + * a copy of the resulting image and should not be used in high-throughput + * applictions. Only use this method when the FaceStylizer is created with the + * video running mode. + * + * The input frame can be of any size. It's required to provide the video + * frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return A stylized face or `null` if no face was detected. The result is + * copied to avoid lifetime issues. + */ + stylizeForVideo(videoFrame: ImageSource, timestamp: number): MPImage|null; + /** + * Performs face stylization on the provided video frame. This method creates + * a copy of the resulting image and should not be used in high-throughput + * applictions. Only use this method when the FaceStylizer is created with the + * video running mode. + * + * The 'imageProcessingOptions' parameter can be used to specify one or all + * of: + * - the rotation to apply to the image before performing stylization, by + * setting its 'rotationDegrees' property. + * - the region-of-interest on which to perform stylization, by setting its + * 'regionOfInterest' property. If not specified, the full image is used. + * If both are specified, the crop around the region-of-interest is + * extracted first, then the specified rotation is applied to the crop. + * + * The input frame can be of any size. It's required to provide the video + * frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @return A stylized face or `null` if no face was detected. The result is + * copied to avoid lifetime issues. + */ + stylizeForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number): MPImage|null; stylizeForVideo( videoFrame: ImageSource, timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|FaceStylizerCallback, - callback?: FaceStylizerCallback): void { + timestampOrCallback?: number|FaceStylizerCallback, + callback?: FaceStylizerCallback): MPImage|null|void { const imageProcessingOptions = typeof timestampOrImageProcessingOptions !== 'number' ? timestampOrImageProcessingOptions : @@ -243,9 +331,13 @@ export class FaceStylizer extends VisionTaskRunner { this.userCallback = typeof timestampOrCallback === 'function' ? timestampOrCallback : - callback!; + callback; this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + this.userCallback = undefined; + + if (!this.userCallback) { + return this.result; + } } /** Updates the MediaPipe graph configuration. */ @@ -270,13 +362,20 @@ export class FaceStylizer extends VisionTaskRunner { this.graphRunner.attachImageListener( STYLIZED_IMAGE_STREAM, (wasmImage, timestamp) => { - const mpImage = this.convertToMPImage(wasmImage); - this.userCallback(mpImage); + const mpImage = this.convertToMPImage( + wasmImage, /* shouldCopyData= */ !this.userCallback); + this.result = mpImage; + if (this.userCallback) { + this.userCallback(mpImage); + } this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( STYLIZED_IMAGE_STREAM, timestamp => { - this.userCallback(null); + this.result = null; + if (this.userCallback) { + this.userCallback(null); + } this.setLatestOutputTimestamp(timestamp); }); diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts index 17764c9e5..8ea8e0f94 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts @@ -99,6 +99,30 @@ describe('FaceStylizer', () => { ]); }); + it('returns result', () => { + if (typeof ImageData === 'undefined') { + console.log('ImageData tests are not supported on Node'); + return; + } + + // Pass the test data to our listener + faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceStylizer); + faceStylizer.imageListener! + ({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1}, + /* timestamp= */ 1337); + }); + + // Invoke the face stylizeer + const image = faceStylizer.stylize({} as HTMLImageElement); + expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(image).not.toBeNull(); + expect(image!.has(MPImage.TYPE.IMAGE_DATA)).toBeTrue(); + expect(image!.width).toEqual(1); + expect(image!.height).toEqual(1); + image!.close(); + }); + it('invokes callback', (done) => { if (typeof ImageData === 'undefined') { console.log('ImageData tests are not supported on Node'); @@ -125,28 +149,7 @@ describe('FaceStylizer', () => { }); }); - it('invokes callback even when no faes are detected', (done) => { - if (typeof ImageData === 'undefined') { - console.log('ImageData tests are not supported on Node'); - done(); - return; - } - - // Pass the test data to our listener - faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(faceStylizer); - faceStylizer.emptyPacketListener!(/* timestamp= */ 1337); - }); - - // Invoke the face stylizeer - faceStylizer.stylize({} as HTMLImageElement, image => { - expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(image).toBeNull(); - done(); - }); - }); - - it('invokes callback even when no faes are detected', (done) => { + it('invokes callback even when no faces are detected', (done) => { // Pass the test data to our listener faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(faceStylizer); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 60b965345..4d0ac18f2 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -60,7 +60,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void; export class ImageSegmenter extends VisionTaskRunner { private result: ImageSegmenterResult = {}; private labels: string[] = []; - private userCallback: ImageSegmenterCallback = () => {}; + private userCallback?: ImageSegmenterCallback; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; @@ -224,22 +224,51 @@ export class ImageSegmenter extends VisionTaskRunner { segment( image: ImageSource, imageProcessingOptions: ImageProcessingOptions, callback: ImageSegmenterCallback): void; + /** + * Performs image segmentation on the provided single image and returns the + * segmentation result. This method creates a copy of the resulting masks and + * should not be used in high-throughput applictions. Only use this method + * when the ImageSegmenter is created with running mode `image`. + * + * @param image An image to process. + * @return The segmentation result. The data is copied to avoid lifetime + * issues. + */ + segment(image: ImageSource): ImageSegmenterResult; + /** + * Performs image segmentation on the provided single image and returns the + * segmentation result. This method creates a copy of the resulting masks and + * should not be used in high-v applictions. Only use this method when + * the ImageSegmenter is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The segmentation result. The data is copied to avoid lifetime + * issues. + */ + segment(image: ImageSource, imageProcessingOptions: ImageProcessingOptions): + ImageSegmenterResult; segment( image: ImageSource, - imageProcessingOptionsOrCallback: ImageProcessingOptions| + imageProcessingOptionsOrCallback?: ImageProcessingOptions| ImageSegmenterCallback, - callback?: ImageSegmenterCallback): void { + callback?: ImageSegmenterCallback): ImageSegmenterResult|void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : - callback!; + callback; this.reset(); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result; + } } /** @@ -265,7 +294,7 @@ export class ImageSegmenter extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how - * to process the input image before running inference. + * to process the input frame before running inference. * @param timestamp The timestamp of the current frame, in ms. * @param callback The callback that is invoked with the segmented masks. The * lifetime of the returned data is only guaranteed for the duration of the @@ -274,11 +303,41 @@ export class ImageSegmenter extends VisionTaskRunner { segmentForVideo( videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, timestamp: number, callback: ImageSegmenterCallback): void; + /** + * Performs image segmentation on the provided video frame and returns the + * segmentation result. This method creates a copy of the resulting masks and + * should not be used in high-throughput applictions. Only use this method + * when the ImageSegmenter is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @return The segmentation result. The data is copied to avoid lifetime + * issues. + */ + segmentForVideo(videoFrame: ImageSource, timestamp: number): + ImageSegmenterResult; + /** + * Performs image segmentation on the provided video frame and returns the + * segmentation result. This method creates a copy of the resulting masks and + * should not be used in high-v applictions. Only use this method when + * the ImageSegmenter is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input frame before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @return The segmentation result. The data is copied to avoid lifetime + * issues. + */ + segmentForVideo( + videoFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions, + timestamp: number, + ): ImageSegmenterResult; segmentForVideo( videoFrame: ImageSource, timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|ImageSegmenterCallback, - callback?: ImageSegmenterCallback): void { + timestampOrCallback?: number|ImageSegmenterCallback, + callback?: ImageSegmenterCallback): ImageSegmenterResult|void { const imageProcessingOptions = typeof timestampOrImageProcessingOptions !== 'number' ? timestampOrImageProcessingOptions : @@ -288,11 +347,14 @@ export class ImageSegmenter extends VisionTaskRunner { timestampOrCallback as number; this.userCallback = typeof timestampOrCallback === 'function' ? timestampOrCallback : - callback!; + callback; this.reset(); this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result; + } } /** @@ -323,7 +385,9 @@ export class ImageSegmenter extends VisionTaskRunner { return; } - this.userCallback(this.result); + if (this.userCallback) { + this.userCallback(this.result); + } } /** Updates the MediaPipe graph configuration. */ @@ -351,8 +415,9 @@ export class ImageSegmenter extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { - this.result.confidenceMasks = - masks.map(wasmImage => this.convertToMPImage(wasmImage)); + this.result.confidenceMasks = masks.map( + wasmImage => this.convertToMPImage( + wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); this.maybeInvokeCallback(); }); @@ -370,7 +435,8 @@ export class ImageSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { - this.result.categoryMask = this.convertToMPImage(mask); + this.result.categoryMask = this.convertToMPImage( + mask, /* shouldCopyData= */ !this.userCallback); this.setLatestOutputTimestamp(timestamp); this.maybeInvokeCallback(); }); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index 8c8767ec7..f9a4fe8a6 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -292,4 +292,21 @@ describe('ImageSegmenter', () => { }); }); }); + + it('returns result', () => { + const confidenceMask = new Float32Array([0.0]); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + imageSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask, width: 1, height: 1}, + ], + 1337); + }); + + const result = imageSegmenter.segment({} as HTMLImageElement); + expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); + result.confidenceMasks![0].close(); + }); }); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 60ec9e1c5..72dfd3834 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -86,7 +86,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { private result: InteractiveSegmenterResult = {}; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; - private userCallback: InteractiveSegmenterCallback = () => {}; + private userCallback?: InteractiveSegmenterCallback; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -186,14 +186,9 @@ export class InteractiveSegmenter extends VisionTaskRunner { /** * Performs interactive segmentation on the provided single image and invokes - * the callback with the response. The `roi` parameter is used to represent a - * user's region of interest for segmentation. - * - * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector - * of images that represent per-category segmented image mask. If the - * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of - * images that contains only one confidence image mask. The method returns - * synchronously once the callback returns. + * the callback with the response. The method returns synchronously once the + * callback returns. The `roi` parameter is used to represent a user's region + * of interest for segmentation. * * @param image An image to process. * @param roi The region of interest for segmentation. @@ -206,8 +201,9 @@ export class InteractiveSegmenter extends VisionTaskRunner { callback: InteractiveSegmenterCallback): void; /** * Performs interactive segmentation on the provided single image and invokes - * the callback with the response. The `roi` parameter is used to represent a - * user's region of interest for segmentation. + * the callback with the response. The method returns synchronously once the + * callback returns. The `roi` parameter is used to represent a user's region + * of interest for segmentation. * * The 'image_processing_options' parameter can be used to specify the * rotation to apply to the image before performing segmentation, by setting @@ -215,12 +211,6 @@ export class InteractiveSegmenter extends VisionTaskRunner { * using the 'regionOfInterest' field is NOT supported and will result in an * error. * - * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector - * of images that represent per-category segmented image mask. If the - * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of - * images that contains only one confidence image mask. The method returns - * synchronously once the callback returns. - * * @param image An image to process. * @param roi The region of interest for segmentation. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how @@ -233,23 +223,63 @@ export class InteractiveSegmenter extends VisionTaskRunner { image: ImageSource, roi: RegionOfInterest, imageProcessingOptions: ImageProcessingOptions, callback: InteractiveSegmenterCallback): void; + /** + * Performs interactive segmentation on the provided video frame and returns + * the segmentation result. This method creates a copy of the resulting masks + * and should not be used in high-throughput applictions. The `roi` parameter + * is used to represent a user's region of interest for segmentation. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @return The segmentation result. The data is copied to avoid lifetime + * limits. + */ + segment(image: ImageSource, roi: RegionOfInterest): + InteractiveSegmenterResult; + /** + * Performs interactive segmentation on the provided video frame and returns + * the segmentation result. This method creates a copy of the resulting masks + * and should not be used in high-throughput applictions. The `roi` parameter + * is used to represent a user's region of interest for segmentation. + * + * The 'image_processing_options' parameter can be used to specify the + * rotation to apply to the image before performing segmentation, by setting + * its 'rotationDegrees' field. Note that specifying a region-of-interest + * using the 'regionOfInterest' field is NOT supported and will result in an + * error. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The segmentation result. The data is copied to avoid lifetime + * limits. + */ segment( image: ImageSource, roi: RegionOfInterest, - imageProcessingOptionsOrCallback: ImageProcessingOptions| + imageProcessingOptions: ImageProcessingOptions): + InteractiveSegmenterResult; + segment( + image: ImageSource, roi: RegionOfInterest, + imageProcessingOptionsOrCallback?: ImageProcessingOptions| InteractiveSegmenterCallback, - callback?: InteractiveSegmenterCallback): void { + callback?: InteractiveSegmenterCallback): InteractiveSegmenterResult| + void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : - callback!; + callback; this.reset(); this.processRenderData(roi, this.getSynctheticTimestamp()); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result; + } } private reset(): void { @@ -265,7 +295,9 @@ export class InteractiveSegmenter extends VisionTaskRunner { return; } - this.userCallback(this.result); + if (this.userCallback) { + this.userCallback(this.result); + } } /** Updates the MediaPipe graph configuration. */ @@ -295,8 +327,9 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { - this.result.confidenceMasks = - masks.map(wasmImage => this.convertToMPImage(wasmImage)); + this.result.confidenceMasks = masks.map( + wasmImage => this.convertToMPImage( + wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); this.maybeInvokeCallback(); }); @@ -314,7 +347,8 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { - this.result.categoryMask = this.convertToMPImage(mask); + this.result.categoryMask = this.convertToMPImage( + mask, /* shouldCopyData= */ !this.userCallback); this.setLatestOutputTimestamp(timestamp); this.maybeInvokeCallback(); }); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index a361af5a1..52742e371 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -306,4 +306,22 @@ describe('InteractiveSegmenter', () => { }); }); }); + + it('returns result', () => { + const confidenceMask = new Float32Array([0.0]); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + interactiveSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask, width: 1, height: 1}, + ], + 1337); + }); + + const result = + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT); + expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); + result.confidenceMasks![0].close(); + }); }); diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts index d21a9a6db..7c2743062 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts @@ -64,7 +64,7 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void; export class PoseLandmarker extends VisionTaskRunner { private result: Partial = {}; private outputSegmentationMasks = false; - private userCallback: PoseLandmarkerCallback = () => {}; + private userCallback?: PoseLandmarkerCallback; private readonly options: PoseLandmarkerGraphOptions; private readonly poseLandmarksDetectorGraphOptions: PoseLandmarksDetectorGraphOptions; @@ -200,21 +200,22 @@ export class PoseLandmarker extends VisionTaskRunner { } /** - * Performs pose detection on the provided single image and waits - * synchronously for the response. Only use this method when the - * PoseLandmarker is created with running mode `image`. + * Performs pose detection on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the PoseLandmarker is created + * with running mode `image`. * * @param image An image to process. * @param callback The callback that is invoked with the result. The * lifetime of the returned masks is only guaranteed for the duration of * the callback. - * @return The detected pose landmarks. */ detect(image: ImageSource, callback: PoseLandmarkerCallback): void; /** - * Performs pose detection on the provided single image and waits - * synchronously for the response. Only use this method when the - * PoseLandmarker is created with running mode `image`. + * Performs pose detection on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the PoseLandmarker is created + * with running mode `image`. * * @param image An image to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how @@ -222,16 +223,42 @@ export class PoseLandmarker extends VisionTaskRunner { * @param callback The callback that is invoked with the result. The * lifetime of the returned masks is only guaranteed for the duration of * the callback. - * @return The detected pose landmarks. */ detect( image: ImageSource, imageProcessingOptions: ImageProcessingOptions, callback: PoseLandmarkerCallback): void; + /** + * Performs pose detection on the provided single image and waits + * synchronously for the response. This method creates a copy of the resulting + * masks and should not be used in high-throughput applictions. Only + * use this method when the PoseLandmarker is created with running mode + * `image`. + * + * @param image An image to process. + * @return The landmarker result. Any masks are copied to avoid lifetime + * limits. + * @return The detected pose landmarks. + */ + detect(image: ImageSource): PoseLandmarkerResult; + /** + * Performs pose detection on the provided single image and waits + * synchronously for the response. This method creates a copy of the resulting + * masks and should not be used in high-throughput applictions. Only + * use this method when the PoseLandmarker is created with running mode + * `image`. + * + * @param image An image to process. + * @return The landmarker result. Any masks are copied to avoid lifetime + * limits. + * @return The detected pose landmarks. + */ + detect(image: ImageSource, imageProcessingOptions: ImageProcessingOptions): + PoseLandmarkerResult; detect( image: ImageSource, - imageProcessingOptionsOrCallback: ImageProcessingOptions| + imageProcessingOptionsOrCallback?: ImageProcessingOptions| PoseLandmarkerCallback, - callback?: PoseLandmarkerCallback): void { + callback?: PoseLandmarkerCallback): PoseLandmarkerResult|void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : @@ -242,28 +269,32 @@ export class PoseLandmarker extends VisionTaskRunner { this.resetResults(); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result as PoseLandmarkerResult; + } } /** - * Performs pose detection on the provided video frame and waits - * synchronously for the response. Only use this method when the - * PoseLandmarker is created with running mode `video`. + * Performs pose detection on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the PoseLandmarker is created + * with running mode `video`. * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @param callback The callback that is invoked with the result. The * lifetime of the returned masks is only guaranteed for the duration of * the callback. - * @return The detected pose landmarks. */ detectForVideo( videoFrame: ImageSource, timestamp: number, callback: PoseLandmarkerCallback): void; /** - * Performs pose detection on the provided video frame and waits - * synchronously for the response. Only use this method when the - * PoseLandmarker is created with running mode `video`. + * Performs pose detection on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the PoseLandmarker is created + * with running mode `video`. * * @param videoFrame A video frame to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how @@ -272,16 +303,45 @@ export class PoseLandmarker extends VisionTaskRunner { * @param callback The callback that is invoked with the result. The * lifetime of the returned masks is only guaranteed for the duration of * the callback. - * @return The detected pose landmarks. */ detectForVideo( videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, timestamp: number, callback: PoseLandmarkerCallback): void; + /** + * Performs pose detection on the provided video frame and returns the result. + * This method creates a copy of the resulting masks and should not be used + * in high-throughput applictions. Only use this method when the + * PoseLandmarker is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The landmarker result. Any masks are copied to extend the + * lifetime of the returned data. + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): + PoseLandmarkerResult; + /** + * Performs pose detection on the provided video frame and returns the result. + * This method creates a copy of the resulting masks and should not be used + * in high-throughput applictions. The method returns synchronously once the + * callback returns. Only use this method when the PoseLandmarker is created + * with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @return The landmarker result. Any masks are copied to extend the lifetime + * of the returned data. + */ + detectForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number): PoseLandmarkerResult; detectForVideo( videoFrame: ImageSource, timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|PoseLandmarkerCallback, - callback?: PoseLandmarkerCallback): void { + timestampOrCallback?: number|PoseLandmarkerCallback, + callback?: PoseLandmarkerCallback): PoseLandmarkerResult|void { const imageProcessingOptions = typeof timestampOrImageProcessingOptions !== 'number' ? timestampOrImageProcessingOptions : @@ -291,10 +351,14 @@ export class PoseLandmarker extends VisionTaskRunner { timestampOrCallback as number; this.userCallback = typeof timestampOrCallback === 'function' ? timestampOrCallback : - callback!; + callback; + this.resetResults(); this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + + if (!this.userCallback) { + return this.result as PoseLandmarkerResult; + } } private resetResults(): void { @@ -315,7 +379,10 @@ export class PoseLandmarker extends VisionTaskRunner { if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { return; } - this.userCallback(this.result as Required); + + if (this.userCallback) { + this.userCallback(this.result as Required); + } } /** Sets the default values for the graph. */ @@ -438,8 +505,9 @@ export class PoseLandmarker extends VisionTaskRunner { 'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM); this.graphRunner.attachImageVectorListener( SEGMENTATION_MASK_STREAM, (masks, timestamp) => { - this.result.segmentationMasks = - masks.map(wasmImage => this.convertToMPImage(wasmImage)); + this.result.segmentationMasks = masks.map( + wasmImage => this.convertToMPImage( + wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); this.maybeInvokeCallback(); }); diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts index 907cb16b3..62efa5a3c 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts @@ -331,4 +331,26 @@ describe('PoseLandmarker', () => { listenerCalled = true; }); }); + + it('returns result', () => { + const landmarksProto = [createLandmarks().serializeBinary()]; + const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; + + // Pass the test data to our listener + poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + poseLandmarker.listeners.get('normalized_landmarks')! + (landmarksProto, 1337); + poseLandmarker.listeners.get('world_landmarks')! + (worldLandmarksProto, 1337); + poseLandmarker.listeners.get('auxiliary_landmarks')! + (landmarksProto, 1337); + }); + + // Invoke the pose landmarker + const result = poseLandmarker.detect({} as HTMLImageElement); + expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); + expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); + expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); + }); });