Support WASM asset loading for MediaPipe Task Web
PiperOrigin-RevId: 547882566
This commit is contained in:
parent
8b59567cb7
commit
327feb42d1
|
@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis
|
|||
|
||||
import {WasmFileset} from './wasm_fileset';
|
||||
|
||||
// None of the MP Tasks ship bundle assets.
|
||||
const NO_ASSETS = undefined;
|
||||
|
||||
// Internal stream names for temporarily keeping memory alive, then freeing it.
|
||||
const FREE_MEMORY_STREAM = 'free_memory';
|
||||
const UNUSED_STREAM_SUFFIX = '_unused_out';
|
||||
|
@ -61,7 +58,8 @@ export async function createTaskRunner<T extends TaskRunner>(
|
|||
};
|
||||
|
||||
const instance = await createMediaPipeLib(
|
||||
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
|
||||
type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas,
|
||||
fileLocator);
|
||||
await instance.setOptions(options);
|
||||
return instance;
|
||||
}
|
||||
|
@ -96,65 +94,73 @@ export abstract class TaskRunner {
|
|||
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Applies the current set of options, including any base options that have
|
||||
* not been processed by the task implementation. The options are applied
|
||||
* synchronously unless a `modelAssetPath` is provided. This ensures that
|
||||
* for most use cases options are applied directly and immediately affect
|
||||
* Applies the current set of options, including optionally any base options
|
||||
* that have not been processed by the task implementation. The options are
|
||||
* applied synchronously unless a `modelAssetPath` is provided. This ensures
|
||||
* that for most use cases options are applied directly and immediately affect
|
||||
* the next inference.
|
||||
*
|
||||
* @param options The options for the task.
|
||||
* @param loadTfliteModel Whether to load the model specified in
|
||||
* `options.baseOptions`.
|
||||
*/
|
||||
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||
protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true):
|
||||
Promise<void> {
|
||||
if (loadTfliteModel) {
|
||||
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||
|
||||
// Validate that exactly one model is configured
|
||||
if (options.baseOptions?.modelAssetBuffer &&
|
||||
options.baseOptions?.modelAssetPath) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||
this.baseOptions.getModelAsset()?.hasFileName() ||
|
||||
options.baseOptions?.modelAssetBuffer ||
|
||||
options.baseOptions?.modelAssetPath)) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
// Validate that exactly one model is configured
|
||||
if (options.baseOptions?.modelAssetBuffer &&
|
||||
options.baseOptions?.modelAssetPath) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||
this.baseOptions.getModelAsset()?.hasFileName() ||
|
||||
options.baseOptions?.modelAssetBuffer ||
|
||||
options.baseOptions?.modelAssetPath)) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
|
||||
this.setAcceleration(baseOptions);
|
||||
if (baseOptions.modelAssetPath) {
|
||||
// We don't use `await` here since we want to apply most settings
|
||||
// synchronously.
|
||||
return fetch(baseOptions.modelAssetPath.toString())
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch model: ${
|
||||
baseOptions.modelAssetPath} (${response.status})`);
|
||||
} else {
|
||||
return response.arrayBuffer();
|
||||
}
|
||||
})
|
||||
.then(buffer => {
|
||||
try {
|
||||
// Try to delete file as we cannot overwite an existing file
|
||||
// using our current API.
|
||||
this.graphRunner.wasmModule.FS_unlink('/model.dat');
|
||||
} catch {
|
||||
}
|
||||
// TODO: Consider passing the model to the graph as an
|
||||
// input side packet as this might reduce copies.
|
||||
this.graphRunner.wasmModule.FS_createDataFile(
|
||||
'/', 'model.dat', new Uint8Array(buffer),
|
||||
/* canRead= */ true, /* canWrite= */ false,
|
||||
/* canOwn= */ false);
|
||||
this.setExternalFile('/model.dat');
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
});
|
||||
} else {
|
||||
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
this.setAcceleration(baseOptions);
|
||||
if (baseOptions.modelAssetPath) {
|
||||
// We don't use `await` here since we want to apply most settings
|
||||
// synchronously.
|
||||
return fetch(baseOptions.modelAssetPath.toString())
|
||||
.then(response => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch model: ${
|
||||
baseOptions.modelAssetPath} (${response.status})`);
|
||||
} else {
|
||||
return response.arrayBuffer();
|
||||
}
|
||||
})
|
||||
.then(buffer => {
|
||||
try {
|
||||
// Try to delete file as we cannot overwite an existing file using
|
||||
// our current API.
|
||||
this.graphRunner.wasmModule.FS_unlink('/model.dat');
|
||||
} catch {
|
||||
}
|
||||
// TODO: Consider passing the model to the graph as an
|
||||
// input side packet as this might reduce copies.
|
||||
this.graphRunner.wasmModule.FS_createDataFile(
|
||||
'/', 'model.dat', new Uint8Array(buffer),
|
||||
/* canRead= */ true, /* canWrite= */ false,
|
||||
/* canOwn= */ false);
|
||||
this.setExternalFile('/model.dat');
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
});
|
||||
} else {
|
||||
// Apply the setting synchronously.
|
||||
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
return Promise.resolve();
|
||||
}
|
||||
// If there is no model to download, we can apply the setting synchronously.
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
/** Appliest the current options to the MediaPipe graph. */
|
||||
|
|
2
mediapipe/tasks/web/core/wasm_fileset.d.ts
vendored
2
mediapipe/tasks/web/core/wasm_fileset.d.ts
vendored
|
@ -22,4 +22,6 @@ export declare interface WasmFileset {
|
|||
wasmLoaderPath: string;
|
||||
/** The path to the Wasm binary. */
|
||||
wasmBinaryPath: string;
|
||||
/** The optional path to the asset loader script. */
|
||||
assetLoaderPath?: string;
|
||||
}
|
||||
|
|
|
@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
|||
* @param imageStreamName the name of the input image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image
|
||||
* stream used to provide (mandatory) rotation and (optional)
|
||||
* region-of-interest.
|
||||
* region-of-interest. `null` if the graph does not support normalized
|
||||
* rects.
|
||||
* @param roiAllowed Whether this task supports Region-Of-Interest
|
||||
* pre-processing
|
||||
*
|
||||
|
@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
|||
constructor(
|
||||
protected override readonly graphRunner: VisionGraphRunner,
|
||||
private readonly imageStreamName: string,
|
||||
private readonly normRectStreamName: string,
|
||||
private readonly normRectStreamName: string|null,
|
||||
private readonly roiAllowed: boolean) {
|
||||
super(graphRunner);
|
||||
}
|
||||
|
||||
/** Configures the shared options of a vision task. */
|
||||
override applyOptions(options: VisionTaskOptions): Promise<void> {
|
||||
/**
|
||||
* Configures the shared options of a vision task.
|
||||
*
|
||||
* @param options The options for the task.
|
||||
* @param loadTfliteModel Whether to load the model specified in
|
||||
* `options.baseOptions`.
|
||||
*/
|
||||
override applyOptions(options: VisionTaskOptions, loadTfliteModel = true):
|
||||
Promise<void> {
|
||||
if ('runningMode' in options) {
|
||||
const useStreamMode =
|
||||
!!options.runningMode && options.runningMode !== 'IMAGE';
|
||||
|
@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
|||
}
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
return super.applyOptions(options, loadTfliteModel);
|
||||
}
|
||||
|
||||
/** Sends a single image to the graph and awaits results. */
|
||||
|
@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
|||
imageSource: ImageSource,
|
||||
imageProcessingOptions: ImageProcessingOptions|undefined,
|
||||
timestamp: number): void {
|
||||
const normalizedRect =
|
||||
this.convertToNormalizedRect(imageSource, imageProcessingOptions);
|
||||
this.graphRunner.addProtoToStream(
|
||||
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
|
||||
this.normRectStreamName, timestamp);
|
||||
if (this.normRectStreamName) {
|
||||
const normalizedRect =
|
||||
this.convertToNormalizedRect(imageSource, imageProcessingOptions);
|
||||
this.graphRunner.addProtoToStream(
|
||||
normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect',
|
||||
this.normRectStreamName, timestamp);
|
||||
}
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
imageSource, this.imageStreamName, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
|
|
Loading…
Reference in New Issue
Block a user