From 8bf6c63e924236a1f28bd4cf121acbed4c95808e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 16 May 2023 09:23:27 -0700 Subject: [PATCH] Write TFLite model to Wasm file system PiperOrigin-RevId: 532482502 --- mediapipe/tasks/web/core/BUILD | 1 + mediapipe/tasks/web/core/task_runner.ts | 27 +++++++++++--- mediapipe/tasks/web/core/task_runner_test.ts | 35 ++++++++++++++++--- .../tasks/web/core/task_runner_test_utils.ts | 10 +++--- mediapipe/web/graph_runner/graph_runner.ts | 8 +++++ 5 files changed, 68 insertions(+), 13 deletions(-) diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index a417d4d72..0c102a86a 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -57,6 +57,7 @@ mediapipe_ts_library( deps = [ ":core", ":task_runner", + ":task_runner_test_utils", "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index efeffbb87..8c6aae6cf 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -111,6 +111,7 @@ export abstract class TaskRunner { throw new Error( 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + this.baseOptions.getModelAsset()?.hasFileName() || options.baseOptions?.modelAssetBuffer || options.baseOptions?.modelAssetPath)) { throw new Error( @@ -131,7 +132,19 @@ export abstract class TaskRunner { } }) .then(buffer => { - this.setExternalFile(new Uint8Array(buffer)); + try { + // Try to delete file as we cannot overwite an existing file using + // our current API. + this.graphRunner.wasmModule.FS_unlink('/model.dat'); + } catch { + } + // TODO: Consider passing the model to the graph as an + // input side packet as this might reduce copies. + this.graphRunner.wasmModule.FS_createDataFile( + '/', 'model.dat', new Uint8Array(buffer), + /* canRead= */ true, /* canWrite= */ false, + /* canOwn= */ false); + this.setExternalFile('/model.dat'); this.refreshGraph(); this.onGraphRefreshed(); }); @@ -236,10 +249,16 @@ export abstract class TaskRunner { } /** Configures the `externalFile` option */ - private setExternalFile(modelAssetBuffer?: Uint8Array): void { + private setExternalFile(modelAssetPath?: string): void; + private setExternalFile(modelAssetBuffer?: Uint8Array): void; + private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void { const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); - if (modelAssetBuffer) { - externalFile.setFileContent(modelAssetBuffer); + if (typeof modelAssetPathOrBuffer === 'string') { + externalFile.setFileName(modelAssetPathOrBuffer); + externalFile.clearFileContent(); + } else if (modelAssetPathOrBuffer instanceof Uint8Array) { + externalFile.setFileContent(modelAssetPathOrBuffer); + externalFile.clearFileName(); } this.baseOptions.setModelAsset(externalFile); } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index 873e6fea1..a68ba224a 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -19,11 +19,16 @@ import 'jasmine'; import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; +import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource URL builder import {CachedGraphRunner} from './task_runner'; -import {TaskRunnerOptions} from './task_runner_options.d'; +import {TaskRunnerOptions} from './task_runner_options'; + +type Writeable = { + -readonly[P in keyof T]: T[P] +}; class TaskRunnerFake extends TaskRunner { private errorListener: ErrorListener|undefined; @@ -40,7 +45,8 @@ class TaskRunnerFake extends TaskRunner { 'setAutoRenderToScreen', 'setGraph', 'finishProcessing', 'registerModelResourcesGraphService', 'attachErrorListener' ])); - const graphRunner = this.graphRunner as jasmine.SpyObj; + const graphRunner = + this.graphRunner as jasmine.SpyObj>; expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled(); graphRunner.attachErrorListener.and.callFake(listener => { this.errorListener = listener; @@ -51,6 +57,11 @@ class TaskRunnerFake extends TaskRunner { graphRunner.finishProcessing.and.callFake(() => { this.throwErrors(); }); + graphRunner.wasmModule = createSpyWasmModule(); + } + + get wasmModule(): SpyWasmModule { + return this.graphRunner.wasmModule as SpyWasmModule; } enqueueError(message: string): void { @@ -117,6 +128,21 @@ describe('TaskRunner', () => { nnapi: undefined, }, }; + const mockFileResult = { + modelAsset: { + fileContent: '', + fileName: '/model.dat', + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + nnapi: undefined, + }, + }; let fetchSpy: jasmine.Spy; let taskRunner: TaskRunnerFake; @@ -204,12 +230,13 @@ describe('TaskRunner', () => { }).not.toThrowError(); }); - it('downloads model', async () => { + it('writes model to file system', async () => { await taskRunner.setOptions( {baseOptions: {modelAssetPath: `foo`}}); expect(fetchSpy).toHaveBeenCalled(); - expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockFileResult); }); it('does not download model when bytes are provided', async () => { diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 1532eb2a5..777cb8704 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -33,11 +33,11 @@ export declare type SpyWasmModule = jasmine.SpyObj; */ export function createSpyWasmModule(): SpyWasmModule { const spyWasmModule = jasmine.createSpyObj([ - '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', - '_attachProtoVectorListener', '_free', '_waitUntilIdle', - '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig', - '_closeGraph', '_addBoolToInputStream' + 'FS_createDataFile', 'FS_unlink', '_addBoolToInputStream', + '_addProtoToInputStream', '_addStringToInputStream', '_attachProtoListener', + '_attachProtoVectorListener', '_closeGraph', '_configureAudio', '_free', + '_getGraphConfig', '_malloc', '_registerModelResourcesGraphService', + '_setAutoRenderToScreen', '_waitUntilIdle', 'stringToNewUTF8' ]); spyWasmModule._getGraphConfig.and.callFake(() => { (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 615971cb3..3444653f8 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -59,6 +59,14 @@ export declare interface WasmModule { HEAPU32: Uint32Array; HEAPF32: Float32Array; HEAPF64: Float64Array; + FS_createDataFile: + (parent: string, name: string, data: Uint8Array, canRead: boolean, + canWrite: boolean, canOwn: boolean) => void; + FS_createPath: + (parent: string, name: string, canRead: boolean, + canWrite: boolean) => void; + FS_unlink(path: string): void; + errorListener?: ErrorListener; _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void;