Write TFLite model to Wasm file system
PiperOrigin-RevId: 532482502
This commit is contained in:
parent
d7fa4b95b5
commit
8bf6c63e92
|
@ -57,6 +57,7 @@ mediapipe_ts_library(
|
|||
deps = [
|
||||
":core",
|
||||
":task_runner",
|
||||
":task_runner_test_utils",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
|
|
|
@ -111,6 +111,7 @@ export abstract class TaskRunner {
|
|||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||
this.baseOptions.getModelAsset()?.hasFileName() ||
|
||||
options.baseOptions?.modelAssetBuffer ||
|
||||
options.baseOptions?.modelAssetPath)) {
|
||||
throw new Error(
|
||||
|
@ -131,7 +132,19 @@ export abstract class TaskRunner {
|
|||
}
|
||||
})
|
||||
.then(buffer => {
|
||||
this.setExternalFile(new Uint8Array(buffer));
|
||||
try {
|
||||
// Try to delete file as we cannot overwite an existing file using
|
||||
// our current API.
|
||||
this.graphRunner.wasmModule.FS_unlink('/model.dat');
|
||||
} catch {
|
||||
}
|
||||
// TODO: Consider passing the model to the graph as an
|
||||
// input side packet as this might reduce copies.
|
||||
this.graphRunner.wasmModule.FS_createDataFile(
|
||||
'/', 'model.dat', new Uint8Array(buffer),
|
||||
/* canRead= */ true, /* canWrite= */ false,
|
||||
/* canOwn= */ false);
|
||||
this.setExternalFile('/model.dat');
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
});
|
||||
|
@ -236,10 +249,16 @@ export abstract class TaskRunner {
|
|||
}
|
||||
|
||||
/** Configures the `externalFile` option */
|
||||
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
||||
private setExternalFile(modelAssetPath?: string): void;
|
||||
private setExternalFile(modelAssetBuffer?: Uint8Array): void;
|
||||
private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
|
||||
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||
if (modelAssetBuffer) {
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
if (typeof modelAssetPathOrBuffer === 'string') {
|
||||
externalFile.setFileName(modelAssetPathOrBuffer);
|
||||
externalFile.clearFileContent();
|
||||
} else if (modelAssetPathOrBuffer instanceof Uint8Array) {
|
||||
externalFile.setFileContent(modelAssetPathOrBuffer);
|
||||
externalFile.clearFileName();
|
||||
}
|
||||
this.baseOptions.setModelAsset(externalFile);
|
||||
}
|
||||
|
|
|
@ -19,11 +19,16 @@ import 'jasmine';
|
|||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource URL builder
|
||||
|
||||
import {CachedGraphRunner} from './task_runner';
|
||||
import {TaskRunnerOptions} from './task_runner_options.d';
|
||||
import {TaskRunnerOptions} from './task_runner_options';
|
||||
|
||||
type Writeable<T> = {
|
||||
-readonly[P in keyof T]: T[P]
|
||||
};
|
||||
|
||||
class TaskRunnerFake extends TaskRunner {
|
||||
private errorListener: ErrorListener|undefined;
|
||||
|
@ -40,7 +45,8 @@ class TaskRunnerFake extends TaskRunner {
|
|||
'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
|
||||
'registerModelResourcesGraphService', 'attachErrorListener'
|
||||
]));
|
||||
const graphRunner = this.graphRunner as jasmine.SpyObj<CachedGraphRunner>;
|
||||
const graphRunner =
|
||||
this.graphRunner as jasmine.SpyObj<Writeable<CachedGraphRunner>>;
|
||||
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
|
||||
graphRunner.attachErrorListener.and.callFake(listener => {
|
||||
this.errorListener = listener;
|
||||
|
@ -51,6 +57,11 @@ class TaskRunnerFake extends TaskRunner {
|
|||
graphRunner.finishProcessing.and.callFake(() => {
|
||||
this.throwErrors();
|
||||
});
|
||||
graphRunner.wasmModule = createSpyWasmModule();
|
||||
}
|
||||
|
||||
get wasmModule(): SpyWasmModule {
|
||||
return this.graphRunner.wasmModule as SpyWasmModule;
|
||||
}
|
||||
|
||||
enqueueError(message: string): void {
|
||||
|
@ -117,6 +128,21 @@ describe('TaskRunner', () => {
|
|||
nnapi: undefined,
|
||||
},
|
||||
};
|
||||
const mockFileResult = {
|
||||
modelAsset: {
|
||||
fileContent: '',
|
||||
fileName: '/model.dat',
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined,
|
||||
},
|
||||
useStreamMode: false,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: undefined,
|
||||
tflite: {},
|
||||
nnapi: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
let taskRunner: TaskRunnerFake;
|
||||
|
@ -204,12 +230,13 @@ describe('TaskRunner', () => {
|
|||
}).not.toThrowError();
|
||||
});
|
||||
|
||||
it('downloads model', async () => {
|
||||
it('writes model to file system', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetPath: `foo`}});
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockFileResult);
|
||||
});
|
||||
|
||||
it('does not download model when bytes are provided', async () => {
|
||||
|
|
|
@ -33,11 +33,11 @@ export declare type SpyWasmModule = jasmine.SpyObj<SpyWasmModuleInternal>;
|
|||
*/
|
||||
export function createSpyWasmModule(): SpyWasmModule {
|
||||
const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([
|
||||
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener',
|
||||
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
|
||||
'_addStringToInputStream', '_registerModelResourcesGraphService',
|
||||
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
|
||||
'_closeGraph', '_addBoolToInputStream'
|
||||
'FS_createDataFile', 'FS_unlink', '_addBoolToInputStream',
|
||||
'_addProtoToInputStream', '_addStringToInputStream', '_attachProtoListener',
|
||||
'_attachProtoVectorListener', '_closeGraph', '_configureAudio', '_free',
|
||||
'_getGraphConfig', '_malloc', '_registerModelResourcesGraphService',
|
||||
'_setAutoRenderToScreen', '_waitUntilIdle', 'stringToNewUTF8'
|
||||
]);
|
||||
spyWasmModule._getGraphConfig.and.callFake(() => {
|
||||
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
||||
|
|
|
@ -59,6 +59,14 @@ export declare interface WasmModule {
|
|||
HEAPU32: Uint32Array;
|
||||
HEAPF32: Float32Array;
|
||||
HEAPF64: Float64Array;
|
||||
FS_createDataFile:
|
||||
(parent: string, name: string, data: Uint8Array, canRead: boolean,
|
||||
canWrite: boolean, canOwn: boolean) => void;
|
||||
FS_createPath:
|
||||
(parent: string, name: string, canRead: boolean,
|
||||
canWrite: boolean) => void;
|
||||
FS_unlink(path: string): void;
|
||||
|
||||
errorListener?: ErrorListener;
|
||||
_bindTextureToCanvas: () => boolean;
|
||||
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
||||
|
|
Loading…
Reference in New Issue
Block a user