diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index fd79487a4..39353b226 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner { */ async setOptions(options: AudioClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index e6b9adf20..cd7190dd9 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -26,6 +26,8 @@ mediapipe_ts_library( name = "base_options", srcs = ["base_options.ts"], deps = [ + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index 2f7d0db37..a7f7bd280 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -14,6 +14,8 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {BaseOptions} from '../../../../tasks/web/core/base_options'; @@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; * Converts a BaseOptions API object to its Protobuf representation. * @throws If neither a model assset path or buffer is provided */ -export async function convertBaseOptionsToProto(baseOptions: BaseOptions): - Promise { - if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); +export async function convertBaseOptionsToProto( + updatedOptions: BaseOptions, + currentOptions?: BaseOptionsProto): Promise { + const result = + currentOptions ? currentOptions.clone() : new BaseOptionsProto(); + + await configureExternalFile(updatedOptions, result); + configureAcceleration(updatedOptions, result); + + return result; +} + +/** + * Configues the `externalFile` option and validates that a single model is + * provided. + */ +async function configureExternalFile( + options: BaseOptions, proto: BaseOptionsProto) { + const externalFile = proto.getModelAsset() || new ExternalFile(); + proto.setModelAsset(externalFile); + + if (options.modelAssetPath || options.modelAssetBuffer) { + if (options.modelAssetPath && options.modelAssetBuffer) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } + + let modelAssetBuffer = options.modelAssetBuffer; + if (!modelAssetBuffer) { + const response = await fetch(options.modelAssetPath!.toString()); + modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); + } + externalFile.setFileContent(modelAssetBuffer); } - if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) { + + if (!externalFile.hasFileContent()) { throw new Error( 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); } - - let modelAssetBuffer = baseOptions.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(baseOptions.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - - const proto = new BaseOptionsProto(); - const externalFile = new ExternalFile(); - externalFile.setFileContent(modelAssetBuffer); - proto.setModelAsset(externalFile); - return proto; +} + +/** Configues the `acceleration` option. */ +function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { + if ('delegate' in options) { + const acceleration = new Acceleration(); + if (options.delegate === 'cpu') { + acceleration.setXnnpack( + new InferenceCalculatorOptions.Delegate.Xnnpack()); + proto.setAcceleration(acceleration); + } else if (options.delegate === 'gpu') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + proto.setAcceleration(acceleration); + } else { + proto.clearAcceleration(); + } + } } diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/base_options.d.ts index 02a288a87..54a59a21b 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/base_options.d.ts @@ -22,10 +22,14 @@ export interface BaseOptions { * The model path to the model asset file. Only one of `modelAssetPath` or * `modelAssetBuffer` can be set. */ - modelAssetPath?: string; + modelAssetPath?: string|undefined; + /** * A buffer containing the model aaset. Only one of `modelAssetPath` or * `modelAssetBuffer` can be set. */ - modelAssetBuffer?: Uint8Array; + modelAssetBuffer?: Uint8Array|undefined; + + /** Overrides the default backend to use for the provided model. */ + delegate?: 'cpu'|'gpu'|undefined; } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index ff36bb9e0..d92248b80 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner { */ async setOptions(options: TextClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index ad8db1477..1275ae875 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner { */ async setOptions(options: GestureRecognizerOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 39674e85c..cb63874c4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner { */ async setOptions(options: ImageClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index c3bb21baa..022bf6301 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner { */ async setOptions(options: ObjectDetectorOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); }