Allow Web developers to opt into CPU or GPU processing

PiperOrigin-RevId: 486935157
This commit is contained in:
Sebastian Schmidt 2022-11-08 06:35:43 -08:00 committed by Copybara-Service
parent 4a6562d423
commit 26066787b3
8 changed files with 72 additions and 30 deletions

View File

@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
*/
async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}

View File

@ -26,6 +26,8 @@ mediapipe_ts_library(
name = "base_options",
srcs = ["base_options.ts"],
deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options';
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
Promise<BaseOptionsProto> {
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
export async function convertBaseOptionsToProto(
updatedOptions: BaseOptions,
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(baseOptions.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
const proto = new BaseOptionsProto();
const externalFile = new ExternalFile();
externalFile.setFileContent(modelAssetBuffer);
proto.setModelAsset(externalFile);
return proto;
/** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
if ('delegate' in options) {
const acceleration = new Acceleration();
if (options.delegate === 'cpu') {
acceleration.setXnnpack(
new InferenceCalculatorOptions.Delegate.Xnnpack());
proto.setAcceleration(acceleration);
} else if (options.delegate === 'gpu') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
proto.setAcceleration(acceleration);
} else {
proto.clearAcceleration();
}
}
}

View File

@ -22,10 +22,14 @@ export interface BaseOptions {
* The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetPath?: string;
modelAssetPath?: string|undefined;
/**
* A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetBuffer?: Uint8Array;
modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined;
}

View File

@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
*/
async setOptions(options: TextClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}

View File

@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
*/
async setOptions(options: GestureRecognizerOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}

View File

@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
*/
async setOptions(options: ImageClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}

View File

@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
*/
async setOptions(options: ObjectDetectorOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}