Invoke PoseListener callback while C++ Packet is still active
PiperOrigin-RevId: 528061429
This commit is contained in:
		
							parent
							
								
									253f13ad62
								
							
						
					
					
						commit
						8e510a3255
					
				|  | @ -62,12 +62,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void; | |||
| 
 | ||||
| /** Performs pose landmarks detection on images. */ | ||||
| export class PoseLandmarker extends VisionTaskRunner { | ||||
|   private result: PoseLandmarkerResult = { | ||||
|     landmarks: [], | ||||
|     worldLandmarks: [], | ||||
|     auxilaryLandmarks: [] | ||||
|   }; | ||||
|   private result: Partial<PoseLandmarkerResult> = {}; | ||||
|   private outputSegmentationMasks = false; | ||||
|   private userCallback: PoseLandmarkerCallback = () => {}; | ||||
|   private readonly options: PoseLandmarkerGraphOptions; | ||||
|   private readonly poseLandmarksDetectorGraphOptions: | ||||
|       PoseLandmarksDetectorGraphOptions; | ||||
|  | @ -239,14 +236,13 @@ export class PoseLandmarker extends VisionTaskRunner { | |||
|         typeof imageProcessingOptionsOrCallback !== 'function' ? | ||||
|         imageProcessingOptionsOrCallback : | ||||
|         {}; | ||||
|     const userCallback = | ||||
|         typeof imageProcessingOptionsOrCallback === 'function' ? | ||||
|     this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? | ||||
|         imageProcessingOptionsOrCallback : | ||||
|         callback!; | ||||
| 
 | ||||
|     this.resetResults(); | ||||
|     this.processImageData(image, imageProcessingOptions); | ||||
|     userCallback(this.result); | ||||
|     this.userCallback = () => {}; | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -293,19 +289,33 @@ export class PoseLandmarker extends VisionTaskRunner { | |||
|     const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? | ||||
|         timestampOrImageProcessingOptions : | ||||
|         timestampOrCallback as number; | ||||
|     const userCallback = typeof timestampOrCallback === 'function' ? | ||||
|     this.userCallback = typeof timestampOrCallback === 'function' ? | ||||
|         timestampOrCallback : | ||||
|         callback!; | ||||
|     this.resetResults(); | ||||
|     this.processVideoData(videoFrame, imageProcessingOptions, timestamp); | ||||
|     userCallback(this.result); | ||||
|     this.userCallback = () => {}; | ||||
|   } | ||||
| 
 | ||||
|   private resetResults(): void { | ||||
|     this.result = {landmarks: [], worldLandmarks: [], auxilaryLandmarks: []}; | ||||
|     if (this.outputSegmentationMasks) { | ||||
|       this.result.segmentationMasks = []; | ||||
|     this.result = {}; | ||||
|   } | ||||
| 
 | ||||
|   /** Invokes the user callback once all data has been received. */ | ||||
|   private maybeInvokeCallback(): void { | ||||
|     if (!('landmarks' in this.result)) { | ||||
|       return; | ||||
|     } | ||||
|     if (!('worldLandmarks' in this.result)) { | ||||
|       return; | ||||
|     } | ||||
|     if (!('landmarks' in this.result)) { | ||||
|       return; | ||||
|     } | ||||
|     if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { | ||||
|       return; | ||||
|     } | ||||
|     this.userCallback(this.result as Required<PoseLandmarkerResult>); | ||||
|   } | ||||
| 
 | ||||
|   /** Sets the default values for the graph. */ | ||||
|  | @ -385,30 +395,39 @@ export class PoseLandmarker extends VisionTaskRunner { | |||
|         NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { | ||||
|           this.addJsLandmarks(binaryProto); | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
|     this.graphRunner.attachEmptyPacketListener( | ||||
|         NORM_LANDMARKS_STREAM, timestamp => { | ||||
|           this.result.landmarks = []; | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
| 
 | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { | ||||
|           this.adddJsWorldLandmarks(binaryProto); | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
|     this.graphRunner.attachEmptyPacketListener( | ||||
|         WORLD_LANDMARKS_STREAM, timestamp => { | ||||
|           this.result.worldLandmarks = []; | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
| 
 | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => { | ||||
|           this.addJsAuxiliaryLandmarks(binaryProto); | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
|     this.graphRunner.attachEmptyPacketListener( | ||||
|         AUXILIARY_LANDMARKS_STREAM, timestamp => { | ||||
|           this.result.auxilaryLandmarks = []; | ||||
|           this.setLatestOutputTimestamp(timestamp); | ||||
|           this.maybeInvokeCallback(); | ||||
|         }); | ||||
| 
 | ||||
|     if (this.outputSegmentationMasks) { | ||||
|  | @ -419,10 +438,13 @@ export class PoseLandmarker extends VisionTaskRunner { | |||
|             this.result.segmentationMasks = | ||||
|                 masks.map(wasmImage => this.convertToMPImage(wasmImage)); | ||||
|             this.setLatestOutputTimestamp(timestamp); | ||||
|             this.maybeInvokeCallback(); | ||||
|           }); | ||||
|       this.graphRunner.attachEmptyPacketListener( | ||||
|           SEGMENTATION_MASK_STREAM, timestamp => { | ||||
|             this.result.segmentationMasks = []; | ||||
|             this.setLatestOutputTimestamp(timestamp); | ||||
|             this.maybeInvokeCallback(); | ||||
|           }); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -260,4 +260,38 @@ describe('PoseLandmarker', () => { | |||
|     expect(landmarks1).toBeDefined(); | ||||
|     expect(landmarks1).toEqual(landmarks2); | ||||
|   }); | ||||
| 
 | ||||
|   it('invokes listener once masks are avaiblae', (done) => { | ||||
|     const landmarksProto = [createLandmarks().serializeBinary()]; | ||||
|     const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; | ||||
|     const masks = [ | ||||
|       {data: new Float32Array([0, 1, 2, 3]), width: 2, height: 2}, | ||||
|     ]; | ||||
|     let listenerCalled = false; | ||||
| 
 | ||||
| 
 | ||||
|     poseLandmarker.setOptions({outputSegmentationMasks: true}); | ||||
| 
 | ||||
|     // Pass the test data to our listener
 | ||||
|     poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { | ||||
|       expect(listenerCalled).toBeFalse(); | ||||
|       poseLandmarker.listeners.get('normalized_landmarks')! | ||||
|           (landmarksProto, 1337); | ||||
|       expect(listenerCalled).toBeFalse(); | ||||
|       poseLandmarker.listeners.get('world_landmarks')! | ||||
|           (worldLandmarksProto, 1337); | ||||
|       expect(listenerCalled).toBeFalse(); | ||||
|       poseLandmarker.listeners.get('auxiliary_landmarks')! | ||||
|           (landmarksProto, 1337); | ||||
|       expect(listenerCalled).toBeFalse(); | ||||
|       poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337); | ||||
|       expect(listenerCalled).toBeTrue(); | ||||
|       done(); | ||||
|     }); | ||||
| 
 | ||||
|     // Invoke the pose landmarker
 | ||||
|     poseLandmarker.detect({} as HTMLImageElement, () => { | ||||
|       listenerCalled = true; | ||||
|     }); | ||||
|   }); | ||||
| }); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user