Invoke PoseListener callback while C++ Packet is still active

PiperOrigin-RevId: 528061429
This commit is contained in:
Sebastian Schmidt 2023-04-28 21:19:46 -07:00 committed by Copybara-Service
parent 253f13ad62
commit 8e510a3255
2 changed files with 69 additions and 13 deletions

View File

@ -62,12 +62,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
/** Performs pose landmarks detection on images. */ /** Performs pose landmarks detection on images. */
export class PoseLandmarker extends VisionTaskRunner { export class PoseLandmarker extends VisionTaskRunner {
private result: PoseLandmarkerResult = { private result: Partial<PoseLandmarkerResult> = {};
landmarks: [],
worldLandmarks: [],
auxilaryLandmarks: []
};
private outputSegmentationMasks = false; private outputSegmentationMasks = false;
private userCallback: PoseLandmarkerCallback = () => {};
private readonly options: PoseLandmarkerGraphOptions; private readonly options: PoseLandmarkerGraphOptions;
private readonly poseLandmarksDetectorGraphOptions: private readonly poseLandmarksDetectorGraphOptions:
PoseLandmarksDetectorGraphOptions; PoseLandmarksDetectorGraphOptions;
@ -239,14 +236,13 @@ export class PoseLandmarker extends VisionTaskRunner {
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
const userCallback = this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback!;
this.resetResults(); this.resetResults();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
userCallback(this.result); this.userCallback = () => {};
} }
/** /**
@ -293,19 +289,33 @@ export class PoseLandmarker extends VisionTaskRunner {
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions : timestampOrImageProcessingOptions :
timestampOrCallback as number; timestampOrCallback as number;
const userCallback = typeof timestampOrCallback === 'function' ? this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback : timestampOrCallback :
callback!; callback!;
this.resetResults(); this.resetResults();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
userCallback(this.result); this.userCallback = () => {};
} }
private resetResults(): void { private resetResults(): void {
this.result = {landmarks: [], worldLandmarks: [], auxilaryLandmarks: []}; this.result = {};
if (this.outputSegmentationMasks) {
this.result.segmentationMasks = [];
} }
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (!('landmarks' in this.result)) {
return;
}
if (!('worldLandmarks' in this.result)) {
return;
}
if (!('landmarks' in this.result)) {
return;
}
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
return;
}
this.userCallback(this.result as Required<PoseLandmarkerResult>);
} }
/** Sets the default values for the graph. */ /** Sets the default values for the graph. */
@ -385,30 +395,39 @@ export class PoseLandmarker extends VisionTaskRunner {
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto); this.addJsLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
NORM_LANDMARKS_STREAM, timestamp => { NORM_LANDMARKS_STREAM, timestamp => {
this.result.landmarks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.adddJsWorldLandmarks(binaryProto); this.adddJsWorldLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
WORLD_LANDMARKS_STREAM, timestamp => { WORLD_LANDMARKS_STREAM, timestamp => {
this.result.worldLandmarks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => { AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsAuxiliaryLandmarks(binaryProto); this.addJsAuxiliaryLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
AUXILIARY_LANDMARKS_STREAM, timestamp => { AUXILIARY_LANDMARKS_STREAM, timestamp => {
this.result.auxilaryLandmarks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
if (this.outputSegmentationMasks) { if (this.outputSegmentationMasks) {
@ -419,10 +438,13 @@ export class PoseLandmarker extends VisionTaskRunner {
this.result.segmentationMasks = this.result.segmentationMasks =
masks.map(wasmImage => this.convertToMPImage(wasmImage)); masks.map(wasmImage => this.convertToMPImage(wasmImage));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
SEGMENTATION_MASK_STREAM, timestamp => { SEGMENTATION_MASK_STREAM, timestamp => {
this.result.segmentationMasks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }

View File

@ -260,4 +260,38 @@ describe('PoseLandmarker', () => {
expect(landmarks1).toBeDefined(); expect(landmarks1).toBeDefined();
expect(landmarks1).toEqual(landmarks2); expect(landmarks1).toEqual(landmarks2);
}); });
it('invokes listener once masks are avaiblae', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const masks = [
{data: new Float32Array([0, 1, 2, 3]), width: 2, height: 2},
];
let listenerCalled = false;
poseLandmarker.setOptions({outputSegmentationMasks: true});
// Pass the test data to our listener
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('normalized_landmarks')!
(landmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
expect(listenerCalled).toBeTrue();
done();
});
// Invoke the pose landmarker
poseLandmarker.detect({} as HTMLImageElement, () => {
listenerCalled = true;
});
});
}); });