Write TFLite model to Wasm file system
PiperOrigin-RevId: 532482502
This commit is contained in:
parent
d7fa4b95b5
commit
8bf6c63e92
|
@ -57,6 +57,7 @@ mediapipe_ts_library(
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
|
":task_runner_test_utils",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
|
|
@ -111,6 +111,7 @@ export abstract class TaskRunner {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||||
|
this.baseOptions.getModelAsset()?.hasFileName() ||
|
||||||
options.baseOptions?.modelAssetBuffer ||
|
options.baseOptions?.modelAssetBuffer ||
|
||||||
options.baseOptions?.modelAssetPath)) {
|
options.baseOptions?.modelAssetPath)) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
|
@ -131,7 +132,19 @@ export abstract class TaskRunner {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(buffer => {
|
.then(buffer => {
|
||||||
this.setExternalFile(new Uint8Array(buffer));
|
try {
|
||||||
|
// Try to delete file as we cannot overwite an existing file using
|
||||||
|
// our current API.
|
||||||
|
this.graphRunner.wasmModule.FS_unlink('/model.dat');
|
||||||
|
} catch {
|
||||||
|
}
|
||||||
|
// TODO: Consider passing the model to the graph as an
|
||||||
|
// input side packet as this might reduce copies.
|
||||||
|
this.graphRunner.wasmModule.FS_createDataFile(
|
||||||
|
'/', 'model.dat', new Uint8Array(buffer),
|
||||||
|
/* canRead= */ true, /* canWrite= */ false,
|
||||||
|
/* canOwn= */ false);
|
||||||
|
this.setExternalFile('/model.dat');
|
||||||
this.refreshGraph();
|
this.refreshGraph();
|
||||||
this.onGraphRefreshed();
|
this.onGraphRefreshed();
|
||||||
});
|
});
|
||||||
|
@ -236,10 +249,16 @@ export abstract class TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Configures the `externalFile` option */
|
/** Configures the `externalFile` option */
|
||||||
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
private setExternalFile(modelAssetPath?: string): void;
|
||||||
|
private setExternalFile(modelAssetBuffer?: Uint8Array): void;
|
||||||
|
private setExternalFile(modelAssetPathOrBuffer?: Uint8Array|string): void {
|
||||||
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||||
if (modelAssetBuffer) {
|
if (typeof modelAssetPathOrBuffer === 'string') {
|
||||||
externalFile.setFileContent(modelAssetBuffer);
|
externalFile.setFileName(modelAssetPathOrBuffer);
|
||||||
|
externalFile.clearFileContent();
|
||||||
|
} else if (modelAssetPathOrBuffer instanceof Uint8Array) {
|
||||||
|
externalFile.setFileContent(modelAssetPathOrBuffer);
|
||||||
|
externalFile.clearFileName();
|
||||||
}
|
}
|
||||||
this.baseOptions.setModelAsset(externalFile);
|
this.baseOptions.setModelAsset(externalFile);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,11 +19,16 @@ import 'jasmine';
|
||||||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||||
|
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource URL builder
|
// Placeholder for internal dependency on trusted resource URL builder
|
||||||
|
|
||||||
import {CachedGraphRunner} from './task_runner';
|
import {CachedGraphRunner} from './task_runner';
|
||||||
import {TaskRunnerOptions} from './task_runner_options.d';
|
import {TaskRunnerOptions} from './task_runner_options';
|
||||||
|
|
||||||
|
type Writeable<T> = {
|
||||||
|
-readonly[P in keyof T]: T[P]
|
||||||
|
};
|
||||||
|
|
||||||
class TaskRunnerFake extends TaskRunner {
|
class TaskRunnerFake extends TaskRunner {
|
||||||
private errorListener: ErrorListener|undefined;
|
private errorListener: ErrorListener|undefined;
|
||||||
|
@ -40,7 +45,8 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
|
'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
|
||||||
'registerModelResourcesGraphService', 'attachErrorListener'
|
'registerModelResourcesGraphService', 'attachErrorListener'
|
||||||
]));
|
]));
|
||||||
const graphRunner = this.graphRunner as jasmine.SpyObj<CachedGraphRunner>;
|
const graphRunner =
|
||||||
|
this.graphRunner as jasmine.SpyObj<Writeable<CachedGraphRunner>>;
|
||||||
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
|
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
|
||||||
graphRunner.attachErrorListener.and.callFake(listener => {
|
graphRunner.attachErrorListener.and.callFake(listener => {
|
||||||
this.errorListener = listener;
|
this.errorListener = listener;
|
||||||
|
@ -51,6 +57,11 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
graphRunner.finishProcessing.and.callFake(() => {
|
graphRunner.finishProcessing.and.callFake(() => {
|
||||||
this.throwErrors();
|
this.throwErrors();
|
||||||
});
|
});
|
||||||
|
graphRunner.wasmModule = createSpyWasmModule();
|
||||||
|
}
|
||||||
|
|
||||||
|
get wasmModule(): SpyWasmModule {
|
||||||
|
return this.graphRunner.wasmModule as SpyWasmModule;
|
||||||
}
|
}
|
||||||
|
|
||||||
enqueueError(message: string): void {
|
enqueueError(message: string): void {
|
||||||
|
@ -117,6 +128,21 @@ describe('TaskRunner', () => {
|
||||||
nnapi: undefined,
|
nnapi: undefined,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
const mockFileResult = {
|
||||||
|
modelAsset: {
|
||||||
|
fileContent: '',
|
||||||
|
fileName: '/model.dat',
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined,
|
||||||
|
},
|
||||||
|
useStreamMode: false,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: undefined,
|
||||||
|
tflite: {},
|
||||||
|
nnapi: undefined,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
let fetchSpy: jasmine.Spy;
|
let fetchSpy: jasmine.Spy;
|
||||||
let taskRunner: TaskRunnerFake;
|
let taskRunner: TaskRunnerFake;
|
||||||
|
@ -204,12 +230,13 @@ describe('TaskRunner', () => {
|
||||||
}).not.toThrowError();
|
}).not.toThrowError();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('downloads model', async () => {
|
it('writes model to file system', async () => {
|
||||||
await taskRunner.setOptions(
|
await taskRunner.setOptions(
|
||||||
{baseOptions: {modelAssetPath: `foo`}});
|
{baseOptions: {modelAssetPath: `foo`}});
|
||||||
|
|
||||||
expect(fetchSpy).toHaveBeenCalled();
|
expect(fetchSpy).toHaveBeenCalled();
|
||||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
expect(taskRunner.wasmModule.FS_createDataFile).toHaveBeenCalled();
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockFileResult);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('does not download model when bytes are provided', async () => {
|
it('does not download model when bytes are provided', async () => {
|
||||||
|
|
|
@ -33,11 +33,11 @@ export declare type SpyWasmModule = jasmine.SpyObj<SpyWasmModuleInternal>;
|
||||||
*/
|
*/
|
||||||
export function createSpyWasmModule(): SpyWasmModule {
|
export function createSpyWasmModule(): SpyWasmModule {
|
||||||
const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([
|
const spyWasmModule = jasmine.createSpyObj<SpyWasmModuleInternal>([
|
||||||
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener',
|
'FS_createDataFile', 'FS_unlink', '_addBoolToInputStream',
|
||||||
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
|
'_addProtoToInputStream', '_addStringToInputStream', '_attachProtoListener',
|
||||||
'_addStringToInputStream', '_registerModelResourcesGraphService',
|
'_attachProtoVectorListener', '_closeGraph', '_configureAudio', '_free',
|
||||||
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
|
'_getGraphConfig', '_malloc', '_registerModelResourcesGraphService',
|
||||||
'_closeGraph', '_addBoolToInputStream'
|
'_setAutoRenderToScreen', '_waitUntilIdle', 'stringToNewUTF8'
|
||||||
]);
|
]);
|
||||||
spyWasmModule._getGraphConfig.and.callFake(() => {
|
spyWasmModule._getGraphConfig.and.callFake(() => {
|
||||||
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
||||||
|
|
|
@ -59,6 +59,14 @@ export declare interface WasmModule {
|
||||||
HEAPU32: Uint32Array;
|
HEAPU32: Uint32Array;
|
||||||
HEAPF32: Float32Array;
|
HEAPF32: Float32Array;
|
||||||
HEAPF64: Float64Array;
|
HEAPF64: Float64Array;
|
||||||
|
FS_createDataFile:
|
||||||
|
(parent: string, name: string, data: Uint8Array, canRead: boolean,
|
||||||
|
canWrite: boolean, canOwn: boolean) => void;
|
||||||
|
FS_createPath:
|
||||||
|
(parent: string, name: string, canRead: boolean,
|
||||||
|
canWrite: boolean) => void;
|
||||||
|
FS_unlink(path: string): void;
|
||||||
|
|
||||||
errorListener?: ErrorListener;
|
errorListener?: ErrorListener;
|
||||||
_bindTextureToCanvas: () => boolean;
|
_bindTextureToCanvas: () => boolean;
|
||||||
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user