Write TFLite model to Wasm file system

PiperOrigin-RevId: 532482502
This commit is contained in:
Sebastian Schmidt 2023-05-16 09:23:27 -07:00 committed by Copybara-Service
parent d7fa4b95b5
commit 8bf6c63e92
5 changed files with 68 additions and 13 deletions

View File

@ -57,6 +57,7 @@ mediapipe_ts_library(
deps = [
":core",
":task_runner",
":task_runner_test_utils",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts",

View File

@ -111,6 +111,7 @@ export abstract class TaskRunner {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
this.baseOptions.getModelAsset()?.hasFileName() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
@ -131,7 +132,19 @@ export abstract class TaskRunner {
}
})
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
try {
// Try to delete file as we cannot overwite an existing file using
// our current API.
this.graphRunner.wasmModule.FS_unlink('/model.dat');
} catch {
}
// TODO: Consider passing the model to the graph as an
// input side packet as this might reduce copies.
this.graphRunner.wasmModule.FS_createDataFile(
'/', 'model.dat', new Uint8Array(buffer),
/* canRead= */ true, /* canWrite= */ false,
/* canOwn= */ false);
this.setExternalFile('/model.dat');
this.refreshGraph();
this.onGraphRefreshed();
});
@ -236,10 +249,16 @@ export abstract class TaskRunner {
}
/** Configures the `externalFile` option */
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
private setExternalFile(modelAssetPath?: string): void;
private setExternalFile(modelAssetBuffer?: Uint8Array): void;
private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (modelAssetBuffer) {
externalFile.setFileContent(modelAssetBuffer);
if (typeof modelAssetPathOrBuffer === 'string') {
externalFile.setFileName(modelAssetPathOrBuffer);
externalFile.clearFileContent();
} else if (modelAssetPathOrBuffer instanceof Uint8Array) {
externalFile.setFileContent(modelAssetPathOrBuffer);
externalFile.clearFileName();
}
this.baseOptions.setModelAsset(externalFile);
}

View File

@ -19,11 +19,16 @@ import 'jasmine';
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder
import {CachedGraphRunner} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d';
import {TaskRunnerOptions} from './task_runner_options';
type Writeable<T> = {
-readonly[P in keyof T]: T[P]
};
class TaskRunnerFake extends TaskRunner {
private errorListener: ErrorListener|undefined;
@ -40,7 +45,8 @@ class TaskRunnerFake extends TaskRunner {
'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
'registerModelResourcesGraphService', 'attachErrorListener'
]));
const graphRunner = this.graphRunner as jasmine.SpyObj<CachedGraphRunner>;
const graphRunner =
this.graphRunner as jasmine.SpyObj<Writeable<CachedGraphRunner>>;
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
graphRunner.attachErrorListener.and.callFake(listener => {
this.errorListener = listener;
@ -51,6 +57,11 @@ class TaskRunnerFake extends TaskRunner {
graphRunner.finishProcessing.and.callFake(() => {
this.throwErrors();
});
graphRunner.wasmModule = createSpyWasmModule();
}
get wasmModule(): SpyWasmModule {
return this.graphRunner.wasmModule as SpyWasmModule;
}
enqueueError(message: string): void {
@ -117,6 +128,21 @@ describe('TaskRunner', () => {
nnapi: undefined,
},
};
const mockFileResult = {
modelAsset: {
fileContent: '',
fileName: '/model.dat',
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
nnapi: undefined,
},
};
let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake;
@ -204,12 +230,13 @@ describe('TaskRunner', () => {
}).not.toThrowError();
});
it('downloads model', async () => {
it('writes model to file system', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetPath: `foo`}});
expect(fetchSpy).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockFileResult);
});
it('does not download model when bytes are provided', async () => {

View File

@ -33,11 +33,11 @@ export declare type SpyWasmModule = jasmine.SpyObj<SpyWasmModuleInternal>;
*/
export function createSpyWasmModule(): SpyWasmModule {
const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener',
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
'_addStringToInputStream', '_registerModelResourcesGraphService',
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
'_closeGraph', '_addBoolToInputStream'
'FS_createDataFile', 'FS_unlink', '_addBoolToInputStream',
'_addProtoToInputStream', '_addStringToInputStream', '_attachProtoListener',
'_attachProtoVectorListener', '_closeGraph', '_configureAudio', '_free',
'_getGraphConfig', '_malloc', '_registerModelResourcesGraphService',
'_setAutoRenderToScreen', '_waitUntilIdle', 'stringToNewUTF8'
]);
spyWasmModule._getGraphConfig.and.callFake(() => {
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as

View File

@ -59,6 +59,14 @@ export declare interface WasmModule {
HEAPU32: Uint32Array;
HEAPF32: Float32Array;
HEAPF64: Float64Array;
FS_createDataFile:
(parent: string, name: string, data: Uint8Array, canRead: boolean,
canWrite: boolean, canOwn: boolean) => void;
FS_createPath:
(parent: string, name: string, canRead: boolean,
canWrite: boolean) => void;
FS_unlink(path: string): void;
errorListener?: ErrorListener;
_bindTextureToCanvas: () => boolean;
_changeBinaryGraph: (size: number, dataPtr: number) => void;