diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 477ca15c3..714b4613b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_declaration( deps = [ ":bounding_box", ":category", + ":keypoint", ], ) diff --git a/mediapipe/tasks/web/components/containers/detection_result.d.ts b/mediapipe/tasks/web/components/containers/detection_result.d.ts index a338cc901..37817307c 100644 --- a/mediapipe/tasks/web/components/containers/detection_result.d.ts +++ b/mediapipe/tasks/web/components/containers/detection_result.d.ts @@ -16,6 +16,7 @@ import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; import {Category} from '../../../../tasks/web/components/containers/category'; +import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; /** Represents one detection by a detection task. */ export declare interface Detection { @@ -24,6 +25,15 @@ export declare interface Detection { /** The bounding box of the detected objects. */ boundingBox?: BoundingBox; + + /** + * Optional list of keypoints associated with the detection. Keypoints + * represent interesting points related to the detection. For example, the + * keypoints represent the eye, ear and mouth from face detection model. Or + * in the template matching detection, e.g. KNIFT, they can represent the + * feature points for template matching. + */ + keypoints?: NormalizedKeypoint[]; } /** Detection results of a model. */ diff --git a/mediapipe/tasks/web/components/processors/detection_result.test.ts b/mediapipe/tasks/web/components/processors/detection_result.test.ts index 26f8bd8a5..289043506 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.test.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.test.ts @@ -31,6 +31,7 @@ describe('convertFromDetectionProto()', () => { detection.addLabelId(1); detection.addLabel('foo'); detection.addDisplayName('bar'); + const locationData = new LocationData(); const boundingBox = new LocationData.BoundingBox(); boundingBox.setXmin(1); @@ -38,6 +39,14 @@ describe('convertFromDetectionProto()', () => { boundingBox.setWidth(3); boundingBox.setHeight(4); locationData.setBoundingBox(boundingBox); + + const keypoint = new LocationData.RelativeKeypoint(); + keypoint.setX(5); + keypoint.setY(6); + keypoint.setScore(0.7); + keypoint.setKeypointLabel('bar'); + locationData.addRelativeKeypoints(new LocationData.RelativeKeypoint()); + detection.setLocationData(locationData); const result = convertFromDetectionProto(detection); @@ -49,7 +58,13 @@ describe('convertFromDetectionProto()', () => { categoryName: 'foo', displayName: 'bar', }], - boundingBox: {originX: 1, originY: 2, width: 3, height: 4} + boundingBox: {originX: 1, originY: 2, width: 3, height: 4}, + keypoints: [{ + x: 5, + y: 6, + score: 0.7, + label: 'bar', + }], }); }); diff --git a/mediapipe/tasks/web/components/processors/detection_result.ts b/mediapipe/tasks/web/components/processors/detection_result.ts index 01041c915..6b38820bf 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.ts @@ -46,5 +46,18 @@ export function convertFromDetectionProto(source: DetectionProto): Detection { }; } + if (source.getLocationData()?.getRelativeKeypointsList().length) { + detection.keypoints = []; + for (const keypoint of + source.getLocationData()!.getRelativeKeypointsList()) { + detection.keypoints.push({ + x: keypoint.getX() ?? 0.0, + y: keypoint.getY() ?? 0.0, + score: keypoint.getScore() ?? 0.0, + label: keypoint.getKeypointLabel() ?? '', + }); + } + } + return detection; } diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 1f28cb0fe..19c795fd9 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -19,6 +19,7 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/face_detector", "//mediapipe/tasks/web/vision/face_landmarker", "//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/gesture_recognizer", diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index ebeac54c5..d5109142b 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -2,6 +2,22 @@ This package contains the vision tasks for MediaPipe. +## Face Detection + +The MediaPipe Face Detector task lets you detect the presence and location of +faces within images or videos. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const faceDetector = await FaceDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +const detections = faceDetector.detect(image); +``` + ## Face Landmark Detection The MediaPipe Face Landmarker task lets you detect the landmarks of faces in diff --git a/mediapipe/tasks/web/vision/face_detector/BUILD b/mediapipe/tasks/web/vision/face_detector/BUILD new file mode 100644 index 000000000..8225e4948 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/BUILD @@ -0,0 +1,71 @@ +# This contains the MediaPipe Face Detector Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more faces, using Face Detector. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "face_detector", + srcs = ["face_detector.ts"], + visibility = ["//visibility:public"], + deps = [ + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "face_detector_types", + srcs = [ + "face_detector_options.d.ts", + "face_detector_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:bounding_box", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:detection_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "face_detector_test_lib", + testonly = True, + srcs = [ + "face_detector_test.ts", + ], + deps = [ + ":face_detector", + ":face_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/framework/formats:location_data_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "face_detector_test", + tags = ["nomsan"], + deps = [":face_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector.ts b/mediapipe/tasks/web/vision/face_detector/face_detector.ts new file mode 100644 index 000000000..039f7dd44 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector.ts @@ -0,0 +1,213 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {FaceDetectorGraphOptions as FaceDetectorGraphOptionsProto} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb'; +import {convertFromDetectionProto} from '../../../../tasks/web/components/processors/detection_result'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {FaceDetectorOptions} from './face_detector_options'; +import {FaceDetectorResult} from './face_detector_result'; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect_in'; +const DETECTIONS_STREAM = 'detections'; +const FACE_DETECTOR_GRAPH = + 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + +export * from './face_detector_options'; +export * from './face_detector_result'; +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs face detection on images. */ +export class FaceDetector extends VisionTaskRunner { + private result: FaceDetectorResult = {detections: []}; + private readonly options = new FaceDetectorGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new face detector from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param faceDetectorOptions The options for the FaceDetector. Note that + * either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + faceDetectorOptions: FaceDetectorOptions): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, faceDetectorOptions); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new face detector based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createVisionInstance( + FaceDetector, wasmFileset, {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options.setBaseOptions(new BaseOptionsProto()); + this.options.setMinDetectionConfidence(0.5); + this.options.setMinSuppressionThreshold(0.3); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the FaceDetector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the FaceDetector. + */ + override setOptions(options: FaceDetectorOptions): Promise { + if ('minDetectionConfidence' in options) { + this.options.setMinDetectionConfidence( + options.minDetectionConfidence ?? 0.5); + } + if ('minSuppressionThreshold' in options) { + this.options.setMinSuppressionThreshold( + options.minSuppressionThreshold ?? 0.3); + } + return this.applyOptions(options); + } + + /** + * Performs face detection on the provided single image and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + FaceDetectorResult { + this.result = {detections: []}; + this.processImageData(image, imageProcessingOptions); + return this.result; + } + + /** + * Performs face detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * FaceDetector is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @return A result containing the list of detected faces. + */ + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): FaceDetectorResult { + this.result = {detections: []}; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.result; + } + + /** Converts raw data into a Detection, and adds it to our detection list. */ + private addJsFaceDetections(data: Uint8Array[]): void { + for (const binaryProto of data) { + const detectionProto = DetectionProto.deserializeBinary(binaryProto); + this.result.detections.push(convertFromDetectionProto(detectionProto)); + } + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(DETECTIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + FaceDetectorGraphOptionsProto.ext, this.options); + + const detectorNode = new CalculatorGraphConfig.Node(); + detectorNode.setCalculator(FACE_DETECTOR_GRAPH); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); + detectorNode.setOptions(calculatorOptions); + + graphConfig.addNode(detectorNode); + + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, (binaryProto, timestamp) => { + this.addJsFaceDetections(binaryProto); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener(DETECTIONS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts new file mode 100644 index 000000000..665035f7e --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Face Detector Task */ +export interface FaceDetectorOptions extends VisionTaskOptions { + /** + * The minimum confidence score for the face detection to be considered + * successful. Defaults to 0.5. + */ + minDetectionConfidence?: number|undefined; + + /** + * The minimum non-maximum-suppression threshold for face detection to be + * considered overlapped. Defaults to 0.3. + */ + minSuppressionThreshold?: number|undefined; +} diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts new file mode 100644 index 000000000..6a36559f7 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts @@ -0,0 +1,19 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box'; +export {Category} from '../../../../tasks/web/components/containers/category'; +export {Detection, DetectionResult as FaceDetectorResult} from '../../../../tasks/web/components/containers/detection_result'; diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts new file mode 100644 index 000000000..88dd20d2b --- /dev/null +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts @@ -0,0 +1,193 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {LocationData} from '../../../../framework/formats/location_data_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {FaceDetector} from './face_detector'; +import {FaceDetectorOptions} from './face_detector_options'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class FaceDetectorFake extends FaceDetector implements MediapipeTasksFake { + lastSampleRate: number|undefined; + calculatorName = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('detections'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('FaceDetector', () => { + let faceDetector: FaceDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + faceDetector = new FaceDetectorFake(); + await faceDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(faceDetector); + verifyListenersRegistered(faceDetector); + }); + + it('reloads graph when settings are changed', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyListenersRegistered(faceDetector); + + await faceDetector.setOptions({minDetectionConfidence: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.2]); + verifyListenersRegistered(faceDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await faceDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + faceDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await faceDetector.setOptions({minDetectionConfidence: 0.1}); + await faceDetector.setOptions({minSuppressionThreshold: 0.2}); + verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]); + verifyGraph(faceDetector, ['minSuppressionThreshold', 0.2]); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof FaceDetectorOptions; + protoName: string; + customValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'minDetectionConfidence', + protoName: 'minDetectionConfidence', + customValue: 0.1, + defaultValue: 0.5 + }, + { + optionName: 'minSuppressionThreshold', + protoName: 'minSuppressionThreshold', + customValue: 0.2, + defaultValue: 0.3 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await faceDetector.setOptions( + {[testCase.optionName]: testCase.customValue}); + verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]); + await faceDetector.setOptions({[testCase.optionName]: undefined}); + verifyGraph(faceDetector, [testCase.protoName, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + faceDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('transforms results', async () => { + const detection = new DetectionProto(); + detection.addScore(0.1); + const locationData = new LocationData(); + const boundingBox = new LocationData.BoundingBox(); + locationData.setBoundingBox(boundingBox); + detection.setLocationData(locationData); + + const binaryProto = detection.serializeBinary(); + + // Pass the test data to our listener + faceDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceDetector); + faceDetector.protoListener!([binaryProto], 1337); + }); + + // Invoke the face detector + const {detections} = faceDetector.detect({} as HTMLImageElement); + + expect(faceDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(detections.length).toEqual(1); + expect(detections[0]).toEqual({ + categories: [{ + score: 0.1, + index: -1, + categoryName: '', + displayName: '', + }], + boundingBox: {originX: 0, originY: 0, width: 0, height: 0} + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 856d84683..4882e22c4 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,6 +15,7 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; @@ -28,6 +29,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceStylizer = FaceStylizerImpl; const GestureRecognizer = GestureRecognizerImpl; @@ -40,6 +42,7 @@ const ObjectDetector = ObjectDetectorImpl; export { FilesetResolver, + FaceDetector, FaceLandmarker, FaceStylizer, GestureRecognizer, diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 2756b05a5..f49161adf 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -15,6 +15,7 @@ */ export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';