Write TFLite model to Wasm file system

PiperOrigin-RevId: 532482502
This commit is contained in:
Sebastian Schmidt 2023-05-16 09:23:27 -07:00 committed by Copybara-Service
parent d7fa4b95b5
commit 8bf6c63e92
5 changed files with 68 additions and 13 deletions

View File

@ -57,6 +57,7 @@ mediapipe_ts_library(
deps = [ deps = [
":core", ":core",
":task_runner", ":task_runner",
":task_runner_test_utils",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",

View File

@ -111,6 +111,7 @@ export abstract class TaskRunner {
throw new Error( throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer || options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) { options.baseOptions?.modelAssetPath)) {
throw new Error( throw new Error(
@ -131,7 +132,19 @@ export abstract class TaskRunner {
} }
}) })
.then(buffer => { .then(buffer => {
this.setExternalFile(new Uint8Array(buffer)); try {
// Try to delete file as we cannot overwite an existing file using
// our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph(); this.refreshGraph();
this.onGraphRefreshed(); this.onGraphRefreshed();
}); });
@ -236,10 +249,16 @@ export abstract class TaskRunner {
} }
/** Configures the `externalFile` option */ /** Configures the `externalFile` option */
private setExternalFile(modelAssetBuffer?: Uint8Array): void { private setExternalFile(modelAssetPath?: string): void;
private setExternalFile(modelAssetBuffer?: Uint8Array): void;
private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (modelAssetBuffer) { if (typeof modelAssetPathOrBuffer === 'string') {
externalFile.setFileContent(modelAssetBuffer); externalFile.setFileName(modelAssetPathOrBuffer);
externalFile.clearFileContent();
} else if (modelAssetPathOrBuffer instanceof Uint8Array) {
externalFile.setFileContent(modelAssetPathOrBuffer);
externalFile.clearFileName();
} }
this.baseOptions.setModelAsset(externalFile); this.baseOptions.setModelAsset(externalFile);
} }

View File

@ -19,11 +19,16 @@ import 'jasmine';
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder // Placeholder for internal dependency on trusted resource URL builder
import {CachedGraphRunner} from './task_runner'; import {CachedGraphRunner} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d'; import {TaskRunnerOptions} from './task_runner_options';
type Writeable<T> = {
-readonly[P in keyof T]: T[P]
};
class TaskRunnerFake extends TaskRunner { class TaskRunnerFake extends TaskRunner {
private errorListener: ErrorListener|undefined; private errorListener: ErrorListener|undefined;
@ -40,7 +45,8 @@ class TaskRunnerFake extends TaskRunner {
'setAutoRenderToScreen', 'setGraph', 'finishProcessing', 'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
'registerModelResourcesGraphService', 'attachErrorListener' 'registerModelResourcesGraphService', 'attachErrorListener'
])); ]));
const graphRunner = this.graphRunner as jasmine.SpyObj<CachedGraphRunner>; const graphRunner =
this.graphRunner as jasmine.SpyObj<Writeable<CachedGraphRunner>>;
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
graphRunner.attachErrorListener.and.callFake(listener => { graphRunner.attachErrorListener.and.callFake(listener => {
this.errorListener = listener; this.errorListener = listener;
@ -51,6 +57,11 @@ class TaskRunnerFake extends TaskRunner {
graphRunner.finishProcessing.and.callFake(() => { graphRunner.finishProcessing.and.callFake(() => {
this.throwErrors(); this.throwErrors();
}); });
graphRunner.wasmModule = createSpyWasmModule();
}
get wasmModule(): SpyWasmModule {
return this.graphRunner.wasmModule as SpyWasmModule;
} }
enqueueError(message: string): void { enqueueError(message: string): void {
@ -117,6 +128,21 @@ describe('TaskRunner', () => {
nnapi: undefined, nnapi: undefined,
}, },
}; };
const mockFileResult = {
modelAsset: {
fileContent: '',
fileName: '/model.dat',
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
nnapi: undefined,
},
};
let fetchSpy: jasmine.Spy; let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake; let taskRunner: TaskRunnerFake;
@ -204,12 +230,13 @@ describe('TaskRunner', () => {
}).not.toThrowError(); }).not.toThrowError();
}); });
it('downloads model', async () => { it('writes model to file system', async () => {
await taskRunner.setOptions( await taskRunner.setOptions(
{baseOptions: {modelAssetPath: `foo`}}); {baseOptions: {modelAssetPath: `foo`}});
expect(fetchSpy).toHaveBeenCalled(); expect(fetchSpy).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockFileResult);
}); });
it('does not download model when bytes are provided', async () => { it('does not download model when bytes are provided', async () => {

View File

@ -33,11 +33,11 @@ export declare type SpyWasmModule = jasmine.SpyObj<SpyWasmModuleInternal>;
*/ */
export function createSpyWasmModule(): SpyWasmModule { export function createSpyWasmModule(): SpyWasmModule {
const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([ const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', 'FS_createDataFile', 'FS_unlink', '_addBoolToInputStream',
'_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addProtoToInputStream', '_addStringToInputStream', '_attachProtoListener',
'_addStringToInputStream', '_registerModelResourcesGraphService', '_attachProtoVectorListener', '_closeGraph', '_configureAudio', '_free',
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig', '_getGraphConfig', '_malloc', '_registerModelResourcesGraphService',
'_closeGraph', '_addBoolToInputStream' '_setAutoRenderToScreen', '_waitUntilIdle', 'stringToNewUTF8'
]); ]);
spyWasmModule._getGraphConfig.and.callFake(() => { spyWasmModule._getGraphConfig.and.callFake(() => {
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as

View File

@ -59,6 +59,14 @@ export declare interface WasmModule {
HEAPU32: Uint32Array; HEAPU32: Uint32Array;
HEAPF32: Float32Array; HEAPF32: Float32Array;
HEAPF64: Float64Array; HEAPF64: Float64Array;
FS_createDataFile:
(parent: string, name: string, data: Uint8Array, canRead: boolean,
canWrite: boolean, canOwn: boolean) => void;
FS_createPath:
(parent: string, name: string, canRead: boolean,
canWrite: boolean) => void;
FS_unlink(path: string): void;
errorListener?: ErrorListener; errorListener?: ErrorListener;
_bindTextureToCanvas: () => boolean; _bindTextureToCanvas: () => boolean;
_changeBinaryGraph: (size: number, dataPtr: number) => void; _changeBinaryGraph: (size: number, dataPtr: number) => void;