Invoke PoseListener callback while C++ Packet is still active

PiperOrigin-RevId: 528061429
This commit is contained in:
Sebastian Schmidt 2023-04-28 21:19:46 -07:00 committed by Copybara-Service
parent 253f13ad62
commit 8e510a3255
2 changed files with 69 additions and 13 deletions

View File

@ -62,12 +62,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
/** Performs pose landmarks detection on images. */
export class PoseLandmarker extends VisionTaskRunner {
private result: PoseLandmarkerResult = {
landmarks: [],
worldLandmarks: [],
auxilaryLandmarks: []
};
private result: Partial<PoseLandmarkerResult> = {};
private outputSegmentationMasks = false;
private userCallback: PoseLandmarkerCallback = () => {};
private readonly options: PoseLandmarkerGraphOptions;
private readonly poseLandmarksDetectorGraphOptions:
PoseLandmarksDetectorGraphOptions;
@ -239,14 +236,13 @@ export class PoseLandmarker extends VisionTaskRunner {
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
const userCallback =
typeof imageProcessingOptionsOrCallback === 'function' ?
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.resetResults();
this.processImageData(image, imageProcessingOptions);
userCallback(this.result);
this.userCallback = () => {};
}
/**
@ -293,19 +289,33 @@ export class PoseLandmarker extends VisionTaskRunner {
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
const userCallback = typeof timestampOrCallback === 'function' ?
this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.resetResults();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
userCallback(this.result);
this.userCallback = () => {};
}
private resetResults(): void {
this.result = {landmarks: [], worldLandmarks: [], auxilaryLandmarks: []};
if (this.outputSegmentationMasks) {
this.result.segmentationMasks = [];
this.result = {};
}
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (!('landmarks' in this.result)) {
return;
}
if (!('worldLandmarks' in this.result)) {
return;
}
if (!('landmarks' in this.result)) {
return;
}
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
return;
}
this.userCallback(this.result as Required<PoseLandmarkerResult>);
}
/** Sets the default values for the graph. */
@ -385,30 +395,39 @@ export class PoseLandmarker extends VisionTaskRunner {
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachEmptyPacketListener(
NORM_LANDMARKS_STREAM, timestamp => {
this.result.landmarks = [];
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachProtoVectorListener(
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.adddJsWorldLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachEmptyPacketListener(
WORLD_LANDMARKS_STREAM, timestamp => {
this.result.worldLandmarks = [];
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachProtoVectorListener(
AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsAuxiliaryLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachEmptyPacketListener(
AUXILIARY_LANDMARKS_STREAM, timestamp => {
this.result.auxilaryLandmarks = [];
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
if (this.outputSegmentationMasks) {
@ -419,10 +438,13 @@ export class PoseLandmarker extends VisionTaskRunner {
this.result.segmentationMasks =
masks.map(wasmImage => this.convertToMPImage(wasmImage));
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachEmptyPacketListener(
SEGMENTATION_MASK_STREAM, timestamp => {
this.result.segmentationMasks = [];
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
}

View File

@ -260,4 +260,38 @@ describe('PoseLandmarker', () => {
expect(landmarks1).toBeDefined();
expect(landmarks1).toEqual(landmarks2);
});
it('invokes listener once masks are avaiblae', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const masks = [
{data: new Float32Array([0, 1, 2, 3]), width: 2, height: 2},
];
let listenerCalled = false;
poseLandmarker.setOptions({outputSegmentationMasks: true});
// Pass the test data to our listener
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('normalized_landmarks')!
(landmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
expect(listenerCalled).toBeTrue();
done();
});
// Invoke the pose landmarker
poseLandmarker.detect({} as HTMLImageElement, () => {
listenerCalled = true;
});
});
});