FaceDetector Web API
PiperOrigin-RevId: 521816795
This commit is contained in:
parent
33cad24a5a
commit
a98f6bf231
|
@ -26,6 +26,7 @@ mediapipe_ts_declaration(
|
|||
deps = [
|
||||
":bounding_box",
|
||||
":category",
|
||||
":keypoint",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
|
||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||
import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
|
||||
|
||||
/** Represents one detection by a detection task. */
|
||||
export declare interface Detection {
|
||||
|
@ -24,6 +25,15 @@ export declare interface Detection {
|
|||
|
||||
/** The bounding box of the detected objects. */
|
||||
boundingBox?: BoundingBox;
|
||||
|
||||
/**
|
||||
* Optional list of keypoints associated with the detection. Keypoints
|
||||
* represent interesting points related to the detection. For example, the
|
||||
* keypoints represent the eye, ear and mouth from face detection model. Or
|
||||
* in the template matching detection, e.g. KNIFT, they can represent the
|
||||
* feature points for template matching.
|
||||
*/
|
||||
keypoints?: NormalizedKeypoint[];
|
||||
}
|
||||
|
||||
/** Detection results of a model. */
|
||||
|
|
|
@ -31,6 +31,7 @@ describe('convertFromDetectionProto()', () => {
|
|||
detection.addLabelId(1);
|
||||
detection.addLabel('foo');
|
||||
detection.addDisplayName('bar');
|
||||
|
||||
const locationData = new LocationData();
|
||||
const boundingBox = new LocationData.BoundingBox();
|
||||
boundingBox.setXmin(1);
|
||||
|
@ -38,6 +39,14 @@ describe('convertFromDetectionProto()', () => {
|
|||
boundingBox.setWidth(3);
|
||||
boundingBox.setHeight(4);
|
||||
locationData.setBoundingBox(boundingBox);
|
||||
|
||||
const keypoint = new LocationData.RelativeKeypoint();
|
||||
keypoint.setX(5);
|
||||
keypoint.setY(6);
|
||||
keypoint.setScore(0.7);
|
||||
keypoint.setKeypointLabel('bar');
|
||||
locationData.addRelativeKeypoints(new LocationData.RelativeKeypoint());
|
||||
|
||||
detection.setLocationData(locationData);
|
||||
|
||||
const result = convertFromDetectionProto(detection);
|
||||
|
@ -49,7 +58,13 @@ describe('convertFromDetectionProto()', () => {
|
|||
categoryName: 'foo',
|
||||
displayName: 'bar',
|
||||
}],
|
||||
boundingBox: {originX: 1, originY: 2, width: 3, height: 4}
|
||||
boundingBox: {originX: 1, originY: 2, width: 3, height: 4},
|
||||
keypoints: [{
|
||||
x: 5,
|
||||
y: 6,
|
||||
score: 0.7,
|
||||
label: 'bar',
|
||||
}],
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -46,5 +46,18 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
|||
};
|
||||
}
|
||||
|
||||
if (source.getLocationData()?.getRelativeKeypointsList().length) {
|
||||
detection.keypoints = [];
|
||||
for (const keypoint of
|
||||
source.getLocationData()!.getRelativeKeypointsList()) {
|
||||
detection.keypoints.push({
|
||||
x: keypoint.getX() ?? 0.0,
|
||||
y: keypoint.getY() ?? 0.0,
|
||||
score: keypoint.getScore() ?? 0.0,
|
||||
label: keypoint.getKeypointLabel() ?? '',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return detection;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ mediapipe_files(srcs = [
|
|||
|
||||
VISION_LIBS = [
|
||||
"//mediapipe/tasks/web/core:fileset_resolver",
|
||||
"//mediapipe/tasks/web/vision/face_detector",
|
||||
"//mediapipe/tasks/web/vision/face_landmarker",
|
||||
"//mediapipe/tasks/web/vision/face_stylizer",
|
||||
"//mediapipe/tasks/web/vision/gesture_recognizer",
|
||||
|
|
|
@ -2,6 +2,22 @@
|
|||
|
||||
This package contains the vision tasks for MediaPipe.
|
||||
|
||||
## Face Detection
|
||||
|
||||
The MediaPipe Face Detector task lets you detect the presence and location of
|
||||
faces within images or videos.
|
||||
|
||||
```
|
||||
const vision = await FilesetResolver.forVisionTasks(
|
||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||
);
|
||||
const faceDetector = await FaceDetector.createFromModelPath(vision,
|
||||
"https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite"
|
||||
);
|
||||
const image = document.getElementById("image") as HTMLImageElement;
|
||||
const detections = faceDetector.detect(image);
|
||||
```
|
||||
|
||||
## Face Landmark Detection
|
||||
|
||||
The MediaPipe Face Landmarker task lets you detect the landmarks of faces in
|
||||
|
|
71
mediapipe/tasks/web/vision/face_detector/BUILD
Normal file
71
mediapipe/tasks/web/vision/face_detector/BUILD
Normal file
|
@ -0,0 +1,71 @@
|
|||
# This contains the MediaPipe Face Detector Task.
|
||||
#
|
||||
# This task takes video frames and outputs synchronized frames along with
|
||||
# the detection results for one or more faces, using Face Detector.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "face_detector",
|
||||
srcs = ["face_detector.ts"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":face_detector_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/framework/formats:detection_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/processors:detection_result",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_declaration(
|
||||
name = "face_detector_types",
|
||||
srcs = [
|
||||
"face_detector_options.d.ts",
|
||||
"face_detector_result.d.ts",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/components/containers:bounding_box",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:detection_result",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "face_detector_test_lib",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"face_detector_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":face_detector",
|
||||
":face_detector_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework/formats:detection_jspb_proto",
|
||||
"//mediapipe/framework/formats:location_data_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "face_detector_test",
|
||||
tags = ["nomsan"],
|
||||
deps = [":face_detector_test_lib"],
|
||||
)
|
213
mediapipe/tasks/web/vision/face_detector/face_detector.ts
Normal file
213
mediapipe/tasks/web/vision/face_detector/face_detector.ts
Normal file
|
@ -0,0 +1,213 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {FaceDetectorGraphOptions as FaceDetectorGraphOptionsProto} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb';
|
||||
import {convertFromDetectionProto} from '../../../../tasks/web/components/processors/detection_result';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {FaceDetectorOptions} from './face_detector_options';
|
||||
import {FaceDetectorResult} from './face_detector_result';
|
||||
|
||||
const IMAGE_STREAM = 'image_in';
|
||||
const NORM_RECT_STREAM = 'norm_rect_in';
|
||||
const DETECTIONS_STREAM = 'detections';
|
||||
const FACE_DETECTOR_GRAPH =
|
||||
'mediapipe.tasks.vision.face_detector.FaceDetectorGraph';
|
||||
|
||||
export * from './face_detector_options';
|
||||
export * from './face_detector_result';
|
||||
export {ImageSource}; // Used in the public API
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/** Performs face detection on images. */
|
||||
export class FaceDetector extends VisionTaskRunner {
|
||||
private result: FaceDetectorResult = {detections: []};
|
||||
private readonly options = new FaceDetectorGraphOptionsProto();
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new face detector from the
|
||||
* provided options.
|
||||
* @param wasmFileset A configuration object that provides the location of the
|
||||
* Wasm binary and its loader.
|
||||
* @param faceDetectorOptions The options for the FaceDetector. Note that
|
||||
* either a path to the model asset or a model buffer needs to be
|
||||
* provided (via `baseOptions`).
|
||||
*/
|
||||
static createFromOptions(
|
||||
wasmFileset: WasmFileset,
|
||||
faceDetectorOptions: FaceDetectorOptions): Promise<FaceDetector> {
|
||||
return VisionTaskRunner.createVisionInstance(
|
||||
FaceDetector, wasmFileset, faceDetectorOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new face detector based on the
|
||||
* provided model asset buffer.
|
||||
* @param wasmFileset A configuration object that provides the location of the
|
||||
* Wasm binary and its loader.
|
||||
* @param modelAssetBuffer A binary representation of the model.
|
||||
*/
|
||||
static createFromModelBuffer(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetBuffer: Uint8Array): Promise<FaceDetector> {
|
||||
return VisionTaskRunner.createVisionInstance(
|
||||
FaceDetector, wasmFileset, {baseOptions: {modelAssetBuffer}});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new face detector based on the
|
||||
* path to the model asset.
|
||||
* @param wasmFileset A configuration object that provides the location of the
|
||||
* Wasm binary and its loader.
|
||||
* @param modelAssetPath The path to the model asset.
|
||||
*/
|
||||
static async createFromModelPath(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetPath: string): Promise<FaceDetector> {
|
||||
return VisionTaskRunner.createVisionInstance(
|
||||
FaceDetector, wasmFileset, {baseOptions: {modelAssetPath}});
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(
|
||||
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
|
||||
NORM_RECT_STREAM, /* roiAllowed= */ false);
|
||||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
this.options.setMinDetectionConfidence(0.5);
|
||||
this.options.setMinSuppressionThreshold(0.3);
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the FaceDetector.
|
||||
*
|
||||
* Calling `setOptions()` with a subset of options only affects those options.
|
||||
* You can reset an option back to its default value by explicitly setting it
|
||||
* to `undefined`.
|
||||
*
|
||||
* @param options The options for the FaceDetector.
|
||||
*/
|
||||
override setOptions(options: FaceDetectorOptions): Promise<void> {
|
||||
if ('minDetectionConfidence' in options) {
|
||||
this.options.setMinDetectionConfidence(
|
||||
options.minDetectionConfidence ?? 0.5);
|
||||
}
|
||||
if ('minSuppressionThreshold' in options) {
|
||||
this.options.setMinSuppressionThreshold(
|
||||
options.minSuppressionThreshold ?? 0.3);
|
||||
}
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs face detection on the provided single image and waits
|
||||
* synchronously for the response. Only use this method when the
|
||||
* FaceDetector is created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @return A result containing the list of detected faces.
|
||||
*/
|
||||
detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions):
|
||||
FaceDetectorResult {
|
||||
this.result = {detections: []};
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
return this.result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs face detection on the provided video frame and waits
|
||||
* synchronously for the response. Only use this method when the
|
||||
* FaceDetector is created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @return A result containing the list of detected faces.
|
||||
*/
|
||||
detectForVideo(
|
||||
videoFrame: ImageSource, timestamp: number,
|
||||
imageProcessingOptions?: ImageProcessingOptions): FaceDetectorResult {
|
||||
this.result = {detections: []};
|
||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||
return this.result;
|
||||
}
|
||||
|
||||
/** Converts raw data into a Detection, and adds it to our detection list. */
|
||||
private addJsFaceDetections(data: Uint8Array[]): void {
|
||||
for (const binaryProto of data) {
|
||||
const detectionProto = DetectionProto.deserializeBinary(binaryProto);
|
||||
this.result.detections.push(convertFromDetectionProto(detectionProto));
|
||||
}
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
FaceDetectorGraphOptionsProto.ext, this.options);
|
||||
|
||||
const detectorNode = new CalculatorGraphConfig.Node();
|
||||
detectorNode.setCalculator(FACE_DETECTOR_GRAPH);
|
||||
detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||
detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||
detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM);
|
||||
detectorNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(detectorNode);
|
||||
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
DETECTIONS_STREAM, (binaryProto, timestamp) => {
|
||||
this.addJsFaceDetections(binaryProto);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(DETECTIONS_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
}
|
||||
|
||||
|
33
mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts
vendored
Normal file
33
mediapipe/tasks/web/vision/face_detector/face_detector_options.d.ts
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
|
||||
|
||||
/** Options to configure the MediaPipe Face Detector Task */
|
||||
export interface FaceDetectorOptions extends VisionTaskOptions {
|
||||
/**
|
||||
* The minimum confidence score for the face detection to be considered
|
||||
* successful. Defaults to 0.5.
|
||||
*/
|
||||
minDetectionConfidence?: number|undefined;
|
||||
|
||||
/**
|
||||
* The minimum non-maximum-suppression threshold for face detection to be
|
||||
* considered overlapped. Defaults to 0.3.
|
||||
*/
|
||||
minSuppressionThreshold?: number|undefined;
|
||||
}
|
19
mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts
vendored
Normal file
19
mediapipe/tasks/web/vision/face_detector/face_detector_result.d.ts
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
export {BoundingBox} from '../../../../tasks/web/components/containers/bounding_box';
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {Detection, DetectionResult as FaceDetectorResult} from '../../../../tasks/web/components/containers/detection_result';
|
193
mediapipe/tasks/web/vision/face_detector/face_detector_test.ts
Normal file
193
mediapipe/tasks/web/vision/face_detector/face_detector_test.ts
Normal file
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
|
||||
import {LocationData} from '../../../../framework/formats/location_data_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
|
||||
import {FaceDetector} from './face_detector';
|
||||
import {FaceDetectorOptions} from './face_detector_options';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
class FaceDetectorFake extends FaceDetector implements MediapipeTasksFake {
|
||||
lastSampleRate: number|undefined;
|
||||
calculatorName = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph';
|
||||
attachListenerSpies: jasmine.Spy[] = [];
|
||||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
protoListener:
|
||||
((binaryProtos: Uint8Array[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('detections');
|
||||
this.protoListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||
}
|
||||
}
|
||||
|
||||
describe('FaceDetector', () => {
|
||||
let faceDetector: FaceDetectorFake;
|
||||
|
||||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
faceDetector = new FaceDetectorFake();
|
||||
await faceDetector.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(faceDetector);
|
||||
verifyListenersRegistered(faceDetector);
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await faceDetector.setOptions({minDetectionConfidence: 0.1});
|
||||
verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]);
|
||||
verifyListenersRegistered(faceDetector);
|
||||
|
||||
await faceDetector.setOptions({minDetectionConfidence: 0.2});
|
||||
verifyGraph(faceDetector, ['minDetectionConfidence', 0.2]);
|
||||
verifyListenersRegistered(faceDetector);
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||
await faceDetector.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: newModel,
|
||||
}
|
||||
});
|
||||
|
||||
verifyGraph(
|
||||
faceDetector,
|
||||
/* expectedCalculatorOptions= */ undefined,
|
||||
/* expectedBaseOptions= */
|
||||
[
|
||||
'modelAsset', {
|
||||
fileContent: newModelBase64,
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it('merges options', async () => {
|
||||
await faceDetector.setOptions({minDetectionConfidence: 0.1});
|
||||
await faceDetector.setOptions({minSuppressionThreshold: 0.2});
|
||||
verifyGraph(faceDetector, ['minDetectionConfidence', 0.1]);
|
||||
verifyGraph(faceDetector, ['minSuppressionThreshold', 0.2]);
|
||||
});
|
||||
|
||||
describe('setOptions()', () => {
|
||||
interface TestCase {
|
||||
optionName: keyof FaceDetectorOptions;
|
||||
protoName: string;
|
||||
customValue: unknown;
|
||||
defaultValue: unknown;
|
||||
}
|
||||
|
||||
const testCases: TestCase[] = [
|
||||
{
|
||||
optionName: 'minDetectionConfidence',
|
||||
protoName: 'minDetectionConfidence',
|
||||
customValue: 0.1,
|
||||
defaultValue: 0.5
|
||||
},
|
||||
{
|
||||
optionName: 'minSuppressionThreshold',
|
||||
protoName: 'minSuppressionThreshold',
|
||||
customValue: 0.2,
|
||||
defaultValue: 0.3
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
it(`can set ${testCase.optionName}`, async () => {
|
||||
await faceDetector.setOptions(
|
||||
{[testCase.optionName]: testCase.customValue});
|
||||
verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]);
|
||||
});
|
||||
|
||||
it(`can clear ${testCase.optionName}`, async () => {
|
||||
await faceDetector.setOptions(
|
||||
{[testCase.optionName]: testCase.customValue});
|
||||
verifyGraph(faceDetector, [testCase.protoName, testCase.customValue]);
|
||||
await faceDetector.setOptions({[testCase.optionName]: undefined});
|
||||
verifyGraph(faceDetector, [testCase.protoName, testCase.defaultValue]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
faceDetector.detect(
|
||||
{} as HTMLImageElement,
|
||||
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
|
||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||
});
|
||||
|
||||
it('transforms results', async () => {
|
||||
const detection = new DetectionProto();
|
||||
detection.addScore(0.1);
|
||||
const locationData = new LocationData();
|
||||
const boundingBox = new LocationData.BoundingBox();
|
||||
locationData.setBoundingBox(boundingBox);
|
||||
detection.setLocationData(locationData);
|
||||
|
||||
const binaryProto = detection.serializeBinary();
|
||||
|
||||
// Pass the test data to our listener
|
||||
faceDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(faceDetector);
|
||||
faceDetector.protoListener!([binaryProto], 1337);
|
||||
});
|
||||
|
||||
// Invoke the face detector
|
||||
const {detections} = faceDetector.detect({} as HTMLImageElement);
|
||||
|
||||
expect(faceDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(detections.length).toEqual(1);
|
||||
expect(detections[0]).toEqual({
|
||||
categories: [{
|
||||
score: 0.1,
|
||||
index: -1,
|
||||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
});
|
||||
});
|
||||
});
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||
import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector';
|
||||
import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||
import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
|
||||
|
@ -28,6 +29,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob
|
|||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||
// as exports.
|
||||
const FilesetResolver = FilesetResolverImpl;
|
||||
const FaceDetector = FaceDetectorImpl;
|
||||
const FaceLandmarker = FaceLandmarkerImpl;
|
||||
const FaceStylizer = FaceStylizerImpl;
|
||||
const GestureRecognizer = GestureRecognizerImpl;
|
||||
|
@ -40,6 +42,7 @@ const ObjectDetector = ObjectDetectorImpl;
|
|||
|
||||
export {
|
||||
FilesetResolver,
|
||||
FaceDetector,
|
||||
FaceLandmarker,
|
||||
FaceStylizer,
|
||||
GestureRecognizer,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
export * from '../../../tasks/web/core/fileset_resolver';
|
||||
export * from '../../../tasks/web/vision/face_detector/face_detector';
|
||||
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
|
||||
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
|
||||
export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
|
||||
|
|
Loading…
Reference in New Issue
Block a user