Allow Web developers to opt into CPU or GPU processing

PiperOrigin-RevId: 486935157
This commit is contained in:
Sebastian Schmidt 2022-11-08 06:35:43 -08:00 committed by Copybara-Service
parent 4a6562d423
commit 26066787b3
8 changed files with 72 additions and 30 deletions

View File

@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
*/ */
async setOptions(options: AudioClassifierOptions): Promise<void> { async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -26,6 +26,8 @@ mediapipe_ts_library(
name = "base_options", name = "base_options",
srcs = ["base_options.ts"], srcs = ["base_options.ts"],
deps = [ deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {BaseOptions} from '../../../../tasks/web/core/base_options';
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
* Converts a BaseOptions API object to its Protobuf representation. * Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided * @throws If neither a model assset path or buffer is provided
*/ */
export async function convertBaseOptionsToProto(baseOptions: BaseOptions): export async function convertBaseOptionsToProto(
Promise<BaseOptionsProto> { updatedOptions: BaseOptions,
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error( throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} }
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error( throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
} }
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(baseOptions.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
} }
const proto = new BaseOptionsProto(); /** Configues the `acceleration` option. */
const externalFile = new ExternalFile(); function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
externalFile.setFileContent(modelAssetBuffer); if ('delegate' in options) {
proto.setModelAsset(externalFile); const acceleration = new Acceleration();
return proto; if (options.delegate === 'cpu') {
acceleration.setXnnpack(
new InferenceCalculatorOptions.Delegate.Xnnpack());
proto.setAcceleration(acceleration);
} else if (options.delegate === 'gpu') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
proto.setAcceleration(acceleration);
} else {
proto.clearAcceleration();
}
}
} }

View File

@ -22,10 +22,14 @@ export interface BaseOptions {
* The model path to the model asset file. Only one of `modelAssetPath` or * The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetPath?: string; modelAssetPath?: string|undefined;
/** /**
* A buffer containing the model aaset. Only one of `modelAssetPath` or * A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetBuffer?: Uint8Array; modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined;
} }

View File

@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
*/ */
async setOptions(options: TextClassifierOptions): Promise<void> { async setOptions(options: TextClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
*/ */
async setOptions(options: GestureRecognizerOptions): Promise<void> { async setOptions(options: GestureRecognizerOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
*/ */
async setOptions(options: ImageClassifierOptions): Promise<void> { async setOptions(options: ImageClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
*/ */
async setOptions(options: ObjectDetectorOptions): Promise<void> { async setOptions(options: ObjectDetectorOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }