Make delegate sticky

PiperOrigin-RevId: 513249729
This commit is contained in:
Sebastian Schmidt 2023-03-01 08:56:29 -08:00 committed by Copybara-Service
parent 4a1ba11e3f
commit 22f186724e
3 changed files with 49 additions and 22 deletions

View File

@ -56,6 +56,7 @@ mediapipe_ts_library(
deps = [ deps = [
":core", ":core",
":task_runner", ":task_runner",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],

View File

@ -208,13 +208,23 @@ export abstract class TaskRunner {
/** Configures the `acceleration` option. */ /** Configures the `acceleration` option. */
private setAcceleration(options: BaseOptions) { private setAcceleration(options: BaseOptions) {
const acceleration = let acceleration = this.baseOptions.getAcceleration();
this.baseOptions.getAcceleration() ?? new Acceleration();
if (!acceleration) {
// Create default instance for the initial configuration.
acceleration = new Acceleration();
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
if ('delegate' in options) {
if (options.delegate === 'GPU') { if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else { } else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); acceleration.setTflite(
new InferenceCalculatorOptions.Delegate.TfLite());
} }
}
this.baseOptions.setAcceleration(acceleration); this.baseOptions.setAcceleration(acceleration);
} }
} }

View File

@ -16,6 +16,7 @@
import 'jasmine'; import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {ErrorListener} from '../../../web/graph_runner/graph_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner';
@ -97,6 +98,23 @@ describe('TaskRunner', () => {
tflite: {}, tflite: {},
}, },
}; };
const mockBytesResultWithGpuDelegate = {
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: InferenceCalculatorOptions.Delegate.Gpu.Api.ANY,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage
.SUSTAINED_SPEED,
},
tflite: undefined,
},
};
let fetchSpy: jasmine.Spy; let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake; let taskRunner: TaskRunnerFake;
@ -224,22 +242,8 @@ describe('TaskRunner', () => {
delegate: 'GPU', delegate: 'GPU',
} }
}); });
expect(taskRunner.baseOptions.toObject()).toEqual({ expect(taskRunner.baseOptions.toObject())
...mockBytesResult, .toEqual(mockBytesResultWithGpuDelegate);
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
}); });
it('can reset delegate', async () => { it('can reset delegate', async () => {
@ -249,8 +253,20 @@ describe('TaskRunner', () => {
delegate: 'GPU', delegate: 'GPU',
} }
}); });
// Clear backend // Clear delegate
await taskRunner.setOptions({baseOptions: {delegate: undefined}}); await taskRunner.setOptions({baseOptions: {delegate: undefined}});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
}); });
it('keeps delegate if not provided', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
await taskRunner.setOptions({baseOptions: {}});
expect(taskRunner.baseOptions.toObject())
.toEqual(mockBytesResultWithGpuDelegate);
});
}); });