From 327feb42d1c9187693b1d18a550efc1d930b2eae Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 13 Jul 2023 12:24:57 -0700 Subject: [PATCH] Support WASM asset loading for MediaPipe Task Web PiperOrigin-RevId: 547882566 --- mediapipe/tasks/web/core/task_runner.ts | 122 +++++++++--------- mediapipe/tasks/web/core/wasm_fileset.d.ts | 2 + .../web/vision/core/vision_task_runner.ts | 30 +++-- 3 files changed, 86 insertions(+), 68 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 8c6aae6cf..dde98192d 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis import {WasmFileset} from './wasm_fileset'; -// None of the MP Tasks ship bundle assets. -const NO_ASSETS = undefined; - // Internal stream names for temporarily keeping memory alive, then freeing it. const FREE_MEMORY_STREAM = 'free_memory'; const UNUSED_STREAM_SUFFIX = '_unused_out'; @@ -61,7 +58,8 @@ export async function createTaskRunner( }; const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas, + fileLocator); await instance.setOptions(options); return instance; } @@ -96,65 +94,73 @@ export abstract class TaskRunner { abstract setOptions(options: TaskRunnerOptions): Promise; /** - * Applies the current set of options, including any base options that have - * not been processed by the task implementation. The options are applied - * synchronously unless a `modelAssetPath` is provided. This ensures that - * for most use cases options are applied directly and immediately affect + * Applies the current set of options, including optionally any base options + * that have not been processed by the task implementation. The options are + * applied synchronously unless a `modelAssetPath` is provided. This ensures + * that for most use cases options are applied directly and immediately affect * the next inference. + * + * @param options The options for the task. + * @param loadTfliteModel Whether to load the model specified in + * `options.baseOptions`. */ - protected applyOptions(options: TaskRunnerOptions): Promise { - const baseOptions: BaseOptions = options.baseOptions || {}; + protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true): + Promise { + if (loadTfliteModel) { + const baseOptions: BaseOptions = options.baseOptions || {}; - // Validate that exactly one model is configured - if (options.baseOptions?.modelAssetBuffer && - options.baseOptions?.modelAssetPath) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || - this.baseOptions.getModelAsset()?.hasFileName() || - options.baseOptions?.modelAssetBuffer || - options.baseOptions?.modelAssetPath)) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + this.baseOptions.getModelAsset()?.hasFileName() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => { + if (!response.ok) { + throw new Error(`Failed to fetch model: ${ + baseOptions.modelAssetPath} (${response.status})`); + } else { + return response.arrayBuffer(); + } + }) + .then(buffer => { + try { + // Try to delete file as we cannot overwite an existing file + // using our current API. + this.graphRunner.wasmModule.FS_unlink('/model.dat'); + } catch { + } + // TODO: Consider passing the model to the graph as an + // input side packet as this might reduce copies. + this.graphRunner.wasmModule.FS_createDataFile( + '/', 'model.dat', new Uint8Array(buffer), + /* canRead= */ true, /* canWrite= */ false, + /* canOwn= */ false); + this.setExternalFile('/model.dat'); + this.refreshGraph(); + this.onGraphRefreshed(); + }); + } else { + this.setExternalFile(baseOptions.modelAssetBuffer); + } } - this.setAcceleration(baseOptions); - if (baseOptions.modelAssetPath) { - // We don't use `await` here since we want to apply most settings - // synchronously. - return fetch(baseOptions.modelAssetPath.toString()) - .then(response => { - if (!response.ok) { - throw new Error(`Failed to fetch model: ${ - baseOptions.modelAssetPath} (${response.status})`); - } else { - return response.arrayBuffer(); - } - }) - .then(buffer => { - try { - // Try to delete file as we cannot overwite an existing file using - // our current API. - this.graphRunner.wasmModule.FS_unlink('/model.dat'); - } catch { - } - // TODO: Consider passing the model to the graph as an - // input side packet as this might reduce copies. - this.graphRunner.wasmModule.FS_createDataFile( - '/', 'model.dat', new Uint8Array(buffer), - /* canRead= */ true, /* canWrite= */ false, - /* canOwn= */ false); - this.setExternalFile('/model.dat'); - this.refreshGraph(); - this.onGraphRefreshed(); - }); - } else { - // Apply the setting synchronously. - this.setExternalFile(baseOptions.modelAssetBuffer); - this.refreshGraph(); - this.onGraphRefreshed(); - return Promise.resolve(); - } + // If there is no model to download, we can apply the setting synchronously. + this.refreshGraph(); + this.onGraphRefreshed(); + return Promise.resolve(); } /** Appliest the current options to the MediaPipe graph. */ diff --git a/mediapipe/tasks/web/core/wasm_fileset.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts index 558aa3faf..dda466ad9 100644 --- a/mediapipe/tasks/web/core/wasm_fileset.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -22,4 +22,6 @@ export declare interface WasmFileset { wasmLoaderPath: string; /** The path to the Wasm binary. */ wasmBinaryPath: string; + /** The optional path to the asset loader script. */ + assetLoaderPath?: string; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index f8f7826d0..3ed15b97d 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner { * @param imageStreamName the name of the input image stream. * @param normRectStreamName the name of the input normalized rect image * stream used to provide (mandatory) rotation and (optional) - * region-of-interest. + * region-of-interest. `null` if the graph does not support normalized + * rects. * @param roiAllowed Whether this task supports Region-Of-Interest * pre-processing * @@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner { constructor( protected override readonly graphRunner: VisionGraphRunner, private readonly imageStreamName: string, - private readonly normRectStreamName: string, + private readonly normRectStreamName: string|null, private readonly roiAllowed: boolean) { super(graphRunner); } - /** Configures the shared options of a vision task. */ - override applyOptions(options: VisionTaskOptions): Promise { + /** + * Configures the shared options of a vision task. + * + * @param options The options for the task. + * @param loadTfliteModel Whether to load the model specified in + * `options.baseOptions`. + */ + override applyOptions(options: VisionTaskOptions, loadTfliteModel = true): + Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'IMAGE'; @@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner { } } - return super.applyOptions(options); + return super.applyOptions(options, loadTfliteModel); } /** Sends a single image to the graph and awaits results. */ @@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner { imageSource: ImageSource, imageProcessingOptions: ImageProcessingOptions|undefined, timestamp: number): void { - const normalizedRect = - this.convertToNormalizedRect(imageSource, imageProcessingOptions); - this.graphRunner.addProtoToStream( - normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', - this.normRectStreamName, timestamp); + if (this.normRectStreamName) { + const normalizedRect = + this.convertToNormalizedRect(imageSource, imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + } this.graphRunner.addGpuBufferAsImageToStream( imageSource, this.imageStreamName, timestamp ?? performance.now()); this.finishProcessing();