Support WASM asset loading for MediaPipe Task Web

PiperOrigin-RevId: 547882566
This commit is contained in:
Sebastian Schmidt 2023-07-13 12:24:57 -07:00 committed by Copybara-Service
parent 8b59567cb7
commit 327feb42d1
3 changed files with 86 additions and 68 deletions

View File

@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis
import {WasmFileset} from './wasm_fileset'; import {WasmFileset} from './wasm_fileset';
// None of the MP Tasks ship bundle assets.
const NO_ASSETS = undefined;
// Internal stream names for temporarily keeping memory alive, then freeing it. // Internal stream names for temporarily keeping memory alive, then freeing it.
const FREE_MEMORY_STREAM = 'free_memory'; const FREE_MEMORY_STREAM = 'free_memory';
const UNUSED_STREAM_SUFFIX = '_unused_out'; const UNUSED_STREAM_SUFFIX = '_unused_out';
@ -61,7 +58,8 @@ export async function createTaskRunner<T extends TaskRunner>(
}; };
const instance = await createMediaPipeLib( const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
fileLocator);
await instance.setOptions(options); await instance.setOptions(options);
return instance; return instance;
} }
@ -96,65 +94,73 @@ export abstract class TaskRunner {
abstract setOptions(options: TaskRunnerOptions): Promise<void>; abstract setOptions(options: TaskRunnerOptions): Promise<void>;
/** /**
* Applies the current set of options, including any base options that have * Applies the current set of options, including optionally any base options
* not been processed by the task implementation. The options are applied * that have not been processed by the task implementation. The options are
* synchronously unless a `modelAssetPath` is provided. This ensures that * applied synchronously unless a `modelAssetPath` is provided. This ensures
* for most use cases options are applied directly and immediately affect * that for most use cases options are applied directly and immediately affect
* the next inference. * the next inference.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/ */
protected applyOptions(options: TaskRunnerOptions): Promise<void> { protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
const baseOptions: BaseOptions = options.baseOptions || {}; Promise<void> {
if (loadTfliteModel) {
const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured // Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer && if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) { options.baseOptions?.modelAssetPath) {
throw new Error( throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() || this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer || options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) { options.baseOptions?.modelAssetPath)) {
throw new Error( throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
try {
// Try to delete file as we cannot overwite an existing file
// using our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
this.setExternalFile(baseOptions.modelAssetBuffer);
}
} }
this.setAcceleration(baseOptions); // If there is no model to download, we can apply the setting synchronously.
if (baseOptions.modelAssetPath) { this.refreshGraph();
// We don't use `await` here since we want to apply most settings this.onGraphRefreshed();
// synchronously. return Promise.resolve();
return fetch(baseOptions.modelAssetPath.toString())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
try {
// Try to delete file as we cannot overwite an existing file using
// our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
this.onGraphRefreshed();
return Promise.resolve();
}
} }
/** Appliest the current options to the MediaPipe graph. */ /** Appliest the current options to the MediaPipe graph. */

View File

@ -22,4 +22,6 @@ export declare interface WasmFileset {
wasmLoaderPath: string; wasmLoaderPath: string;
/** The path to the Wasm binary. */ /** The path to the Wasm binary. */
wasmBinaryPath: string; wasmBinaryPath: string;
/** The optional path to the asset loader script. */
assetLoaderPath?: string;
} }

View File

@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner {
* @param imageStreamName the name of the input image stream. * @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image * @param normRectStreamName the name of the input normalized rect image
* stream used to provide (mandatory) rotation and (optional) * stream used to provide (mandatory) rotation and (optional)
* region-of-interest. * region-of-interest. `null` if the graph does not support normalized
* rects.
* @param roiAllowed Whether this task supports Region-Of-Interest * @param roiAllowed Whether this task supports Region-Of-Interest
* pre-processing * pre-processing
* *
@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner {
constructor( constructor(
protected override readonly graphRunner: VisionGraphRunner, protected override readonly graphRunner: VisionGraphRunner,
private readonly imageStreamName: string, private readonly imageStreamName: string,
private readonly normRectStreamName: string, private readonly normRectStreamName: string|null,
private readonly roiAllowed: boolean) { private readonly roiAllowed: boolean) {
super(graphRunner); super(graphRunner);
} }
/** Configures the shared options of a vision task. */ /**
override applyOptions(options: VisionTaskOptions): Promise<void> { * Configures the shared options of a vision task.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/
override applyOptions(options: VisionTaskOptions, loadTfliteModel = true):
Promise<void> {
if ('runningMode' in options) { if ('runningMode' in options) {
const useStreamMode = const useStreamMode =
!!options.runningMode && options.runningMode !== 'IMAGE'; !!options.runningMode && options.runningMode !== 'IMAGE';
@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
} }
} }
return super.applyOptions(options); return super.applyOptions(options, loadTfliteModel);
} }
/** Sends a single image to the graph and awaits results. */ /** Sends a single image to the graph and awaits results. */
@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner {
imageSource: ImageSource, imageSource: ImageSource,
imageProcessingOptions: ImageProcessingOptions|undefined, imageProcessingOptions: ImageProcessingOptions|undefined,
timestamp: number): void { timestamp: number): void {
const normalizedRect = if (this.normRectStreamName) {
this.convertToNormalizedRect(imageSource, imageProcessingOptions); const normalizedRect =
this.graphRunner.addProtoToStream( this.convertToNormalizedRect(imageSource, imageProcessingOptions);
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', this.graphRunner.addProtoToStream(
this.normRectStreamName, timestamp); normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
this.normRectStreamName, timestamp);
}
this.graphRunner.addGpuBufferAsImageToStream( this.graphRunner.addGpuBufferAsImageToStream(
imageSource, this.imageStreamName, timestamp ?? performance.now()); imageSource, this.imageStreamName, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();