Support multiple poses for PoseLandmarker

PiperOrigin-RevId: 529430797
This commit is contained in:
Sebastian Schmidt 2023-05-04 09:27:21 -07:00 committed by Copybara-Service
parent c6e3f08282
commit 767db32d69
3 changed files with 52 additions and 12 deletions

View File

@ -309,7 +309,7 @@ export class PoseLandmarker extends VisionTaskRunner {
if (!('worldLandmarks' in this.result)) { if (!('worldLandmarks' in this.result)) {
return; return;
} }
if (!('landmarks' in this.result)) { if (!('auxilaryLandmarks' in this.result)) {
return; return;
} }
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) { if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
@ -332,10 +332,11 @@ export class PoseLandmarker extends VisionTaskRunner {
* Converts raw data into a landmark, and adds it to our landmarks list. * Converts raw data into a landmark, and adds it to our landmarks list.
*/ */
private addJsLandmarks(data: Uint8Array[]): void { private addJsLandmarks(data: Uint8Array[]): void {
this.result.landmarks = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const poseLandmarksProto = const poseLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto); NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.landmarks = convertToLandmarks(poseLandmarksProto); this.result.landmarks.push(convertToLandmarks(poseLandmarksProto));
} }
} }
@ -344,11 +345,12 @@ export class PoseLandmarker extends VisionTaskRunner {
* worldLandmarks list. * worldLandmarks list.
*/ */
private adddJsWorldLandmarks(data: Uint8Array[]): void { private adddJsWorldLandmarks(data: Uint8Array[]): void {
this.result.worldLandmarks = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const poseWorldLandmarksProto = const poseWorldLandmarksProto =
LandmarkList.deserializeBinary(binaryProto); LandmarkList.deserializeBinary(binaryProto);
this.result.worldLandmarks = this.result.worldLandmarks.push(
convertToWorldLandmarks(poseWorldLandmarksProto); convertToWorldLandmarks(poseWorldLandmarksProto));
} }
} }
@ -357,11 +359,12 @@ export class PoseLandmarker extends VisionTaskRunner {
* landmarks list. * landmarks list.
*/ */
private addJsAuxiliaryLandmarks(data: Uint8Array[]): void { private addJsAuxiliaryLandmarks(data: Uint8Array[]): void {
this.result.auxilaryLandmarks = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const auxiliaryLandmarksProto = const auxiliaryLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto); NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.auxilaryLandmarks = this.result.auxilaryLandmarks.push(
convertToLandmarks(auxiliaryLandmarksProto); convertToLandmarks(auxiliaryLandmarksProto));
} }
} }

View File

@ -26,13 +26,13 @@ export {Category, Landmark, NormalizedLandmark};
*/ */
export declare interface PoseLandmarkerResult { export declare interface PoseLandmarkerResult {
/** Pose landmarks of detected poses. */ /** Pose landmarks of detected poses. */
landmarks: NormalizedLandmark[]; landmarks: NormalizedLandmark[][];
/** Pose landmarks in world coordinates of detected poses. */ /** Pose landmarks in world coordinates of detected poses. */
worldLandmarks: Landmark[]; worldLandmarks: Landmark[][];
/** Detected auxiliary landmarks, used for deriving ROI for next frame. */ /** Detected auxiliary landmarks, used for deriving ROI for next frame. */
auxilaryLandmarks: NormalizedLandmark[]; auxilaryLandmarks: NormalizedLandmark[][];
/** Segmentation mask for the detected pose. */ /** Segmentation mask for the detected pose. */
segmentationMasks?: MPImage[]; segmentationMasks?: MPImage[];

View File

@ -222,9 +222,9 @@ describe('PoseLandmarker', () => {
.toHaveBeenCalledTimes(1); .toHaveBeenCalledTimes(1);
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.landmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]); expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]); expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]); expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage); expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage);
done(); done();
}); });
@ -261,6 +261,43 @@ describe('PoseLandmarker', () => {
expect(landmarks1).toEqual(landmarks2); expect(landmarks1).toEqual(landmarks2);
}); });
it('supports multiple poses', (done) => {
const landmarksProto = [
createLandmarks(0.1, 0.2, 0.3).serializeBinary(),
createLandmarks(0.4, 0.5, 0.6).serializeBinary()
];
const worldLandmarksProto = [
createWorldLandmarks(1, 2, 3).serializeBinary(),
createWorldLandmarks(4, 5, 6).serializeBinary()
];
poseLandmarker.setOptions({numPoses: 1});
// Pass the test data to our listener
poseLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
poseLandmarker.listeners.get('normalized_landmarks')!
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker
poseLandmarker.detect({} as HTMLImageElement, result => {
expect(result.landmarks).toEqual([
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
]);
expect(result.worldLandmarks).toEqual([
[{'x': 1, 'y': 2, 'z': 3}], [{'x': 4, 'y': 5, 'z': 6}]
]);
expect(result.auxilaryLandmarks).toEqual([
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
]);
done();
});
});
it('invokes listener once masks are avaiblae', (done) => { it('invokes listener once masks are avaiblae', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()]; const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];