Make delegate sticky
PiperOrigin-RevId: 513249729
This commit is contained in:
parent
4a1ba11e3f
commit
22f186724e
|
@ -56,6 +56,7 @@ mediapipe_ts_library(
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
|
|
|
@ -208,13 +208,23 @@ export abstract class TaskRunner {
|
||||||
|
|
||||||
/** Configures the `acceleration` option. */
|
/** Configures the `acceleration` option. */
|
||||||
private setAcceleration(options: BaseOptions) {
|
private setAcceleration(options: BaseOptions) {
|
||||||
const acceleration =
|
let acceleration = this.baseOptions.getAcceleration();
|
||||||
this.baseOptions.getAcceleration() ?? new Acceleration();
|
|
||||||
if (options.delegate === 'GPU') {
|
if (!acceleration) {
|
||||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
// Create default instance for the initial configuration.
|
||||||
} else {
|
acceleration = new Acceleration();
|
||||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('delegate' in options) {
|
||||||
|
if (options.delegate === 'GPU') {
|
||||||
|
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||||
|
} else {
|
||||||
|
acceleration.setTflite(
|
||||||
|
new InferenceCalculatorOptions.Delegate.TfLite());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this.baseOptions.setAcceleration(acceleration);
|
this.baseOptions.setAcceleration(acceleration);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import 'jasmine';
|
import 'jasmine';
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||||
|
@ -97,6 +98,23 @@ describe('TaskRunner', () => {
|
||||||
tflite: {},
|
tflite: {},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
const mockBytesResultWithGpuDelegate = {
|
||||||
|
...mockBytesResult,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: {
|
||||||
|
useAdvancedGpuApi: false,
|
||||||
|
api: InferenceCalculatorOptions.Delegate.Gpu.Api.ANY,
|
||||||
|
allowPrecisionLoss: true,
|
||||||
|
cachedKernelPath: undefined,
|
||||||
|
serializedModelDir: undefined,
|
||||||
|
modelToken: undefined,
|
||||||
|
usage: InferenceCalculatorOptions.Delegate.Gpu.InferenceUsage
|
||||||
|
.SUSTAINED_SPEED,
|
||||||
|
},
|
||||||
|
tflite: undefined,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
let fetchSpy: jasmine.Spy;
|
let fetchSpy: jasmine.Spy;
|
||||||
let taskRunner: TaskRunnerFake;
|
let taskRunner: TaskRunnerFake;
|
||||||
|
@ -224,22 +242,8 @@ describe('TaskRunner', () => {
|
||||||
delegate: 'GPU',
|
delegate: 'GPU',
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
expect(taskRunner.baseOptions.toObject()).toEqual({
|
expect(taskRunner.baseOptions.toObject())
|
||||||
...mockBytesResult,
|
.toEqual(mockBytesResultWithGpuDelegate);
|
||||||
acceleration: {
|
|
||||||
xnnpack: undefined,
|
|
||||||
gpu: {
|
|
||||||
useAdvancedGpuApi: false,
|
|
||||||
api: 0,
|
|
||||||
allowPrecisionLoss: true,
|
|
||||||
cachedKernelPath: undefined,
|
|
||||||
serializedModelDir: undefined,
|
|
||||||
modelToken: undefined,
|
|
||||||
usage: 2,
|
|
||||||
},
|
|
||||||
tflite: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can reset delegate', async () => {
|
it('can reset delegate', async () => {
|
||||||
|
@ -249,8 +253,20 @@ describe('TaskRunner', () => {
|
||||||
delegate: 'GPU',
|
delegate: 'GPU',
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// Clear backend
|
// Clear delegate
|
||||||
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('keeps delegate if not provided', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'GPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
await taskRunner.setOptions({baseOptions: {}});
|
||||||
|
expect(taskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(mockBytesResultWithGpuDelegate);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user