Extract shared types to create and test landmarks

PiperOrigin-RevId: 525568412
This commit is contained in:
Sebastian Schmidt 2023-04-19 15:34:07 -07:00 committed by Copybara-Service
parent 476c7efc18
commit ffbd799b8d
11 changed files with 217 additions and 87 deletions

View File

@ -125,3 +125,27 @@ jasmine_node_test(
name = "embedder_options_test",
deps = [":embedder_options_test_lib"],
)
mediapipe_ts_library(
name = "landmark_result",
srcs = [
"landmark_result.ts",
"landmark_result_test_lib.ts",
],
deps = [
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/components/containers:landmark",
],
)
mediapipe_ts_library(
name = "landmark_result_test_lib",
testonly = True,
srcs = ["landmark_result.test.ts"],
deps = [":landmark_result"],
)
jasmine_node_test(
name = "landmark_result_test",
deps = [":landmark_result_test_lib"],
)

View File

@ -0,0 +1,52 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
describe('convertToLandmarks()', () => {
it('transforms custom values', () => {
const landmarkListProto = createLandmarks(0.1, 0.2, 0.3);
const result = convertToLandmarks(landmarkListProto);
expect(result).toEqual([{x: 0.1, y: 0.2, z: 0.3}]);
});
it('transforms default values', () => {
const landmarkListProto = createLandmarks();
const result = convertToLandmarks(landmarkListProto);
expect(result).toEqual([{x: 0, y: 0, z: 0}]);
});
});
describe('convertToWorldLandmarks()', () => {
it('transforms custom values', () => {
const worldLandmarkListProto = createWorldLandmarks(10, 20, 30);
const result = convertToWorldLandmarks(worldLandmarkListProto);
expect(result).toEqual([{x: 10, y: 20, z: 30}]);
});
it('transforms default values', () => {
const worldLandmarkListProto = createWorldLandmarks();
const result = convertToWorldLandmarks(worldLandmarkListProto);
expect(result).toEqual([{x: 0, y: 0, z: 0}]);
});
});

View File

@ -0,0 +1,45 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {LandmarkList as LandmarkListProto, NormalizedLandmarkList as NormalizedLandmarkListProto} from '../../../../framework/formats/landmark_pb';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
/** Converts raw data into a landmark. */
export function convertToLandmarks(proto: NormalizedLandmarkListProto):
NormalizedLandmark[] {
const landmarks: NormalizedLandmark[] = [];
for (const landmark of proto.getLandmarkList()) {
landmarks.push({
x: landmark.getX() ?? 0,
y: landmark.getY() ?? 0,
z: landmark.getZ() ?? 0,
});
}
return landmarks;
}
/** Converts raw data into a world landmark. */
export function convertToWorldLandmarks(proto: LandmarkListProto): Landmark[] {
const worldLandmarks: Landmark[] = [];
for (const worldLandmark of proto.getLandmarkList()) {
worldLandmarks.push({
x: worldLandmark.getX() ?? 0,
y: worldLandmark.getY() ?? 0,
z: worldLandmark.getZ() ?? 0,
});
}
return worldLandmarks;
}

View File

@ -0,0 +1,44 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Landmark as LandmarkProto, LandmarkList as LandmarkListProto, NormalizedLandmark as NormalizedLandmarkProto, NormalizedLandmarkList as NormalizedLandmarkListProto} from '../../../../framework/formats/landmark_pb';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Creates a normalized landmark list with one entrry. */
export function createLandmarks(
x?: number, y?: number, z?: number): NormalizedLandmarkListProto {
const landmarksProto = new NormalizedLandmarkListProto();
const landmark = new NormalizedLandmarkProto();
if (x !== undefined) landmark.setX(x);
if (y !== undefined) landmark.setY(y);
if (z !== undefined) landmark.setZ(z);
landmarksProto.addLandmark(landmark);
return landmarksProto;
}
/** Creates a world landmark list with one entry. */
export function createWorldLandmarks(
x?: number, y?: number, z?: number): LandmarkListProto {
const worldLandmarksProto = new LandmarkListProto();
const landmark = new LandmarkProto();
if (x !== undefined) landmark.setX(x);
if (y !== undefined) landmark.setY(y);
if (z !== undefined) landmark.setZ(z);
worldLandmarksProto.addLandmark(landmark);
return worldLandmarksProto;
}

View File

