diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts index 44effa879..2d72bf1dc 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts @@ -62,12 +62,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void; /** Performs pose landmarks detection on images. */ export class PoseLandmarker extends VisionTaskRunner { - private result: PoseLandmarkerResult = { - landmarks: [], - worldLandmarks: [], - auxilaryLandmarks: [] - }; + private result: Partial = {}; private outputSegmentationMasks = false; + private userCallback: PoseLandmarkerCallback = () => {}; private readonly options: PoseLandmarkerGraphOptions; private readonly poseLandmarksDetectorGraphOptions: PoseLandmarksDetectorGraphOptions; @@ -239,14 +236,13 @@ export class PoseLandmarker extends VisionTaskRunner { typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - const userCallback = - typeof imageProcessingOptionsOrCallback === 'function' ? + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; this.resetResults(); this.processImageData(image, imageProcessingOptions); - userCallback(this.result); + this.userCallback = () => {}; } /** @@ -293,19 +289,33 @@ export class PoseLandmarker extends VisionTaskRunner { const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? timestampOrImageProcessingOptions : timestampOrCallback as number; - const userCallback = typeof timestampOrCallback === 'function' ? + this.userCallback = typeof timestampOrCallback === 'function' ? timestampOrCallback : callback!; this.resetResults(); this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - userCallback(this.result); + this.userCallback = () => {}; } private resetResults(): void { - this.result = {landmarks: [], worldLandmarks: [], auxilaryLandmarks: []}; - if (this.outputSegmentationMasks) { - this.result.segmentationMasks = []; + this.result = {}; + } + + /** Invokes the user callback once all data has been received. */ + private maybeInvokeCallback(): void { + if (!('landmarks' in this.result)) { + return; } + if (!('worldLandmarks' in this.result)) { + return; + } + if (!('landmarks' in this.result)) { + return; + } + if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { + return; + } + this.userCallback(this.result as Required); } /** Sets the default values for the graph. */ @@ -385,30 +395,39 @@ export class PoseLandmarker extends VisionTaskRunner { NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsLandmarks(binaryProto); this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( NORM_LANDMARKS_STREAM, timestamp => { + this.result.landmarks = []; this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachProtoVectorListener( WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.adddJsWorldLandmarks(binaryProto); this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( WORLD_LANDMARKS_STREAM, timestamp => { + this.result.worldLandmarks = []; this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachProtoVectorListener( AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => { this.addJsAuxiliaryLandmarks(binaryProto); this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( AUXILIARY_LANDMARKS_STREAM, timestamp => { + this.result.auxilaryLandmarks = []; this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); if (this.outputSegmentationMasks) { @@ -419,10 +438,13 @@ export class PoseLandmarker extends VisionTaskRunner { this.result.segmentationMasks = masks.map(wasmImage => this.convertToMPImage(wasmImage)); this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); this.graphRunner.attachEmptyPacketListener( SEGMENTATION_MASK_STREAM, timestamp => { + this.result.segmentationMasks = []; this.setLatestOutputTimestamp(timestamp); + this.maybeInvokeCallback(); }); } diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts index 2d76f656f..794df68b8 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts @@ -260,4 +260,38 @@ describe('PoseLandmarker', () => { expect(landmarks1).toBeDefined(); expect(landmarks1).toEqual(landmarks2); }); + + it('invokes listener once masks are avaiblae', (done) => { + const landmarksProto = [createLandmarks().serializeBinary()]; + const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; + const masks = [ + {data: new Float32Array([0, 1, 2, 3]), width: 2, height: 2}, + ]; + let listenerCalled = false; + + + poseLandmarker.setOptions({outputSegmentationMasks: true}); + + // Pass the test data to our listener + poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(listenerCalled).toBeFalse(); + poseLandmarker.listeners.get('normalized_landmarks')! + (landmarksProto, 1337); + expect(listenerCalled).toBeFalse(); + poseLandmarker.listeners.get('world_landmarks')! + (worldLandmarksProto, 1337); + expect(listenerCalled).toBeFalse(); + poseLandmarker.listeners.get('auxiliary_landmarks')! + (landmarksProto, 1337); + expect(listenerCalled).toBeFalse(); + poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337); + expect(listenerCalled).toBeTrue(); + done(); + }); + + // Invoke the pose landmarker + poseLandmarker.detect({} as HTMLImageElement, () => { + listenerCalled = true; + }); + }); });