Support WASM asset loading for MediaPipe Task Web

PiperOrigin-RevId: 547882566
This commit is contained in:
Sebastian Schmidt 2023-07-13 12:24:57 -07:00 committed by Copybara-Service
parent 8b59567cb7
commit 327feb42d1
3 changed files with 86 additions and 68 deletions

View File

@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis
import {WasmFileset} from './wasm_fileset'; import {WasmFileset} from './wasm_fileset';
// None of the MP Tasks ship bundle assets.
const NO_ASSETS = undefined;
// Internal stream names for temporarily keeping memory alive, then freeing it. // Internal stream names for temporarily keeping memory alive, then freeing it.
const FREE_MEMORY_STREAM = 'free_memory'; const FREE_MEMORY_STREAM = 'free_memory';
const UNUSED_STREAM_SUFFIX = '_unused_out'; const UNUSED_STREAM_SUFFIX = '_unused_out';
@ -61,7 +58,8 @@ export async function createTaskRunner<T extends TaskRunner>(
}; };
const instance = await createMediaPipeLib( const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
fileLocator);
await instance.setOptions(options); await instance.setOptions(options);
return instance; return instance;
} }
@ -96,13 +94,19 @@ export abstract class TaskRunner {
abstract setOptions(options: TaskRunnerOptions): Promise<void>; abstract setOptions(options: TaskRunnerOptions): Promise<void>;
/** /**
* Applies the current set of options, including any base options that have * Applies the current set of options, including optionally any base options
* not been processed by the task implementation. The options are applied * that have not been processed by the task implementation. The options are
* synchronously unless a `modelAssetPath` is provided. This ensures that * applied synchronously unless a `modelAssetPath` is provided. This ensures
* for most use cases options are applied directly and immediately affect * that for most use cases options are applied directly and immediately affect
* the next inference. * the next inference.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/ */
protected applyOptions(options: TaskRunnerOptions): Promise<void> { protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
Promise<void> {
if (loadTfliteModel) {
const baseOptions: BaseOptions = options.baseOptions || {}; const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured // Validate that exactly one model is configured
@ -133,8 +137,8 @@ export abstract class TaskRunner {
}) })
.then(buffer => { .then(buffer => {
try { try {
// Try to delete file as we cannot overwite an existing file using // Try to delete file as we cannot overwite an existing file
// our current API. // using our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat'); this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch { } catch {
} }
@ -149,13 +153,15 @@ export abstract class TaskRunner {
this.onGraphRefreshed(); this.onGraphRefreshed();
}); });
} else { } else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer); this.setExternalFile(baseOptions.modelAssetBuffer);
}
}
// If there is no model to download, we can apply the setting synchronously.
this.refreshGraph(); this.refreshGraph();
this.onGraphRefreshed(); this.onGraphRefreshed();
return Promise.resolve(); return Promise.resolve();
} }
}
/** Appliest the current options to the MediaPipe graph. */ /** Appliest the current options to the MediaPipe graph. */
protected abstract refreshGraph(): void; protected abstract refreshGraph(): void;

View File

@ -22,4 +22,6 @@ export declare interface WasmFileset {
wasmLoaderPath: string; wasmLoaderPath: string;
/** The path to the Wasm binary. */ /** The path to the Wasm binary. */
wasmBinaryPath: string; wasmBinaryPath: string;
/** The optional path to the asset loader script. */
assetLoaderPath?: string;
} }

View File

@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner {
* @param imageStreamName the name of the input image stream. * @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image * @param normRectStreamName the name of the input normalized rect image
* stream used to provide (mandatory) rotation and (optional) * stream used to provide (mandatory) rotation and (optional)
* region-of-interest. * region-of-interest. `null` if the graph does not support normalized
* rects.
* @param roiAllowed Whether this task supports Region-Of-Interest * @param roiAllowed Whether this task supports Region-Of-Interest
* pre-processing * pre-processing
* *
@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner {
constructor( constructor(
protected override readonly graphRunner: VisionGraphRunner, protected override readonly graphRunner: VisionGraphRunner,
private readonly imageStreamName: string, private readonly imageStreamName: string,
private readonly normRectStreamName: string, private readonly normRectStreamName: string|null,
private readonly roiAllowed: boolean) { private readonly roiAllowed: boolean) {
super(graphRunner); super(graphRunner);
} }
/** Configures the shared options of a vision task. */ /**
override applyOptions(options: VisionTaskOptions): Promise<void> { * Configures the shared options of a vision task.
*
* @param options The options for the task.
* @param loadTfliteModel Whether to load the model specified in
* `options.baseOptions`.
*/
override applyOptions(options: VisionTaskOptions, loadTfliteModel = true):
Promise<void> {
if ('runningMode' in options) { if ('runningMode' in options) {
const useStreamMode = const useStreamMode =
!!options.runningMode && options.runningMode !== 'IMAGE'; !!options.runningMode && options.runningMode !== 'IMAGE';
@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
} }
} }
return super.applyOptions(options); return super.applyOptions(options, loadTfliteModel);
} }
/** Sends a single image to the graph and awaits results. */ /** Sends a single image to the graph and awaits results. */
@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner {
imageSource: ImageSource, imageSource: ImageSource,
imageProcessingOptions: ImageProcessingOptions|undefined, imageProcessingOptions: ImageProcessingOptions|undefined,
timestamp: number): void { timestamp: number): void {
if (this.normRectStreamName) {
const normalizedRect = const normalizedRect =
this.convertToNormalizedRect(imageSource, imageProcessingOptions); this.convertToNormalizedRect(imageSource, imageProcessingOptions);
this.graphRunner.addProtoToStream( this.graphRunner.addProtoToStream(
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
this.normRectStreamName, timestamp); this.normRectStreamName, timestamp);
}
this.graphRunner.addGpuBufferAsImageToStream( this.graphRunner.addGpuBufferAsImageToStream(
imageSource, this.imageStreamName, timestamp ?? performance.now()); imageSource, this.imageStreamName, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();