diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 67db27ddb..1f28cb0fe 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -19,6 +19,7 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/face_landmarker", "//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index a1444e10b..ebeac54c5 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -2,6 +2,23 @@ This package contains the vision tasks for MediaPipe. +## Face Landmark Detection + +The MediaPipe Face Landmarker task lets you detect the landmarks of faces in +an image. You can use this Task to localize key points of a face and render +visual effects over the faces. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const faceLandmarker = await FaceLandmarker.createFromModelPath(vision, + "model.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const landmarks = faceLandmarker.detect(image); +``` + ## Face Stylizer The MediaPipe Face Stylizer lets you perform face stylization on images. diff --git a/mediapipe/tasks/web/vision/face_landmarker/BUILD b/mediapipe/tasks/web/vision/face_landmarker/BUILD new file mode 100644 index 000000000..bd5e971a3 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_landmarker/BUILD @@ -0,0 +1,81 @@ +# This contains the MediaPipe Face Landmarker Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more face categories, using Face Landmarker. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "face_landmarker", + srcs = ["face_landmarker.ts"], + visibility = ["//visibility:public"], + deps = [ + ":face_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_jspb_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/containers:matrix", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "face_landmarker_types", + srcs = [ + "face_landmarker_options.d.ts", + "face_landmarker_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/containers:matrix", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "face_landmarker_test_lib", + testonly = True, + srcs = [ + "face_landmarker_test.ts", + ], + deps = [ + ":face_landmarker", + ":face_landmarker_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/framework/formats:matrix_data_jspb_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + ], +) + +jasmine_node_test( + name = "face_landmarker_test", + tags = ["nomsan"], + deps = [":face_landmarker_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts new file mode 100644 index 000000000..113dd37c5 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts @@ -0,0 +1,355 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationList as ClassificationListProto} from '../../../../framework/formats/classification_pb'; +import {NormalizedLandmarkList as NormalizedLandmarkListProto} from '../../../../framework/formats/landmark_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {FaceDetectorGraphOptions} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb'; +import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb'; +import {FaceLandmarkerGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options_pb'; +import {FaceLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options_pb'; +import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {convertFromClassifications} from '../../../../tasks/web/components/processors/classifier_result'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {FaceLandmarkerOptions} from './face_landmarker_options'; +import {FaceLandmarkerResult} from './face_landmarker_result'; + +export * from './face_landmarker_options'; +export * from './face_landmarker_result'; +export {ImageSource}; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const LANDMARKS_STREAM = 'face_landmarks'; +const BLENDSHAPES_STREAM = 'blendshapes'; +const FACE_GEOMETRY_STREAM = 'face_geometry'; +const FACE_LANDMARKER_GRAPH = + 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph'; + +const DEFAULT_NUM_FACES = 1; +const DEFAULT_SCORE_THRESHOLD = 0.5; + +/** + * Performs face landmarks detection on images. + * + * This API expects a pre-trained face landmarker model asset bundle. + */ +export class FaceLandmarker extends VisionTaskRunner { + private result: FaceLandmarkerResult = {faceLandmarks: []}; + private outputFaceBlendshapes = false; + private outputFacialTransformationMatrixes = false; + + private readonly options: FaceLandmarkerGraphOptions; + private readonly faceLandmarksDetectorGraphOptions: + FaceLandmarksDetectorGraphOptions; + private readonly faceDetectorGraphOptions: FaceDetectorGraphOptions; + + /** + * Initializes the Wasm runtime and creates a new `FaceLandmarker` from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param faceLandmarkerOptions The options for the FaceLandmarker. + * Note that either a path to the model asset or a model buffer needs to + * be provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + faceLandmarkerOptions: FaceLandmarkerOptions): Promise { + return VisionTaskRunner.createVisionInstance( + FaceLandmarker, wasmFileset, faceLandmarkerOptions); + } + + /** + * Initializes the Wasm runtime and creates a new `FaceLandmarker` based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createVisionInstance( + FaceLandmarker, wasmFileset, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new `FaceLandmarker` based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createVisionInstance( + FaceLandmarker, wasmFileset, {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + + this.options = new FaceLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); + this.faceLandmarksDetectorGraphOptions = + new FaceLandmarksDetectorGraphOptions(); + this.options.setFaceLandmarksDetectorGraphOptions( + this.faceLandmarksDetectorGraphOptions); + this.faceDetectorGraphOptions = new FaceDetectorGraphOptions(); + this.options.setFaceDetectorGraphOptions(this.faceDetectorGraphOptions); + + this.initDefaults(); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for this `FaceLandmarker`. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the face landmarker. + */ + override setOptions(options: FaceLandmarkerOptions): Promise { + // Configure face detector options. + if ('numFaces' in options) { + this.faceDetectorGraphOptions.setNumFaces( + options.numFaces ?? DEFAULT_NUM_FACES); + } + if ('minFaceDetectionConfidence' in options) { + this.faceDetectorGraphOptions.setMinDetectionConfidence( + options.minFaceDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + // Configure face landmark detector options. + if ('minTrackingConfidence' in options) { + this.options.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minFacePresenceConfidence' in options) { + this.faceLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minFacePresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + if ('outputFaceBlendshapes' in options) { + this.outputFaceBlendshapes = !!options.outputFaceBlendshapes; + } + + if ('outputFacialTransformationMatrixes' in options) { + this.outputFacialTransformationMatrixes = + !!options.outputFacialTransformationMatrixes; + } + + return this.applyOptions(options); + } + + /** + * Performs face landmarks detection on the provided single image and waits + * synchronously for the response. Only use this method when the + * FaceLandmarker is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The detected face landmarks. + */ + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + FaceLandmarkerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.result; + } + + /** + * Performs face landmarks detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * FaceLandmarker is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return The detected face landmarks. + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): FaceLandmarkerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.result; + } + + private resetResults(): void { + this.result = {faceLandmarks: []}; + if (this.outputFaceBlendshapes) { + this.result.faceBlendshapes = []; + } + if (this.outputFacialTransformationMatrixes) { + this.result.facialTransformationMatrixes = []; + } + } + + /** Sets the default values for the graph. */ + private initDefaults(): void { + this.faceDetectorGraphOptions.setNumFaces(DEFAULT_NUM_FACES); + this.faceDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.faceLandmarksDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.options.setMinTrackingConfidence(DEFAULT_SCORE_THRESHOLD); + } + + /** Adds new face landmark from the given proto. */ + private addJsLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const faceLandmarksProto = + NormalizedLandmarkListProto.deserializeBinary(binaryProto); + const landmarks: NormalizedLandmark[] = []; + for (const faceLandmarkProto of faceLandmarksProto.getLandmarkList()) { + landmarks.push({ + x: faceLandmarkProto.getX() ?? 0, + y: faceLandmarkProto.getY() ?? 0, + z: faceLandmarkProto.getZ() ?? 0, + }); + } + this.result.faceLandmarks.push(landmarks); + } + } + + /** Adds new blendshapes from the given proto. */ + private addBlenshape(data: Uint8Array[]): void { + if (!this.result.faceBlendshapes) { + return; + } + + for (const binaryProto of data) { + const classificationList = + ClassificationListProto.deserializeBinary(binaryProto); + this.result.faceBlendshapes.push(convertFromClassifications( + classificationList.getClassificationList() ?? [])); + } + } + + /** Adds new transformation matrixes from the given proto. */ + private addFacialTransformationMatrixes(data: Uint8Array[]): void { + if (!this.result.facialTransformationMatrixes) { + return; + } + + for (const binaryProto of data) { + const faceGeometryProto = + FaceGeometryProto.deserializeBinary(binaryProto); + const poseTransformMatrix = faceGeometryProto.getPoseTransformMatrix(); + if (poseTransformMatrix) { + this.result.facialTransformationMatrixes.push({ + rows: poseTransformMatrix.getRows() ?? 0, + columns: poseTransformMatrix.getCols() ?? 0, + data: poseTransformMatrix.getPackedDataList() ?? [], + }); + } + } + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(LANDMARKS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + FaceLandmarkerGraphOptions.ext, this.options); + + const landmarkerNode = new CalculatorGraphConfig.Node(); + landmarkerNode.setCalculator(FACE_LANDMARKER_GRAPH); + landmarkerNode.addInputStream('IMAGE:' + IMAGE_STREAM); + landmarkerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + landmarkerNode.addOutputStream('NORM_LANDMARKS:' + LANDMARKS_STREAM); + landmarkerNode.setOptions(calculatorOptions); + + graphConfig.addNode(landmarkerNode); + + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, (binaryProto, timestamp) => { + this.addJsLandmarks(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + LANDMARKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + if (this.outputFaceBlendshapes) { + graphConfig.addOutputStream(BLENDSHAPES_STREAM); + landmarkerNode.addOutputStream('BLENDSHAPES:' + BLENDSHAPES_STREAM); + this.graphRunner.attachProtoVectorListener( + BLENDSHAPES_STREAM, (binaryProto, timestamp) => { + this.addBlenshape(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + BLENDSHAPES_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputFacialTransformationMatrixes) { + graphConfig.addOutputStream(FACE_GEOMETRY_STREAM); + landmarkerNode.addOutputStream('FACE_GEOMETRY:' + FACE_GEOMETRY_STREAM); + + this.graphRunner.attachProtoVectorListener( + FACE_GEOMETRY_STREAM, (binaryProto, timestamp) => { + this.addFacialTransformationMatrixes(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + FACE_GEOMETRY_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_options.d.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_options.d.ts new file mode 100644 index 000000000..f537ca2f6 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_options.d.ts @@ -0,0 +1,58 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe FaceLandmarker Task */ +export declare interface FaceLandmarkerOptions extends VisionTaskOptions { + /** + * The maximum number of faces can be detected by the FaceLandmarker. + * Defaults to 1. + */ + numFaces?: number|undefined; + + /** + * The minimum confidence score for the face detection to be considered + * successful. Defaults to 0.5. + */ + minFaceDetectionConfidence?: number|undefined; + + /** + * The minimum confidence score of face presence score in the face landmark + * detection. Defaults to 0.5. + */ + minFacePresenceConfidence?: number|undefined; + + /** + * The minimum confidence score for the face tracking to be considered + * successful. Defaults to 0.5. + */ + minTrackingConfidence?: number|undefined; + + /** + * Whether FaceLandmarker outputs face blendshapes classification. Face + * blendshapes are used for rendering the 3D face model. + */ + outputFaceBlendshapes?: boolean|undefined; + + /** + * Whether FaceLandmarker outputs facial transformation_matrix. Facial + * transformation matrix is used to transform the face landmarks in canonical + * face to the detected face, so that users can apply face effects on the + * detected landmarks. + */ + outputFacialTransformationMatrixes?: boolean|undefined; +} diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_result.d.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_result.d.ts new file mode 100644 index 000000000..123c0a82a --- /dev/null +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_result.d.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Classifications} from '../../../../tasks/web/components/containers/classification_result'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {Matrix} from '../../../../tasks/web/components/containers/matrix'; + +export {Category, Landmark, NormalizedLandmark}; + +/** + * Represents the face landmarks deection results generated by `FaceLandmarker`. + */ +export declare interface FaceLandmarkerResult { + /** Detected face landmarks in normalized image coordinates. */ + faceLandmarks: NormalizedLandmark[][]; + + /** Optional face blendshapes results. */ + faceBlendshapes?: Classifications[]; + + /** Optional facial transformation matrix. */ + facialTransformationMatrixes?: Matrix[]; +} diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts new file mode 100644 index 000000000..92012a6f3 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts @@ -0,0 +1,306 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import 'jasmine'; + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {MatrixData as MatrixDataProto} from '../../../../framework/formats/matrix_data_pb'; +import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; + +import {FaceLandmarker} from './face_landmarker'; +import {FaceLandmarkerOptions} from './face_landmarker_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); + +function createBlendshapes(): Uint8Array[] { + const blendshapesProto = new ClassificationList(); + const classification = new Classification(); + classification.setScore(0.1); + classification.setIndex(1); + classification.setLabel('face_label'); + classification.setDisplayName('face_display_name'); + blendshapesProto.addClassification(classification); + return [blendshapesProto.serializeBinary()]; +} + +function createFacialTransformationMatrixes(): Uint8Array[] { + const faceGeometryProto = new FaceGeometryProto(); + const posteTransformationMatrix = new MatrixDataProto(); + posteTransformationMatrix.setRows(1); + posteTransformationMatrix.setCols(1); + posteTransformationMatrix.setPackedDataList([1.0]); + faceGeometryProto.setPoseTransformMatrix(posteTransformationMatrix); + return [faceGeometryProto.serializeBinary()]; +} + +function createLandmarks(): Uint8Array[] { + const faceLandmarksProto = new NormalizedLandmarkList(); + const landmark = new NormalizedLandmark(); + landmark.setX(0.3); + landmark.setY(0.4); + landmark.setZ(0.5); + faceLandmarksProto.addLandmark(landmark); + return [faceLandmarksProto.serializeBinary()]; +} + +class FaceLandmarkerFake extends FaceLandmarker implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + listeners = new Map(); + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toMatch( + /(face_landmarks|blendshapes|face_geometry)/); + this.listeners.set(stream, listener); + }); + + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + spyOn(this.graphRunner, 'addProtoToStream'); + } + + getGraphRunner(): VisionGraphRunner { + return this.graphRunner; + } +} + +describe('FaceLandmarker', () => { + let faceLandmarker: FaceLandmarkerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + faceLandmarker = new FaceLandmarkerFake(); + await faceLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(faceLandmarker); + verifyListenersRegistered(faceLandmarker); + }); + + it('reloads graph when settings are changed', async () => { + verifyListenersRegistered(faceLandmarker); + + await faceLandmarker.setOptions({numFaces: 1}); + verifyGraph(faceLandmarker, [['faceDetectorGraphOptions', 'numFaces'], 1]); + verifyListenersRegistered(faceLandmarker); + + await faceLandmarker.setOptions({numFaces: 5}); + verifyGraph(faceLandmarker, [['faceDetectorGraphOptions', 'numFaces'], 5]); + verifyListenersRegistered(faceLandmarker); + }); + + it('merges options', async () => { + await faceLandmarker.setOptions({numFaces: 1}); + await faceLandmarker.setOptions({minFaceDetectionConfidence: 0.5}); + verifyGraph(faceLandmarker, [ + 'faceDetectorGraphOptions', { + numFaces: 1, + baseOptions: undefined, + minDetectionConfidence: 0.5, + minSuppressionThreshold: 0.5 + } + ]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionPath: [keyof FaceLandmarkerOptions, ...string[]]; + fieldPath: string[]; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionPath: ['numFaces'], + fieldPath: ['faceDetectorGraphOptions', 'numFaces'], + customValue: 5, + defaultValue: 1 + }, + { + optionPath: ['minFaceDetectionConfidence'], + fieldPath: ['faceDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionPath: ['minFacePresenceConfidence'], + fieldPath: + ['faceLandmarksDetectorGraphOptions', 'minDetectionConfidence'], + customValue: 0.2, + defaultValue: 0.5 + }, + { + optionPath: ['minTrackingConfidence'], + fieldPath: ['minTrackingConfidence'], + customValue: 0.3, + defaultValue: 0.5 + }, + ]; + + /** Creates an options object that can be passed to setOptions() */ + function createOptions( + path: string[], value: unknown): FaceLandmarkerOptions { + const options: Record = {}; + let currentLevel = options; + for (const element of path.slice(0, -1)) { + currentLevel[element] = {}; + currentLevel = currentLevel[element] as Record; + } + currentLevel[path[path.length - 1]] = value; + return options; + } + + for (const testCase of testCases) { + it(`uses default value for ${testCase.optionPath[0]}`, async () => { + verifyGraph( + faceLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + + it(`can set ${testCase.optionPath[0]}`, async () => { + await faceLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(faceLandmarker, [testCase.fieldPath, testCase.customValue]); + }); + + it(`can clear ${testCase.optionPath[0]}`, async () => { + await faceLandmarker.setOptions( + createOptions(testCase.optionPath, testCase.customValue)); + verifyGraph(faceLandmarker, [testCase.fieldPath, testCase.customValue]); + + await faceLandmarker.setOptions( + createOptions(testCase.optionPath, undefined)); + verifyGraph( + faceLandmarker, [testCase.fieldPath, testCase.defaultValue]); + }); + } + + it('supports outputFaceBlendshapes', async () => { + const stream = 'blendshapes'; + await faceLandmarker.setOptions({}); + expect(faceLandmarker.graph!.getOutputStreamList()).not.toContain(stream); + + await faceLandmarker.setOptions({outputFaceBlendshapes: false}); + expect(faceLandmarker.graph!.getOutputStreamList()).not.toContain(stream); + + await faceLandmarker.setOptions({outputFaceBlendshapes: true}); + expect(faceLandmarker.graph!.getOutputStreamList()).toContain(stream); + }); + + it('supports outputFacialTransformationMatrixes', async () => { + const stream = 'face_geometry'; + await faceLandmarker.setOptions({}); + expect(faceLandmarker.graph!.getOutputStreamList()).not.toContain(stream); + + await faceLandmarker.setOptions( + {outputFacialTransformationMatrixes: false}); + expect(faceLandmarker.graph!.getOutputStreamList()).not.toContain(stream); + + await faceLandmarker.setOptions( + {outputFacialTransformationMatrixes: true}); + expect(faceLandmarker.graph!.getOutputStreamList()).toContain(stream); + }); + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + faceLandmarker.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + // Pass the test data to our listener + faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceLandmarker); + faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337); + faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337); + faceLandmarker.listeners.get('face_geometry')! + (createFacialTransformationMatrixes(), 1337); + }); + + await faceLandmarker.setOptions({ + outputFaceBlendshapes: true, + outputFacialTransformationMatrixes: true + }); + + // Invoke the face landmarker + const landmarks = faceLandmarker.detect({} as HTMLImageElement); + expect(faceLandmarker.getGraphRunner().addProtoToStream) + .toHaveBeenCalledTimes(1); + expect(faceLandmarker.getGraphRunner().addGpuBufferAsImageToStream) + .toHaveBeenCalledTimes(1); + expect(faceLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + + expect(landmarks).toEqual({ + faceLandmarks: [[{x: 0.3, y: 0.4, z: 0.5}]], + faceBlendshapes: [{ + categories: [{ + index: 1, + score: 0.1, + categoryName: 'face_label', + displayName: 'face_display_name' + }], + headIndex: -1, + headName: '' + }], + facialTransformationMatrixes: [({rows: 1, columns: 1, data: [1]})] + }); + }); + + it('clears results between invoations', async () => { + // Pass the test data to our listener + faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { + faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337); + faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337); + faceLandmarker.listeners.get('face_geometry')! + (createFacialTransformationMatrixes(), 1337); + }); + + await faceLandmarker.setOptions({ + outputFaceBlendshapes: true, + outputFacialTransformationMatrixes: true + }); + + // Invoke the face landmarker twice + const landmarks1 = faceLandmarker.detect({} as HTMLImageElement); + const landmarks2 = faceLandmarker.detect({} as HTMLImageElement); + + // Verify that faces2 is not a concatenation of all previously returned + // faces. + expect(landmarks1).toEqual(landmarks2); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 7fca725ec..856d84683 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,6 +15,7 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; @@ -27,6 +28,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const FaceLandmarker = FaceLandmarkerImpl; const FaceStylizer = FaceStylizerImpl; const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; @@ -38,6 +40,7 @@ const ObjectDetector = ObjectDetectorImpl; export { FilesetResolver, + FaceLandmarker, FaceStylizer, GestureRecognizer, HandLandmarker,