Allow Web developers to opt into CPU or GPU processing
PiperOrigin-RevId: 486935157
This commit is contained in:
parent
4a6562d423
commit
26066787b3
|
@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@ mediapipe_ts_library(
|
|||
name = "base_options",
|
||||
srcs = ["base_options.ts"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||
|
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
|||
* Converts a BaseOptions API object to its Protobuf representation.
|
||||
* @throws If neither a model assset path or buffer is provided
|
||||
*/
|
||||
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
|
||||
Promise<BaseOptionsProto> {
|
||||
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
|
||||
export async function convertBaseOptionsToProto(
|
||||
updatedOptions: BaseOptions,
|
||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
||||
const result =
|
||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
||||
|
||||
await configureExternalFile(updatedOptions, result);
|
||||
configureAcceleration(updatedOptions, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configues the `externalFile` option and validates that a single model is
|
||||
* provided.
|
||||
*/
|
||||
async function configureExternalFile(
|
||||
options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
||||
proto.setModelAsset(externalFile);
|
||||
|
||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
}
|
||||
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
|
||||
|
||||
let modelAssetBuffer = options.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(options.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
|
||||
if (!externalFile.hasFileContent()) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
|
||||
let modelAssetBuffer = baseOptions.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(baseOptions.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
|
||||
const proto = new BaseOptionsProto();
|
||||
const externalFile = new ExternalFile();
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
proto.setModelAsset(externalFile);
|
||||
return proto;
|
||||
}
|
||||
|
||||
/** Configues the `acceleration` option. */
|
||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||
if ('delegate' in options) {
|
||||
const acceleration = new Acceleration();
|
||||
if (options.delegate === 'cpu') {
|
||||
acceleration.setXnnpack(
|
||||
new InferenceCalculatorOptions.Delegate.Xnnpack());
|
||||
proto.setAcceleration(acceleration);
|
||||
} else if (options.delegate === 'gpu') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
proto.setAcceleration(acceleration);
|
||||
} else {
|
||||
proto.clearAcceleration();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
|
@ -22,10 +22,14 @@ export interface BaseOptions {
|
|||
* The model path to the model asset file. Only one of `modelAssetPath` or
|
||||
* `modelAssetBuffer` can be set.
|
||||
*/
|
||||
modelAssetPath?: string;
|
||||
modelAssetPath?: string|undefined;
|
||||
|
||||
/**
|
||||
* A buffer containing the model aaset. Only one of `modelAssetPath` or
|
||||
* `modelAssetBuffer` can be set.
|
||||
*/
|
||||
modelAssetBuffer?: Uint8Array;
|
||||
modelAssetBuffer?: Uint8Array|undefined;
|
||||
|
||||
/** Overrides the default backend to use for the provided model. */
|
||||
delegate?: 'cpu'|'gpu'|undefined;
|
||||
}
|
||||
|
|
|
@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
|
|
@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
|
|
@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
|
|
@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
|
|||
*/
|
||||
async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto =
|
||||
await convertBaseOptionsToProto(options.baseOptions);
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user