Return custom error if model download fails

PiperOrigin-RevId: 520066065
This commit is contained in:
Sebastian Schmidt 2023-03-28 10:24:16 -07:00 committed by Copybara-Service
parent d4ec485971
commit 5c295da6ff
2 changed files with 19 additions and 2 deletions

View File

@ -114,7 +114,14 @@ export abstract class TaskRunner {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => response.arrayBuffer())
.then(response => {
if (!response.ok) {
throw new Error(`Failed to fetch model: ${
baseOptions.modelAssetPath} (${response.status})`);
} else {
return response.arrayBuffer();
}
})
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
this.refreshGraph();

View File

@ -118,12 +118,14 @@ describe('TaskRunner', () => {
let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake;
let fetchStatus = 200;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
ok: fetchStatus === 200,
status: fetchStatus,
} as unknown as Response;
});
global.fetch = fetchSpy;
@ -225,6 +227,14 @@ describe('TaskRunner', () => {
return resolvedPromise;
});
it('returns custom error if model download failed', () => {
fetchStatus = 404;
return expectAsync(taskRunner.setOptions({
baseOptions: {modelAssetPath: `notfound.tflite`}
}))
.toBeRejectedWithError('Failed to fetch model: notfound.tflite (404)');
});
it('can enable CPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {