Make delegate sticky
PiperOrigin-RevId: 513249729
This commit is contained in:
parent
4a1ba11e3f
commit
22f186724e
|
@ -56,6 +56,7 @@ mediapipe_ts_library(
|
|||
deps = [
|
||||
":core",
|
||||
":task_runner",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
|
|
|
@ -208,13 +208,23 @@ export abstract class TaskRunner {
|
|||
|
||||
/** Configures the `acceleration` option. */
|
||||
private setAcceleration(options: BaseOptions) {
|
||||
const acceleration =
|
||||
this.baseOptions.getAcceleration() ?? new Acceleration();
|
||||
let acceleration = this.baseOptions.getAcceleration();
|
||||
|
||||
if (!acceleration) {
|
||||
// Create default instance for the initial configuration.
|
||||
acceleration = new Acceleration();
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
|
||||
if ('delegate' in options) {
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
acceleration.setTflite(
|
||||
new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
}
|
||||
|
||||
this.baseOptions.setAcceleration(acceleration);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||
|
@ -97,6 +98,23 @@ describe('TaskRunner', () => {
|
|||
tflite: {},
|
||||
},
|
||||
};
|
||||
const mockBytesResultWithGpuDelegate = {
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: InferenceCalculatorOptions.Delegate.Gpu.Api.ANY,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage
|
||||
.SUSTAINED_SPEED,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
let taskRunner: TaskRunnerFake;
|
||||
|
@ -224,22 +242,8 @@ describe('TaskRunner', () => {
|
|||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: 0,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: 2,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject())
|
||||
.toEqual(mockBytesResultWithGpuDelegate);
|
||||
});
|
||||
|
||||
it('can reset delegate', async () => {
|
||||
|
@ -249,8 +253,20 @@ describe('TaskRunner', () => {
|
|||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
// Clear backend
|
||||
// Clear delegate
|
||||
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('keeps delegate if not provided', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
await taskRunner.setOptions({baseOptions: {}});
|
||||
expect(taskRunner.baseOptions.toObject())
|
||||
.toEqual(mockBytesResultWithGpuDelegate);
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user