From 22f186724e53a23f434d0bc966812af25467ea5e Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 1 Mar 2023 08:56:29 -0800 Subject: [PATCH] Make delegate sticky PiperOrigin-RevId: 513249729 --- mediapipe/tasks/web/core/BUILD | 1 + mediapipe/tasks/web/core/task_runner.ts | 20 ++++++-- mediapipe/tasks/web/core/task_runner_test.ts | 50 +++++++++++++------- 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 371c75da0..ec65548d4 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -56,6 +56,7 @@ mediapipe_ts_library( deps = [ ":core", ":task_runner", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index b6babe730..79b2ca173 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -208,13 +208,23 @@ export abstract class TaskRunner { /** Configures the `acceleration` option. */ private setAcceleration(options: BaseOptions) { - const acceleration = - this.baseOptions.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { + let acceleration = this.baseOptions.getAcceleration(); + + if (!acceleration) { + // Create default instance for the initial configuration. + acceleration = new Acceleration(); acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); } + + if ('delegate' in options) { + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite( + new InferenceCalculatorOptions.Delegate.TfLite()); + } + } + this.baseOptions.setAcceleration(acceleration); } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index 1276c2c9a..41ac76952 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -16,6 +16,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; @@ -97,6 +98,23 @@ describe('TaskRunner', () => { tflite: {}, }, }; + const mockBytesResultWithGpuDelegate = { + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: InferenceCalculatorOptions.Delegate.Gpu.Api.ANY, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage + .SUSTAINED_SPEED, + }, + tflite: undefined, + }, + }; let fetchSpy: jasmine.Spy; let taskRunner: TaskRunnerFake; @@ -224,22 +242,8 @@ describe('TaskRunner', () => { delegate: 'GPU', } }); - expect(taskRunner.baseOptions.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); + expect(taskRunner.baseOptions.toObject()) + .toEqual(mockBytesResultWithGpuDelegate); }); it('can reset delegate', async () => { @@ -249,8 +253,20 @@ describe('TaskRunner', () => { delegate: 'GPU', } }); - // Clear backend + // Clear delegate await taskRunner.setOptions({baseOptions: {delegate: undefined}}); expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); }); + + it('keeps delegate if not provided', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + await taskRunner.setOptions({baseOptions: {}}); + expect(taskRunner.baseOptions.toObject()) + .toEqual(mockBytesResultWithGpuDelegate); + }); });