Support WASM asset loading for MediaPipe Task Web

PiperOrigin-RevId: 547882566
This commit is contained in:
Sebastian Schmidt 2023-07-13 12:24:57 -07:00 committed by Copybara-Service
parent 8b59567cb7
commit 327feb42d1
3 changed files with 86 additions and 68 deletions

View File

@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis
import {WasmFileset} from './wasm_fileset';
// None of the MP Tasks ship bundle assets.
const NO_ASSETS = undefined;
// Internal stream names for temporarily keeping memory alive, then freeing it.
const FREE_MEMORY_STREAM = 'free_memory';
const UNUSED_STREAM_SUFFIX = '_unused_out';
@ -61,7 +58,8 @@ export async function createTaskRunner<T extends TaskRunner>(
};
const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
fileLocator);
await instance.setOptions(options);
return instance;
}
@ -96,65 +94,73 @@ export abstract class TaskRunner {
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
/**
* Applies the current set of options, including any base options that have
* not been processed by the task implementation. The options are applied
* synchronously unless a `modelAssetPath` is provided. This ensures that
* for most use cases options are applied directly and immediately affect
* Applies the current set of options, including optionally any base options
* that have not been processed by the task implementation. The options are
* applied synchronously unless a `modelAssetPath` is provided. This ensures
* that for most use cases options are applied directly and immediately affect
* the next inference.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
const baseOptions: BaseOptions = options.baseOptions || {};
protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
Promise<void> {
if (loadTfliteModel) {
const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
try {
// Try to delete file as we cannot overwite an existing file
// using our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
this.setExternalFile(baseOptions.modelAssetBuffer);
}
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
try {
// Try to delete file as we cannot overwite an existing file using
// our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
this.onGraphRefreshed();
return Promise.resolve();
}
// If there is no model to download, we can apply the setting synchronously.
this.refreshGraph();
this.onGraphRefreshed();
return Promise.resolve();
}
/** Appliest the current options to the MediaPipe graph. */

View File

@ -22,4 +22,6 @@ export declare interface WasmFileset {
wasmLoaderPath: string;
/** The path to the Wasm binary. */
wasmBinaryPath: string;
/** The optional path to the asset loader script. */
assetLoaderPath?: string;
}

View File

@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner {
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image
* stream used to provide (mandatory) rotation and (optional)
* region-of-interest.
* region-of-interest. `null` if the graph does not support normalized
* rects.
* @param roiAllowed Whether this task supports Region-Of-Interest
* pre-processing
*
@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner {
constructor(
protected override readonly graphRunner: VisionGraphRunner,
private readonly imageStreamName: string,
private readonly normRectStreamName: string,
private readonly normRectStreamName: string|null,
private readonly roiAllowed: boolean) {
super(graphRunner);
}
/** Configures the shared options of a vision task. */
override applyOptions(options: VisionTaskOptions): Promise<void> {
/**
* Configures the shared options of a vision task.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/
override applyOptions(options: VisionTaskOptions, loadTfliteModel = true):
Promise<void> {
if ('runningMode' in options) {
const useStreamMode =
!!options.runningMode && options.runningMode !== 'IMAGE';
@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
}
}
return super.applyOptions(options);
return super.applyOptions(options, loadTfliteModel);
}
/** Sends a single image to the graph and awaits results. */
@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner {
imageSource: ImageSource,
imageProcessingOptions: ImageProcessingOptions|undefined,
timestamp: number): void {
const normalizedRect =
this.convertToNormalizedRect(imageSource, imageProcessingOptions);
this.graphRunner.addProtoToStream(
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
this.normRectStreamName, timestamp);
if (this.normRectStreamName) {
const normalizedRect =
this.convertToNormalizedRect(imageSource, imageProcessingOptions);
this.graphRunner.addProtoToStream(
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
this.normRectStreamName, timestamp);
}
this.graphRunner.addGpuBufferAsImageToStream(
imageSource, this.imageStreamName, timestamp ?? performance.now());
this.finishProcessing();