@ -31,6 +31,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/components/containers:matrix",
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
@ -73,9 +74,9 @@ mediapipe_ts_library(
":face_landmarker_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/framework/formats:matrix_data_jspb_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_jspb_proto",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:vision_task_runner",

View File

@ -23,8 +23,8 @@ import {FaceDetectorGraphOptions} from '../../../../tasks/cc/vision/face_detecto
import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb';
import {FaceLandmarkerGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options_pb';
import {FaceLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options_pb';
import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {convertFromClassifications} from '../../../../tasks/web/components/processors/classifier_result';
import {convertToLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
@ -243,15 +243,7 @@ export class FaceLandmarker extends VisionTaskRunner {
for (const binaryProto of data) {
const faceLandmarksProto =
NormalizedLandmarkListProto.deserializeBinary(binaryProto);
const landmarks: NormalizedLandmark[] = [];
for (const faceLandmarkProto of faceLandmarksProto.getLandmarkList()) {
landmarks.push({
x: faceLandmarkProto.getX() ?? 0,
y: faceLandmarkProto.getY() ?? 0,
z: faceLandmarkProto.getZ() ?? 0,
});
}
this.result.faceLandmarks.push(landmarks);
this.result.faceLandmarks.push(convertToLandmarks(faceLandmarksProto));
}
}

View File

@ -17,9 +17,9 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {MatrixData as MatrixDataProto} from '../../../../framework/formats/matrix_data_pb';
import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb';
import {createLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
@ -31,7 +31,7 @@ import {FaceLandmarkerOptions} from './face_landmarker_options';
type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void);
function createBlendshapes(): Uint8Array[] {
function createBlendshapes(): ClassificationList {
const blendshapesProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
@ -39,27 +39,17 @@ function createBlendshapes(): Uint8Array[] {
classification.setLabel('face_label');
classification.setDisplayName('face_display_name');
blendshapesProto.addClassification(classification);
return [blendshapesProto.serializeBinary()];
return blendshapesProto;
}
function createFacialTransformationMatrixes(): Uint8Array[] {
function createFacialTransformationMatrixes(): FaceGeometryProto {
const faceGeometryProto = new FaceGeometryProto();
const posteTransformationMatrix = new MatrixDataProto();
posteTransformationMatrix.setRows(1);
posteTransformationMatrix.setCols(1);
posteTransformationMatrix.setPackedDataList([1.0]);
faceGeometryProto.setPoseTransformMatrix(posteTransformationMatrix);
return [faceGeometryProto.serializeBinary()];
}
function createLandmarks(): Uint8Array[] {
const faceLandmarksProto = new NormalizedLandmarkList();
const landmark = new NormalizedLandmark();
landmark.setX(0.3);
landmark.setY(0.4);
landmark.setZ(0.5);
faceLandmarksProto.addLandmark(landmark);
return [faceLandmarksProto.serializeBinary()];
return faceGeometryProto;
}
class FaceLandmarkerFake extends FaceLandmarker implements MediapipeTasksFake {
@ -243,13 +233,17 @@ describe('FaceLandmarker', () => {
});
it('transforms results', async () => {
const landmarksProto = [createLandmarks().serializeBinary()];
const blendshapesProto = [createBlendshapes().serializeBinary()];
const faceGeometryProto =
[createFacialTransformationMatrixes().serializeBinary()];
// Pass the test data to our listener
faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceLandmarker);
faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337);
faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337);
faceLandmarker.listeners.get('face_geometry')!
(createFacialTransformationMatrixes(), 1337);
faceLandmarker.listeners.get('face_landmarks')!(landmarksProto, 1337);
faceLandmarker.listeners.get('blendshapes')!(blendshapesProto, 1337);
faceLandmarker.listeners.get('face_geometry')!(faceGeometryProto, 1337);
});
await faceLandmarker.setOptions({
@ -266,7 +260,7 @@ describe('FaceLandmarker', () => {
expect(faceLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(landmarks).toEqual({
faceLandmarks: [[{x: 0.3, y: 0.4, z: 0.5}]],
faceLandmarks: [[{x: 0, y: 0, z: 0}]],
faceBlendshapes: [{
categories: [{
index: 1,
@ -282,12 +276,16 @@ describe('FaceLandmarker', () => {
});
it('clears results between invoations', async () => {
const landmarksProto = [createLandmarks().serializeBinary()];
const blendshapesProto = [createBlendshapes().serializeBinary()];
const faceGeometryProto =
[createFacialTransformationMatrixes().serializeBinary()];
// Pass the test data to our listener
faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337);
faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337);
faceLandmarker.listeners.get('face_geometry')!
(createFacialTransformationMatrixes(), 1337);
faceLandmarker.listeners.get('face_landmarks')!(landmarksProto, 1337);
faceLandmarker.listeners.get('blendshapes')!(blendshapesProto, 1337);
faceLandmarker.listeners.get('face_geometry')!(faceGeometryProto, 1337);
});
await faceLandmarker.setOptions({

View File

@ -27,6 +27,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
@ -61,7 +62,7 @@ mediapipe_ts_library(
":hand_landmarker_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:vision_task_runner",

View File

@ -24,6 +24,7 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm
import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb';
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
@ -259,15 +260,7 @@ export class HandLandmarker extends VisionTaskRunner {
for (const binaryProto of data) {
const handLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto);
const landmarks: NormalizedLandmark[] = [];
for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) {
landmarks.push({
x: handLandmarkProto.getX() ?? 0,
y: handLandmarkProto.getY() ?? 0,
z: handLandmarkProto.getZ() ?? 0,
});
}
this.landmarks.push(landmarks);
this.landmarks.push(convertToLandmarks(handLandmarksProto));
}
}
@ -279,16 +272,8 @@ export class HandLandmarker extends VisionTaskRunner {
for (const binaryProto of data) {
const handWorldLandmarksProto =
LandmarkList.deserializeBinary(binaryProto);
const worldLandmarks: Landmark[] = [];
for (const handWorldLandmarkProto of
handWorldLandmarksProto.getLandmarkList()) {
worldLandmarks.push({
x: handWorldLandmarkProto.getX() ?? 0,
y: handWorldLandmarkProto.getY() ?? 0,
z: handWorldLandmarkProto.getZ() ?? 0,
});
}
this.worldLandmarks.push(worldLandmarks);
this.worldLandmarks.push(
convertToWorldLandmarks(handWorldLandmarksProto));
}
}

View File

@ -26,7 +26,7 @@ export declare interface HandLandmarkerResult {
/** Hand landmarks of detected hands. */
landmarks: NormalizedLandmark[][];
/** Hand landmarks in world coordniates of detected hands. */
/** Hand landmarks in world coordinates of detected hands. */
worldLandmarks: Landmark[][];
/** Handedness of detected hands. */

View File

@ -17,7 +17,7 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
@ -30,7 +30,7 @@ import {HandLandmarkerOptions} from './hand_landmarker_options';
type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void);
function createHandednesses(): Uint8Array[] {
function createHandednesses(): ClassificationList {
const handsProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
@ -38,27 +38,7 @@ function createHandednesses(): Uint8Array[] {
classification.setLabel('handedness_label');
classification.setDisplayName('handedness_display_name');
handsProto.addClassification(classification);
return [handsProto.serializeBinary()];
}
function createLandmarks(): Uint8Array[] {
const handLandmarksProto = new NormalizedLandmarkList();
const landmark = new NormalizedLandmark();
landmark.setX(0.3);
landmark.setY(0.4);
landmark.setZ(0.5);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
}
function createWorldLandmarks(): Uint8Array[] {
const handLandmarksProto = new LandmarkList();
const landmark = new Landmark();
landmark.setX(21);
landmark.setY(22);
landmark.setZ(23);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
return handsProto;
}
class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake {
@ -212,13 +192,17 @@ describe('HandLandmarker', () => {
});
it('transforms results', async () => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const handednessProto = [createHandednesses().serializeBinary()];
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(handLandmarker);
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337);
handLandmarker.listeners.get('hand_landmarks')!(landmarksProto, 1337);
handLandmarker.listeners.get('world_hand_landmarks')!
(createWorldLandmarks(), 1337);
handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337);
(worldLandmarksProto, 1337);
handLandmarker.listeners.get('handedness')!(handednessProto, 1337);
});
// Invoke the hand landmarker
@ -230,8 +214,8 @@ describe('HandLandmarker', () => {
expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(landmarks).toEqual({
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
'landmarks': [[{'x': 0, 'y': 0, 'z': 0}]],
'worldLandmarks': [[{'x': 0, 'y': 0, 'z': 0}]],
'handednesses': [[{
'score': 0.1,
'index': 1,
@ -242,12 +226,16 @@ describe('HandLandmarker', () => {
});
it('clears results between invoations', async () => {
const landmarks = [createLandmarks().serializeBinary()];
const worldLandmarks = [createWorldLandmarks().serializeBinary()];
const handedness = [createHandednesses().serializeBinary()];
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337);
handLandmarker.listeners.get('hand_landmarks')!(landmarks, 1337);
handLandmarker.listeners.get('world_hand_landmarks')!
(createWorldLandmarks(), 1337);
handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337);
(worldLandmarks, 1337);
handLandmarker.listeners.get('handedness')!(handedness, 1337);
});
// Invoke the hand landmarker twice