Allow Web developers to opt into CPU or GPU processing
PiperOrigin-RevId: 486935157
This commit is contained in:
parent
4a6562d423
commit
26066787b3
|
@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
if (options.baseOptions) {
|
||||||
const baseOptionsProto =
|
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||||
await convertBaseOptionsToProto(options.baseOptions);
|
options.baseOptions, this.options.getBaseOptions());
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ mediapipe_ts_library(
|
||||||
name = "base_options",
|
name = "base_options",
|
||||||
srcs = ["base_options.ts"],
|
srcs = ["base_options.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
||||||
|
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
|
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
* Converts a BaseOptions API object to its Protobuf representation.
|
* Converts a BaseOptions API object to its Protobuf representation.
|
||||||
* @throws If neither a model assset path or buffer is provided
|
* @throws If neither a model assset path or buffer is provided
|
||||||
*/
|
*/
|
||||||
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
|
export async function convertBaseOptionsToProto(
|
||||||
Promise<BaseOptionsProto> {
|
updatedOptions: BaseOptions,
|
||||||
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
|
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
||||||
|
const result =
|
||||||
|
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
||||||
|
|
||||||
|
await configureExternalFile(updatedOptions, result);
|
||||||
|
configureAcceleration(updatedOptions, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configues the `externalFile` option and validates that a single model is
|
||||||
|
* provided.
|
||||||
|
*/
|
||||||
|
async function configureExternalFile(
|
||||||
|
options: BaseOptions, proto: BaseOptionsProto) {
|
||||||
|
const externalFile = proto.getModelAsset() || new ExternalFile();
|
||||||
|
proto.setModelAsset(externalFile);
|
||||||
|
|
||||||
|
if (options.modelAssetPath || options.modelAssetBuffer) {
|
||||||
|
if (options.modelAssetPath && options.modelAssetBuffer) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||||
}
|
}
|
||||||
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
|
|
||||||
|
let modelAssetBuffer = options.modelAssetBuffer;
|
||||||
|
if (!modelAssetBuffer) {
|
||||||
|
const response = await fetch(options.modelAssetPath!.toString());
|
||||||
|
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||||
|
}
|
||||||
|
externalFile.setFileContent(modelAssetBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!externalFile.hasFileContent()) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||||
}
|
}
|
||||||
|
|
||||||
let modelAssetBuffer = baseOptions.modelAssetBuffer;
|
|
||||||
if (!modelAssetBuffer) {
|
|
||||||
const response = await fetch(baseOptions.modelAssetPath!.toString());
|
|
||||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const proto = new BaseOptionsProto();
|
/** Configues the `acceleration` option. */
|
||||||
const externalFile = new ExternalFile();
|
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||||
externalFile.setFileContent(modelAssetBuffer);
|
if ('delegate' in options) {
|
||||||
proto.setModelAsset(externalFile);
|
const acceleration = new Acceleration();
|
||||||
return proto;
|
if (options.delegate === 'cpu') {
|
||||||
|
acceleration.setXnnpack(
|
||||||
|
new InferenceCalculatorOptions.Delegate.Xnnpack());
|
||||||
|
proto.setAcceleration(acceleration);
|
||||||
|
} else if (options.delegate === 'gpu') {
|
||||||
|
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||||
|
proto.setAcceleration(acceleration);
|
||||||
|
} else {
|
||||||
|
proto.clearAcceleration();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
8
mediapipe/tasks/web/core/base_options.d.ts
vendored
|
@ -22,10 +22,14 @@ export interface BaseOptions {
|
||||||
* The model path to the model asset file. Only one of `modelAssetPath` or
|
* The model path to the model asset file. Only one of `modelAssetPath` or
|
||||||
* `modelAssetBuffer` can be set.
|
* `modelAssetBuffer` can be set.
|
||||||
*/
|
*/
|
||||||
modelAssetPath?: string;
|
modelAssetPath?: string|undefined;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A buffer containing the model aaset. Only one of `modelAssetPath` or
|
* A buffer containing the model aaset. Only one of `modelAssetPath` or
|
||||||
* `modelAssetBuffer` can be set.
|
* `modelAssetBuffer` can be set.
|
||||||
*/
|
*/
|
||||||
modelAssetBuffer?: Uint8Array;
|
modelAssetBuffer?: Uint8Array|undefined;
|
||||||
|
|
||||||
|
/** Overrides the default backend to use for the provided model. */
|
||||||
|
delegate?: 'cpu'|'gpu'|undefined;
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
async setOptions(options: TextClassifierOptions): Promise<void> {
|
async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
if (options.baseOptions) {
|
||||||
const baseOptionsProto =
|
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||||
await convertBaseOptionsToProto(options.baseOptions);
|
options.baseOptions, this.options.getBaseOptions());
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
if (options.baseOptions) {
|
||||||
const baseOptionsProto =
|
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||||
await convertBaseOptionsToProto(options.baseOptions);
|
options.baseOptions, this.options.getBaseOptions());
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
async setOptions(options: ImageClassifierOptions): Promise<void> {
|
async setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
if (options.baseOptions) {
|
||||||
const baseOptionsProto =
|
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||||
await convertBaseOptionsToProto(options.baseOptions);
|
options.baseOptions, this.options.getBaseOptions());
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
if (options.baseOptions) {
|
||||||
const baseOptionsProto =
|
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||||
await convertBaseOptionsToProto(options.baseOptions);
|
options.baseOptions, this.options.getBaseOptions());
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user