329 lines
12 KiB
TypeScript
329 lines
12 KiB
TypeScript
/**
|
|
* Copyright 2022 The MediaPipe Authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
|
import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
|
|
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
|
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
|
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
|
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner';
|
|
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
|
|
|
import {WasmFileset} from './wasm_fileset';
|
|
|
|
// Internal stream names for temporarily keeping memory alive, then freeing it.
|
|
const FREE_MEMORY_STREAM = 'free_memory';
|
|
const UNUSED_STREAM_SUFFIX = '_unused_out';
|
|
|
|
// tslint:disable-next-line:enforce-name-casing
|
|
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
|
|
|
|
// The OSS JS API does not support the builder pattern.
|
|
// tslint:disable:jspb-use-builder-pattern
|
|
|
|
/**
|
|
* An implementation of the GraphRunner that exposes the resource graph
|
|
* service.
|
|
*/
|
|
export class CachedGraphRunner extends CachedGraphRunnerType {}
|
|
|
|
/**
|
|
* Creates a new instance of a Mediapipe Task. Determines if SIMD is
|
|
* supported and loads the relevant WASM binary.
|
|
* @return A fully instantiated instance of `T`.
|
|
*/
|
|
export async function createTaskRunner<T extends TaskRunner>(
|
|
type: WasmMediaPipeConstructor<T>,
|
|
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
|
|
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
|
|
const fileLocator: FileLocator = {
|
|
locateFile() {
|
|
// The only file loaded with this mechanism is the Wasm binary
|
|
return fileset.wasmBinaryPath.toString();
|
|
}
|
|
};
|
|
|
|
const instance = await createMediaPipeLib(
|
|
type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
|
|
fileLocator);
|
|
await instance.setOptions(options);
|
|
return instance;
|
|
}
|
|
|
|
/** Base class for all MediaPipe Tasks. */
|
|
export abstract class TaskRunner {
|
|
protected abstract baseOptions: BaseOptionsProto;
|
|
private processingErrors: Error[] = [];
|
|
private latestOutputTimestamp = 0;
|
|
private keepaliveNode?: CalculatorGraphConfig.Node;
|
|
|
|
/**
|
|
* Creates a new instance of a Mediapipe Task. Determines if SIMD is
|
|
* supported and loads the relevant WASM binary.
|
|
* @return A fully instantiated instance of `T`.
|
|
*/
|
|
protected static async createInstance<T extends TaskRunner>(
|
|
type: WasmMediaPipeConstructor<T>,
|
|
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
|
|
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
|
|
return createTaskRunner(type, canvas, fileset, options);
|
|
}
|
|
|
|
/** @hideconstructor protected */
|
|
constructor(protected readonly graphRunner: CachedGraphRunner) {
|
|
// Disables the automatic render-to-screen code, which allows for pure
|
|
// CPU processing.
|
|
this.graphRunner.setAutoRenderToScreen(false);
|
|
}
|
|
|
|
/** Configures the task with custom options. */
|
|
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
|
|
|
/**
|
|
* Applies the current set of options, including optionally any base options
|
|
* that have not been processed by the task implementation. The options are
|
|
* applied synchronously unless a `modelAssetPath` is provided. This ensures
|
|
* that for most use cases options are applied directly and immediately affect
|
|
* the next inference.
|
|
*
|
|
* @param options The options for the task.
|
|
* @param loadTfliteModel Whether to load the model specified in
|
|
* `options.baseOptions`.
|
|
*/
|
|
protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
|
|
Promise<void> {
|
|
if (loadTfliteModel) {
|
|
const baseOptions: BaseOptions = options.baseOptions || {};
|
|
|
|
// Validate that exactly one model is configured
|
|
if (options.baseOptions?.modelAssetBuffer &&
|
|
options.baseOptions?.modelAssetPath) {
|
|
throw new Error(
|
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
|
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
|
this.baseOptions.getModelAsset()?.hasFileName() ||
|
|
options.baseOptions?.modelAssetBuffer ||
|
|
options.baseOptions?.modelAssetPath)) {
|
|
throw new Error(
|
|
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
|
}
|
|
|
|
this.setAcceleration(baseOptions);
|
|
if (baseOptions.modelAssetPath) {
|
|
// We don't use `await` here since we want to apply most settings
|
|
// synchronously.
|
|
return fetch(baseOptions.modelAssetPath.toString())
|
|
.then(response => {
|
|
if (!response.ok) {
|
|
throw new Error(`Failed to fetch model: ${
|
|
baseOptions.modelAssetPath} (${response.status})`);
|
|
} else {
|
|
return response.arrayBuffer();
|
|
}
|
|
})
|
|
.then(buffer => {
|
|
try {
|
|
// Try to delete file as we cannot overwite an existing file
|
|
// using our current API.
|
|
this.graphRunner.wasmModule.FS_unlink('/model.dat');
|
|
} catch {
|
|
}
|
|
// TODO: Consider passing the model to the graph as an
|
|
// input side packet as this might reduce copies.
|
|
this.graphRunner.wasmModule.FS_createDataFile(
|
|
'/', 'model.dat', new Uint8Array(buffer),
|
|
/* canRead= */ true, /* canWrite= */ false,
|
|
/* canOwn= */ false);
|
|
this.setExternalFile('/model.dat');
|
|
this.refreshGraph();
|
|
this.onGraphRefreshed();
|
|
});
|
|
} else {
|
|
this.setExternalFile(baseOptions.modelAssetBuffer);
|
|
}
|
|
}
|
|
|
|
// If there is no model to download, we can apply the setting synchronously.
|
|
this.refreshGraph();
|
|
this.onGraphRefreshed();
|
|
return Promise.resolve();
|
|
}
|
|
|
|
/** Appliest the current options to the MediaPipe graph. */
|
|
protected abstract refreshGraph(): void;
|
|
|
|
/**
|
|
* Callback that gets invoked once a new graph configuration has been
|
|
* applied.
|
|
*/
|
|
protected onGraphRefreshed(): void {}
|
|
|
|
/** Returns the current CalculatorGraphConfig. */
|
|
protected getCalculatorGraphConfig(): CalculatorGraphConfig {
|
|
let config: CalculatorGraphConfig|undefined;
|
|
this.graphRunner.getCalculatorGraphConfig(binaryData => {
|
|
config = CalculatorGraphConfig.deserializeBinary(binaryData);
|
|
});
|
|
if (!config) {
|
|
throw new Error('Failed to retrieve CalculatorGraphConfig');
|
|
}
|
|
return config;
|
|
}
|
|
|
|
/**
|
|
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
|
* over the video stream. Will replace the previously running MediaPipe graph,
|
|
* if there is one.
|
|
* @param graphData The raw MediaPipe graph data, either in binary
|
|
* protobuffer format (.binarypb), or else in raw text format (.pbtxt or
|
|
* .textproto).
|
|
* @param isBinary This should be set to true if the graph is in
|
|
* binary format, and false if it is in human-readable text format.
|
|
*/
|
|
protected setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
|
this.graphRunner.attachErrorListener((code, message) => {
|
|
this.processingErrors.push(new Error(message));
|
|
});
|
|
|
|
// Enables use of our model resource caching graph service; we apply this to
|
|
// every MediaPipe graph we run.
|
|
this.graphRunner.registerModelResourcesGraphService();
|
|
|
|
this.graphRunner.setGraph(graphData, isBinary);
|
|
this.keepaliveNode = undefined;
|
|
this.handleErrors();
|
|
}
|
|
|
|
/**
|
|
* Forces all queued-up packets to be pushed through the MediaPipe graph as
|
|
* far as possible, performing all processing until no more processing can be
|
|
* done.
|
|
*/
|
|
protected finishProcessing(): void {
|
|
this.graphRunner.finishProcessing();
|
|
this.handleErrors();
|
|
}
|
|
|
|
/*
|
|
* Sets the latest output timestamp received from the graph (in ms).
|
|
* Timestamps that are smaller than the currently latest output timestamp are
|
|
* ignored.
|
|
*/
|
|
protected setLatestOutputTimestamp(timestamp: number): void {
|
|
this.latestOutputTimestamp =
|
|
Math.max(this.latestOutputTimestamp, timestamp);
|
|
}
|
|
|
|
/**
|
|
* Gets a syncthethic timestamp in ms that can be used to send data to the
|
|
* next packet. The timestamp is one millisecond past the last timestamp
|
|
* received from the graph.
|
|
*/
|
|
protected getSynctheticTimestamp(): number {
|
|
return this.latestOutputTimestamp + 1;
|
|
}
|
|
|
|
/** Throws the error from the error listener if an error was raised. */
|
|
private handleErrors() {
|
|
try {
|
|
const errorCount = this.processingErrors.length;
|
|
if (errorCount === 1) {
|
|
// Re-throw error to get a more meaningful stacktrace
|
|
throw new Error(this.processingErrors[0].message);
|
|
} else if (errorCount > 1) {
|
|
throw new Error(
|
|
'Encountered multiple errors: ' +
|
|
this.processingErrors.map(e => e.message).join(', '));
|
|
}
|
|
} finally {
|
|
this.processingErrors = [];
|
|
}
|
|
}
|
|
|
|
/** Configures the `externalFile` option */
|
|
private setExternalFile(modelAssetPath?: string): void;
|
|
private setExternalFile(modelAssetBuffer?: Uint8Array): void;
|
|
private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
|
|
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
|
if (typeof modelAssetPathOrBuffer === 'string') {
|
|
externalFile.setFileName(modelAssetPathOrBuffer);
|
|
externalFile.clearFileContent();
|
|
} else if (modelAssetPathOrBuffer instanceof Uint8Array) {
|
|
externalFile.setFileContent(modelAssetPathOrBuffer);
|
|
externalFile.clearFileName();
|
|
}
|
|
this.baseOptions.setModelAsset(externalFile);
|
|
}
|
|
|
|
/** Configures the `acceleration` option. */
|
|
private setAcceleration(options: BaseOptions) {
|
|
let acceleration = this.baseOptions.getAcceleration();
|
|
|
|
if (!acceleration) {
|
|
// Create default instance for the initial configuration.
|
|
acceleration = new Acceleration();
|
|
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
|
}
|
|
|
|
if ('delegate' in options) {
|
|
if (options.delegate === 'GPU') {
|
|
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
|
} else {
|
|
acceleration.setTflite(
|
|
new InferenceCalculatorOptions.Delegate.TfLite());
|
|
}
|
|
}
|
|
|
|
this.baseOptions.setAcceleration(acceleration);
|
|
}
|
|
|
|
/**
|
|
* Adds a node to the graph to temporarily keep certain streams alive.
|
|
* NOTE: To use this call, PassThroughCalculator must be included in your wasm
|
|
* dependencies.
|
|
*/
|
|
protected addKeepaliveNode(graphConfig: CalculatorGraphConfig) {
|
|
this.keepaliveNode = new CalculatorGraphConfig.Node();
|
|
this.keepaliveNode.setCalculator('PassThroughCalculator');
|
|
this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM);
|
|
this.keepaliveNode.addOutputStream(
|
|
FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX);
|
|
graphConfig.addInputStream(FREE_MEMORY_STREAM);
|
|
graphConfig.addNode(this.keepaliveNode);
|
|
}
|
|
|
|
/** Adds streams to the keepalive node to be kept alive until callback. */
|
|
protected keepStreamAlive(streamName: string) {
|
|
this.keepaliveNode!.addInputStream(streamName);
|
|
this.keepaliveNode!.addOutputStream(streamName + UNUSED_STREAM_SUFFIX);
|
|
}
|
|
|
|
/** Frees any streams being kept alive by the keepStreamAlive callback. */
|
|
protected freeKeepaliveStreams() {
|
|
this.graphRunner.addBoolToStream(
|
|
true, FREE_MEMORY_STREAM, this.latestOutputTimestamp);
|
|
}
|
|
|
|
/** Closes and cleans up the resources held by this task. */
|
|
close(): void {
|
|
this.keepaliveNode = undefined;
|
|
this.graphRunner.closeGraph();
|
|
}
|
|
}
|
|
|
|
|