Invoke PoseListener callback while C++ Packet is still active
PiperOrigin-RevId: 528061429
This commit is contained in:
parent
253f13ad62
commit
8e510a3255
|
@ -62,12 +62,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
|
|||
|
||||
/** Performs pose landmarks detection on images. */
|
||||
export class PoseLandmarker extends VisionTaskRunner {
|
||||
private result: PoseLandmarkerResult = {
|
||||
landmarks: [],
|
||||
worldLandmarks: [],
|
||||
auxilaryLandmarks: []
|
||||
};
|
||||
private result: Partial<PoseLandmarkerResult> = {};
|
||||
private outputSegmentationMasks = false;
|
||||
private userCallback: PoseLandmarkerCallback = () => {};
|
||||
private readonly options: PoseLandmarkerGraphOptions;
|
||||
private readonly poseLandmarksDetectorGraphOptions:
|
||||
PoseLandmarksDetectorGraphOptions;
|
||||
|
@ -239,14 +236,13 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
const userCallback =
|
||||
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.resetResults();
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
userCallback(this.result);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -293,19 +289,33 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
timestampOrCallback as number;
|
||||
const userCallback = typeof timestampOrCallback === 'function' ?
|
||||
this.userCallback = typeof timestampOrCallback === 'function' ?
|
||||
timestampOrCallback :
|
||||
callback!;
|
||||
this.resetResults();
|
||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||
userCallback(this.result);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
private resetResults(): void {
|
||||
this.result = {landmarks: [], worldLandmarks: [], auxilaryLandmarks: []};
|
||||
if (this.outputSegmentationMasks) {
|
||||
this.result.segmentationMasks = [];
|
||||
this.result = {};
|
||||
}
|
||||
|
||||
/** Invokes the user callback once all data has been received. */
|
||||
private maybeInvokeCallback(): void {
|
||||
if (!('landmarks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
if (!('worldLandmarks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
if (!('landmarks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
this.userCallback(this.result as Required<PoseLandmarkerResult>);
|
||||
}
|
||||
|
||||
/** Sets the default values for the graph. */
|
||||
|
@ -385,30 +395,39 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||
this.addJsLandmarks(binaryProto);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
NORM_LANDMARKS_STREAM, timestamp => {
|
||||
this.result.landmarks = [];
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||
this.adddJsWorldLandmarks(binaryProto);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
WORLD_LANDMARKS_STREAM, timestamp => {
|
||||
this.result.worldLandmarks = [];
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||
this.addJsAuxiliaryLandmarks(binaryProto);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
AUXILIARY_LANDMARKS_STREAM, timestamp => {
|
||||
this.result.auxilaryLandmarks = [];
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
|
||||
if (this.outputSegmentationMasks) {
|
||||
|
@ -419,10 +438,13 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
this.result.segmentationMasks =
|
||||
masks.map(wasmImage => this.convertToMPImage(wasmImage));
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
SEGMENTATION_MASK_STREAM, timestamp => {
|
||||
this.result.segmentationMasks = [];
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -260,4 +260,38 @@ describe('PoseLandmarker', () => {
|
|||
expect(landmarks1).toBeDefined();
|
||||
expect(landmarks1).toEqual(landmarks2);
|
||||
});
|
||||
|
||||
it('invokes listener once masks are avaiblae', (done) => {
|
||||
const landmarksProto = [createLandmarks().serializeBinary()];
|
||||
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
|
||||
const masks = [
|
||||
{data: new Float32Array([0, 1, 2, 3]), width: 2, height: 2},
|
||||
];
|
||||
let listenerCalled = false;
|
||||
|
||||
|
||||
poseLandmarker.setOptions({outputSegmentationMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('normalized_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
|
||||
expect(listenerCalled).toBeTrue();
|
||||
done();
|
||||
});
|
||||
|
||||
// Invoke the pose landmarker
|
||||
poseLandmarker.detect({} as HTMLImageElement, () => {
|
||||
listenerCalled = true;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user