diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index b12adb0df..ee9caaa1f 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -22,6 +22,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -58,7 +59,8 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void; /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { - private result: ImageSegmenterResult = {}; + private categoryMask?: MPMask; + private confidenceMasks?: MPMask[]; private labels: string[] = []; private userCallback?: ImageSegmenterCallback; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; @@ -265,10 +267,7 @@ export class ImageSegmenter extends VisionTaskRunner { this.reset(); this.processImageData(image, imageProcessingOptions); - - if (!this.userCallback) { - return this.result; - } + return this.processResults(); } /** @@ -347,10 +346,7 @@ export class ImageSegmenter extends VisionTaskRunner { this.reset(); this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - - if (!this.userCallback) { - return this.result; - } + return this.processResults(); } /** @@ -369,21 +365,20 @@ export class ImageSegmenter extends VisionTaskRunner { } private reset(): void { - this.result = {}; + this.categoryMask = undefined; + this.confidenceMasks = undefined; } - /** Invokes the user callback once all data has been received. */ - private maybeInvokeCallback(): void { - if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) { - return; - } - if (this.outputCategoryMask && !('categoryMask' in this.result)) { - return; - } - - if (this.userCallback) { - this.userCallback(this.result); - + private processResults(): ImageSegmenterResult|void { + try { + const result = + new ImageSegmenterResult(this.confidenceMasks, this.categoryMask); + if (this.userCallback) { + this.userCallback(result); + } else { + return result; + } + } finally { // Free the image memory, now that we've kept all streams alive long // enough to be returned in our callbacks. this.freeKeepaliveStreams(); @@ -417,17 +412,15 @@ export class ImageSegmenter extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { - this.result.confidenceMasks = masks.map( + this.confidenceMasks = masks.map( wasmImage => this.convertToMPMask( wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( CONFIDENCE_MASKS_STREAM, timestamp => { - this.result.confidenceMasks = undefined; + this.confidenceMasks = []; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); } @@ -438,16 +431,14 @@ export class ImageSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { - this.result.categoryMask = this.convertToMPMask( + this.categoryMask = this.convertToMPMask( mask, /* shouldCopyData= */ !this.userCallback); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( CATEGORY_MASK_STREAM, timestamp => { - this.result.categoryMask = undefined; + this.categoryMask = undefined; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts index 25962d57e..9107a5c80 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.ts @@ -17,18 +17,26 @@ import {MPMask} from '../../../../tasks/web/vision/core/mask'; /** The output result of ImageSegmenter. */ -export declare interface ImageSegmenterResult { - /** - * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed - * `MPImage`s where, for each mask, each pixel represents the prediction - * confidence, usually in the [0, 1] range. - */ - confidenceMasks?: MPMask[]; +export class ImageSegmenterResult { + constructor( + /** + * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed + * `MPImage`s where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. + */ + readonly confidenceMasks?: MPMask[], + /** + * A category mask represented as a `Uint8ClampedArray` or + * `WebGLTexture`-backed `MPImage` where each pixel represents the class + * which the pixel in the original image was predicted to belong to. + */ + readonly categoryMask?: MPMask) {} - /** - * A category mask represented as a `Uint8ClampedArray` or - * `WebGLTexture`-backed `MPImage` where each pixel represents the class which - * the pixel in the original image was predicted to belong to. - */ - categoryMask?: MPMask; + /** Frees the resources held by the category and confidence masks. */ + close(): void { + this.confidenceMasks?.forEach(m => { + m.close(); + }); + this.categoryMask?.close(); + } } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index f9172ecd3..10983b488 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -263,7 +263,7 @@ describe('ImageSegmenter', () => { }); }); - it('invokes listener once masks are available', async () => { + it('invokes listener after masks are available', async () => { const categoryMask = new Uint8Array([1]); const confidenceMask = new Float32Array([0.0]); let listenerCalled = false; @@ -282,7 +282,7 @@ describe('ImageSegmenter', () => { {data: confidenceMask, width: 1, height: 1}, ], 1337); - expect(listenerCalled).toBeTrue(); + expect(listenerCalled).toBeFalse(); }); return new Promise(resolve => { @@ -307,6 +307,6 @@ describe('ImageSegmenter', () => { const result = imageSegmenter.segment({} as HTMLImageElement); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); - result.confidenceMasks![0].close(); + result.close(); }); }); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index e3f79d26d..16bf10eeb 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {RegionOfInterest} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {Color as ColorProto} from '../../../../util/color_pb'; @@ -83,7 +84,8 @@ export type InteractiveSegmenterCallback = * - batch is always 1 */ export class InteractiveSegmenter extends VisionTaskRunner { - private result: InteractiveSegmenterResult = {}; + private categoryMask?: MPMask; + private confidenceMasks?: MPMask[]; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private userCallback?: InteractiveSegmenterCallback; @@ -276,28 +278,24 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.reset(); this.processRenderData(roi, this.getSynctheticTimestamp()); this.processImageData(image, imageProcessingOptions); - - if (!this.userCallback) { - return this.result; - } + return this.processResults(); } private reset(): void { - this.result = {}; + this.confidenceMasks = undefined; + this.categoryMask = undefined; } - /** Invokes the user callback once all data has been received. */ - private maybeInvokeCallback(): void { - if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) { - return; - } - if (this.outputCategoryMask && !('categoryMask' in this.result)) { - return; - } - - if (this.userCallback) { - this.userCallback(this.result); - + private processResults(): InteractiveSegmenterResult|void { + try { + const result = new InteractiveSegmenterResult( + this.confidenceMasks, this.categoryMask); + if (this.userCallback) { + this.userCallback(result); + } else { + return result; + } + } finally { // Free the image memory, now that we've kept all streams alive long // enough to be returned in our callbacks. this.freeKeepaliveStreams(); @@ -333,17 +331,15 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { - this.result.confidenceMasks = masks.map( + this.confidenceMasks = masks.map( wasmImage => this.convertToMPMask( wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( CONFIDENCE_MASKS_STREAM, timestamp => { - this.result.confidenceMasks = undefined; + this.confidenceMasks = []; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); } @@ -354,16 +350,14 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { - this.result.categoryMask = this.convertToMPMask( + this.categoryMask = this.convertToMPMask( mask, /* shouldCopyData= */ !this.userCallback); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( CATEGORY_MASK_STREAM, timestamp => { - this.result.categoryMask = undefined; + this.categoryMask = undefined; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts index e773b5e64..5da7e4df3 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.ts @@ -17,18 +17,26 @@ import {MPMask} from '../../../../tasks/web/vision/core/mask'; /** The output result of InteractiveSegmenter. */ -export declare interface InteractiveSegmenterResult { - /** - * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed - * `MPImage`s where, for each mask, each pixel represents the prediction - * confidence, usually in the [0, 1] range. - */ - confidenceMasks?: MPMask[]; +export class InteractiveSegmenterResult { + constructor( + /** + * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed + * `MPImage`s where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. + */ + readonly confidenceMasks?: MPMask[], + /** + * A category mask represented as a `Uint8ClampedArray` or + * `WebGLTexture`-backed `MPImage` where each pixel represents the class + * which the pixel in the original image was predicted to belong to. + */ + readonly categoryMask?: MPMask) {} - /** - * A category mask represented as a `Uint8ClampedArray` or - * `WebGLTexture`-backed `MPImage` where each pixel represents the class which - * the pixel in the original image was predicted to belong to. - */ - categoryMask?: MPMask; + /** Frees the resources held by the category and confidence masks. */ + close(): void { + this.confidenceMasks?.forEach(m => { + m.close(); + }); + this.categoryMask?.close(); + } } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index c5603c5c6..6550202e0 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -277,7 +277,7 @@ describe('InteractiveSegmenter', () => { }); }); - it('invokes listener once masks are avaiblae', async () => { + it('invokes listener after masks are avaiblae', async () => { const categoryMask = new Uint8Array([1]); const confidenceMask = new Float32Array([0.0]); let listenerCalled = false; @@ -296,7 +296,7 @@ describe('InteractiveSegmenter', () => { {data: confidenceMask, width: 1, height: 1}, ], 1337); - expect(listenerCalled).toBeTrue(); + expect(listenerCalled).toBeFalse(); }); return new Promise(resolve => { @@ -322,6 +322,6 @@ describe('InteractiveSegmenter', () => { const result = interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); - result.confidenceMasks![0].close(); + result.close(); }); }); diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts index 0d3181aa0..927b3c24b 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts @@ -21,9 +21,11 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb'; import {PoseLandmarkerGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb'; import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {Connection} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -61,7 +63,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void; /** Performs pose landmarks detection on images. */ export class PoseLandmarker extends VisionTaskRunner { - private result: Partial = {}; + private landmarks: NormalizedLandmark[][] = []; + private worldLandmarks: Landmark[][] = []; + private segmentationMasks?: MPMask[]; private outputSegmentationMasks = false; private userCallback?: PoseLandmarkerCallback; private readonly options: PoseLandmarkerGraphOptions; @@ -268,10 +272,7 @@ export class PoseLandmarker extends VisionTaskRunner { this.resetResults(); this.processImageData(image, imageProcessingOptions); - - if (!this.userCallback) { - return this.result as PoseLandmarkerResult; - } + return this.processResults(); } /** @@ -352,31 +353,25 @@ export class PoseLandmarker extends VisionTaskRunner { this.resetResults(); this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - - if (!this.userCallback) { - return this.result as PoseLandmarkerResult; - } + return this.processResults(); } private resetResults(): void { - this.result = {}; + this.landmarks = []; + this.worldLandmarks = []; + this.segmentationMasks = undefined; } - /** Invokes the user callback once all data has been received. */ - private maybeInvokeCallback(): void { - if (!('landmarks' in this.result)) { - return; - } - if (!('worldLandmarks' in this.result)) { - return; - } - if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { - return; - } - - if (this.userCallback) { - this.userCallback(this.result as Required); - + private processResults(): PoseLandmarkerResult|void { + try { + const result = new PoseLandmarkerResult( + this.landmarks, this.worldLandmarks, this.segmentationMasks); + if (this.userCallback) { + this.userCallback(result); + } else { + return result; + } + } finally { // Free the image memory, now that we've finished our callback. this.freeKeepaliveStreams(); } @@ -396,11 +391,11 @@ export class PoseLandmarker extends VisionTaskRunner { * Converts raw data into a landmark, and adds it to our landmarks list. */ private addJsLandmarks(data: Uint8Array[]): void { - this.result.landmarks = []; + this.landmarks = []; for (const binaryProto of data) { const poseLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - this.result.landmarks.push(convertToLandmarks(poseLandmarksProto)); + this.landmarks.push(convertToLandmarks(poseLandmarksProto)); } } @@ -409,11 +404,11 @@ export class PoseLandmarker extends VisionTaskRunner { * worldLandmarks list. */ private adddJsWorldLandmarks(data: Uint8Array[]): void { - this.result.worldLandmarks = []; + this.worldLandmarks = []; for (const binaryProto of data) { const poseWorldLandmarksProto = LandmarkList.deserializeBinary(binaryProto); - this.result.worldLandmarks.push( + this.worldLandmarks.push( convertToWorldLandmarks(poseWorldLandmarksProto)); } } @@ -448,26 +443,22 @@ export class PoseLandmarker extends VisionTaskRunner { NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( NORM_LANDMARKS_STREAM, timestamp => { - this.result.landmarks = []; + this.landmarks = []; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachProtoVectorListener( WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( WORLD_LANDMARKS_STREAM, timestamp => { - this.result.worldLandmarks = []; + this.worldLandmarks = []; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); if (this.outputSegmentationMasks) { @@ -477,17 +468,15 @@ export class PoseLandmarker extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( SEGMENTATION_MASK_STREAM, (masks, timestamp) => { - this.result.segmentationMasks = masks.map( + this.segmentationMasks = masks.map( wasmImage => this.convertToMPMask( wasmImage, /* shouldCopyData= */ !this.userCallback)); this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( SEGMENTATION_MASK_STREAM, timestamp => { - this.result.segmentationMasks = []; + this.segmentationMasks = []; this.setLatestOutputTimestamp(timestamp); - this.maybeInvokeCallback(); }); } diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_result.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_result.ts index 96e698a85..92ba804d6 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_result.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_result.ts @@ -24,13 +24,18 @@ export {Category, Landmark, NormalizedLandmark}; * Represents the pose landmarks deection results generated by `PoseLandmarker`. * Each vector element represents a single pose detected in the image. */ -export declare interface PoseLandmarkerResult { - /** Pose landmarks of detected poses. */ - landmarks: NormalizedLandmark[][]; +export class PoseLandmarkerResult { + constructor(/** Pose landmarks of detected poses. */ + readonly landmarks: NormalizedLandmark[][], + /** Pose landmarks in world coordinates of detected poses. */ + readonly worldLandmarks: Landmark[][], + /** Segmentation mask for the detected pose. */ + readonly segmentationMasks?: MPMask[]) {} - /** Pose landmarks in world coordinates of detected poses. */ - worldLandmarks: Landmark[][]; - - /** Segmentation mask for the detected pose. */ - segmentationMasks?: MPMask[]; + /** Frees the resources held by the segmentation masks. */ + close(): void { + this.segmentationMasks?.forEach(m => { + m.close(); + }); + } } diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts index d4a49db97..9131b93ec 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts @@ -287,7 +287,7 @@ describe('PoseLandmarker', () => { }); }); - it('invokes listener once masks are available', (done) => { + it('invokes listener after masks are available', (done) => { const landmarksProto = [createLandmarks().serializeBinary()]; const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; const masks = [ @@ -309,13 +309,12 @@ describe('PoseLandmarker', () => { expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse(); poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337); - expect(listenerCalled).toBeTrue(); - done(); }); // Invoke the pose landmarker poseLandmarker.detect({} as HTMLImageElement, () => { listenerCalled = true; + done(); }); }); @@ -336,5 +335,6 @@ describe('PoseLandmarker', () => { expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); + result.close(); }); });