Support multiple poses for PoseLandmarker

PiperOrigin-RevId: 529430797
This commit is contained in:
Sebastian Schmidt 2023-05-04 09:27:21 -07:00 committed by Copybara-Service
parent c6e3f08282
commit 767db32d69
3 changed files with 52 additions and 12 deletions

View File

@ -309,7 +309,7 @@ export class PoseLandmarker extends VisionTaskRunner {
if (!('worldLandmarks' in this.result)) {
return;
}
if (!('landmarks' in this.result)) {
if (!('auxilaryLandmarks' in this.result)) {
return;
}
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
@ -332,10 +332,11 @@ export class PoseLandmarker extends VisionTaskRunner {
* Converts raw data into a landmark, and adds it to our landmarks list.
*/
private addJsLandmarks(data: Uint8Array[]): void {
this.result.landmarks = [];
for (const binaryProto of data) {
const poseLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.landmarks = convertToLandmarks(poseLandmarksProto);
this.result.landmarks.push(convertToLandmarks(poseLandmarksProto));
}
}
@ -344,11 +345,12 @@ export class PoseLandmarker extends VisionTaskRunner {
* worldLandmarks list.
*/
private adddJsWorldLandmarks(data: Uint8Array[]): void {
this.result.worldLandmarks = [];
for (const binaryProto of data) {
const poseWorldLandmarksProto =
LandmarkList.deserializeBinary(binaryProto);
this.result.worldLandmarks =
convertToWorldLandmarks(poseWorldLandmarksProto);
this.result.worldLandmarks.push(
convertToWorldLandmarks(poseWorldLandmarksProto));
}
}
@ -357,11 +359,12 @@ export class PoseLandmarker extends VisionTaskRunner {
* landmarks list.
*/
private addJsAuxiliaryLandmarks(data: Uint8Array[]): void {
this.result.auxilaryLandmarks = [];
for (const binaryProto of data) {
const auxiliaryLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.auxilaryLandmarks =
convertToLandmarks(auxiliaryLandmarksProto);
this.result.auxilaryLandmarks.push(
convertToLandmarks(auxiliaryLandmarksProto));
}
}

View File

@ -26,13 +26,13 @@ export {Category, Landmark, NormalizedLandmark};
*/
export declare interface PoseLandmarkerResult {
/** Pose landmarks of detected poses. */
landmarks: NormalizedLandmark[];
landmarks: NormalizedLandmark[][];
/** Pose landmarks in world coordinates of detected poses. */
worldLandmarks: Landmark[];
worldLandmarks: Landmark[][];
/** Detected auxiliary landmarks, used for deriving ROI for next frame. */
auxilaryLandmarks: NormalizedLandmark[];
auxilaryLandmarks: NormalizedLandmark[][];
/** Segmentation mask for the detected pose. */
segmentationMasks?: MPImage[];

View File

@ -222,9 +222,9 @@ describe('PoseLandmarker', () => {
.toHaveBeenCalledTimes(1);
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.landmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
expect(result.worldLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
expect(result.auxilaryLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage);
done();
});
@ -261,6 +261,43 @@ describe('PoseLandmarker', () => {
expect(landmarks1).toEqual(landmarks2);
});
it('supports multiple poses', (done) => {
const landmarksProto = [
createLandmarks(0.1, 0.2, 0.3).serializeBinary(),
createLandmarks(0.4, 0.5, 0.6).serializeBinary()
];
const worldLandmarksProto = [
createWorldLandmarks(1, 2, 3).serializeBinary(),
createWorldLandmarks(4, 5, 6).serializeBinary()
];
poseLandmarker.setOptions({numPoses: 1});
// Pass the test data to our listener
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
poseLandmarker.listeners.get('normalized_landmarks')!
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker
poseLandmarker.detect({} as HTMLImageElement, result => {
expect(result.landmarks).toEqual([
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
]);
expect(result.worldLandmarks).toEqual([
[{'x': 1, 'y': 2, 'z': 3}], [{'x': 4, 'y': 5, 'z': 6}]
]);
expect(result.auxilaryLandmarks).toEqual([
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
]);
done();
});
});
it('invokes listener once masks are avaiblae', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];