Apply most graph options synchronously

PiperOrigin-RevId: 498244085
This commit is contained in:
Sebastian Schmidt 2022-12-28 13:57:20 -08:00 committed by Copybara-Service
parent 7e36a5e2ae
commit 9580f04571
27 changed files with 280 additions and 311 deletions

View File

@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
* *
* @param options The options for the audio classifier. * @param options The options for the audio classifier.
*/ */
override async setOptions(options: AudioClassifierOptions): Promise<void> { override setOptions(options: AudioClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
audioClassifier = new AudioClassifierFake(); audioClassifier = new AudioClassifierFake();
await audioClassifier.setOptions({}); // Initialize graph await audioClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
* *
* @param options The options for the audio embedder. * @param options The options for the audio embedder.
*/ */
override async setOptions(options: AudioEmbedderOptions): Promise<void> { override setOptions(options: AudioEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);

View File

@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
audioEmbedder = new AudioEmbedderFake(); audioEmbedder = new AudioEmbedderFake();
await audioEmbedder.setOptions({}); // Initialize graph await audioEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', () => { it('initializes graph', () => {

View File

@ -103,29 +103,3 @@ jasmine_node_test(
name = "embedder_options_test", name = "embedder_options_test",
deps = [":embedder_options_test_lib"], deps = [":embedder_options_test_lib"],
) )
mediapipe_ts_library(
name = "base_options",
srcs = [
"base_options.ts",
],
deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",
],
)
mediapipe_ts_library(
name = "base_options_test_lib",
testonly = True,
srcs = ["base_options.test.ts"],
deps = [":base_options"],
)
jasmine_node_test(
name = "base_options_test",
deps = [":base_options_test_lib"],
)

View File

@ -1,127 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
// Placeholder for internal dependency on trusted resource URL builder
import {convertBaseOptionsToProto} from './base_options';
describe('convertBaseOptionsToProto()', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
});
it('verifies that at least one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({}))
.toBeRejectedWithError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', async () => {
await expectAsync(convertBaseOptionsToProto({
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}))
.toBeRejectedWithError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('downloads model', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetPath: `foo`,
});
expect(fetchSpy).toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
});
expect(fetchSpy).not.toHaveBeenCalled();
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable CPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
});
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
const baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
expect(baseOptionsProto.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
let baseOptionsProto = await convertBaseOptionsToProto({
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
});
// Clear backend
baseOptionsProto =
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
});
});

View File

@ -1,80 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(
updatedOptions: BaseOptions,
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
}
/** Configues the `acceleration` option. */
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
const acceleration = proto.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
proto.setAcceleration(acceleration);
}

View File

@ -18,8 +18,10 @@ mediapipe_ts_library(
srcs = ["task_runner.ts"], srcs = ["task_runner.ts"],
deps = [ deps = [
":core", ":core",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
@ -53,6 +55,7 @@ mediapipe_ts_library(
"task_runner_test.ts", "task_runner_test.ts",
], ],
deps = [ deps = [
":core",
":task_runner", ":task_runner",
":task_runner_test_utils", ":task_runner_test_utils",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",

View File

@ -14,9 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
@ -91,14 +93,52 @@ export abstract class TaskRunner {
this.graphRunner.registerModelResourcesGraphService(); this.graphRunner.registerModelResourcesGraphService();
} }
/** Configures the shared options of a MediaPipe Task. */ /** Configures the task with custom options. */
async setOptions(options: TaskRunnerOptions): Promise<void> { abstract setOptions(options: TaskRunnerOptions): Promise<void>;
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto( /**
options.baseOptions, this.baseOptions); * Applies the current set of options, including any base options that have
* not been processed by the task implementation. The options are applied
* synchronously unless a `modelAssetPath` is provided. This ensures that
* for most use cases options are applied directly and immediately affect
* the next inference.
*/
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
const baseOptions: BaseOptions = options.baseOptions || {};
// Validate that exactly one model is configured
if (options.baseOptions?.modelAssetBuffer &&
options.baseOptions?.modelAssetPath) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
options.baseOptions?.modelAssetBuffer ||
options.baseOptions?.modelAssetPath)) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
this.setAcceleration(baseOptions);
if (baseOptions.modelAssetPath) {
// We don't use `await` here since we want to apply most settings
// synchronously.
return fetch(baseOptions.modelAssetPath.toString())
.then(response => response.arrayBuffer())
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
this.refreshGraph();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
return Promise.resolve();
} }
} }
/** Appliest the current options to the MediaPipe graph. */
protected abstract refreshGraph(): void;
/** /**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph, * over the video stream. Will replace the previously running MediaPipe graph,
@ -140,6 +180,27 @@ export abstract class TaskRunner {
} }
this.processingErrors = []; this.processingErrors = [];
} }
/** Configures the `externalFile` option */
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
if (modelAssetBuffer) {
externalFile.setFileContent(modelAssetBuffer);
}
this.baseOptions.setModelAsset(externalFile);
}
/** Configures the `acceleration` option. */
private setAcceleration(options: BaseOptions) {
const acceleration =
this.baseOptions.getAcceleration() ?? new Acceleration();
if (options.delegate === 'GPU') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
} else {
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
}
this.baseOptions.setAcceleration(acceleration);
}
} }

View File

@ -15,18 +15,22 @@
*/ */
import 'jasmine'; import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner'; import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder
import {GraphRunnerImageLib} from './task_runner'; import {GraphRunnerImageLib} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d';
class TaskRunnerFake extends TaskRunner { class TaskRunnerFake extends TaskRunner {
protected baseOptions = new BaseOptionsProto();
private errorListener: ErrorListener|undefined; private errorListener: ErrorListener|undefined;
private errors: string[] = []; private errors: string[] = [];
baseOptions = new BaseOptionsProto();
static createFake(): TaskRunnerFake { static createFake(): TaskRunnerFake {
const wasmModule = createSpyWasmModule(); const wasmModule = createSpyWasmModule();
return new TaskRunnerFake(wasmModule); return new TaskRunnerFake(wasmModule);
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
super.finishProcessing(); super.finishProcessing();
} }
override refreshGraph(): void {}
override setGraph(graphData: Uint8Array, isBinary: boolean): void { override setGraph(graphData: Uint8Array, isBinary: boolean): void {
super.setGraph(graphData, isBinary); super.setGraph(graphData, isBinary);
} }
setOptions(options: TaskRunnerOptions): Promise<void> {
return this.applyOptions(options);
}
private throwErrors(): void { private throwErrors(): void {
expect(this.errorListener).toBeDefined(); expect(this.errorListener).toBeDefined();
for (const error of this.errors) { for (const error of this.errors) {
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
} }
describe('TaskRunner', () => { describe('TaskRunner', () => {
const mockBytes = new Uint8Array([0, 1, 2, 3]);
const mockBytesResult = {
modelAsset: {
fileContent: Buffer.from(mockBytes).toString('base64'),
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined,
},
useStreamMode: false,
acceleration: {
xnnpack: undefined,
gpu: undefined,
tflite: {},
},
};
let fetchSpy: jasmine.Spy;
let taskRunner: TaskRunnerFake;
beforeEach(() => {
fetchSpy = jasmine.createSpy().and.callFake(async url => {
expect(url).toEqual('foo');
return {
arrayBuffer: () => mockBytes.buffer,
} as unknown as Response;
});
global.fetch = fetchSpy;
taskRunner = TaskRunnerFake.createFake();
});
it('handles errors during graph update', () => { it('handles errors during graph update', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error'); taskRunner.enqueueError('Test error');
expect(() => { expect(() => {
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
}); });
it('handles errors during graph execution', () => { it('handles errors during graph execution', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
taskRunner.enqueueError('Test error'); taskRunner.enqueueError('Test error');
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
}); });
it('can handle multiple errors', () => { it('can handle multiple errors', () => {
const taskRunner = TaskRunnerFake.createFake();
taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 1');
taskRunner.enqueueError('Test error 2'); taskRunner.enqueueError('Test error 2');
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
}).toThrowError(/Test error 1, Test error 2/); }).toThrowError(/Test error 1, Test error 2/);
}); });
it('verifies that at least one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({});
})
.toThrowError(
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
});
it('verifies that no more than one model asset option is provided', () => {
expect(() => {
taskRunner.setOptions({
baseOptions: {
modelAssetPath: `foo`,
modelAssetBuffer: new Uint8Array([])
}
});
})
.toThrowError(
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
});
it('doesn\'t require model once it is configured', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(() => {
taskRunner.setOptions({});
}).not.toThrowError();
});
it('downloads model', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetPath: `foo`}});
expect(fetchSpy).toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('does not download model when bytes are provided', async () => {
await taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
expect(fetchSpy).not.toHaveBeenCalled();
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('changes model synchronously when bytes are provided', () => {
const resolvedPromise = taskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
// Check that the change has been applied even though we do not await the
// above Promise
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
return resolvedPromise;
});
it('can enable CPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'CPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
it('can enable GPU delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
expect(taskRunner.baseOptions.toObject()).toEqual({
...mockBytesResult,
acceleration: {
xnnpack: undefined,
gpu: {
useAdvancedGpuApi: false,
api: 0,
allowPrecisionLoss: true,
cachedKernelPath: undefined,
serializedModelDir: undefined,
modelToken: undefined,
usage: 2,
},
tflite: undefined,
},
});
});
it('can reset delegate', async () => {
await taskRunner.setOptions({
baseOptions: {
modelAssetBuffer: new Uint8Array(mockBytes),
delegate: 'GPU',
}
});
// Clear backend
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
});
}); });

View File

@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
* *
* @param options The options for the text classifier. * @param options The options for the text classifier.
*/ */
override async setOptions(options: TextClassifierOptions): Promise<void> { override setOptions(options: TextClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
textClassifier = new TextClassifierFake(); textClassifier = new TextClassifierFake();
await textClassifier.setOptions({}); // Initialize graph await textClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
* *
* @param options The options for the text embedder. * @param options The options for the text embedder.
*/ */
override async setOptions(options: TextEmbedderOptions): Promise<void> { override setOptions(options: TextEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
textEmbedder = new TextEmbedderFake(); textEmbedder = new TextEmbedderFake();
await textEmbedder.setOptions({}); // Initialize graph await textEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -29,6 +29,7 @@ mediapipe_ts_library(
testonly = True, testonly = True,
srcs = ["vision_task_runner.test.ts"], srcs = ["vision_task_runner.test.ts"],
deps = [ deps = [
":vision_task_options",
":vision_task_runner", ":vision_task_runner",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",

View File

@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskOptions} from './vision_task_options';
import {VisionTaskRunner} from './vision_task_runner'; import {VisionTaskRunner} from './vision_task_runner';
class VisionTaskRunnerFake extends VisionTaskRunner<void> { class VisionTaskRunnerFake extends VisionTaskRunner<void> {
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
protected override process(): void {} protected override process(): void {}
protected override refreshGraph(): void {}
override setOptions(options: VisionTaskOptions): Promise<void> {
return this.applyOptions(options);
}
override processImageData(image: ImageSource): void { override processImageData(image: ImageSource): void {
super.processImageData(image); super.processImageData(image);
} }
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
} }
describe('VisionTaskRunner', () => { describe('VisionTaskRunner', () => {
const streamMode = {
modelAsset: undefined,
useStreamMode: true,
acceleration: undefined,
};
const imageMode = {
modelAsset: undefined,
useStreamMode: false,
acceleration: undefined,
};
let visionTaskRunner: VisionTaskRunnerFake; let visionTaskRunner: VisionTaskRunnerFake;
beforeEach(() => { beforeEach(async () => {
visionTaskRunner = new VisionTaskRunnerFake(); visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('can enable image mode', async () => { it('can enable image mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'image'}); await visionTaskRunner.setOptions({runningMode: 'image'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
}); });
it('can enable video mode', async () => { it('can enable video mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'}); await visionTaskRunner.setOptions({runningMode: 'video'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: true}));
}); });
it('can clear running mode', async () => { it('can clear running mode', async () => {
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
// Clear running mode // Clear running mode
await visionTaskRunner.setOptions({runningMode: undefined}); await visionTaskRunner.setOptions({runningMode: undefined});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
}); });
it('cannot process images with video mode', async () => { it('cannot process images with video mode', async () => {

View File

@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
/** Base class for all MediaPipe Vision Tasks. */ /** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner<T> extends TaskRunner { export abstract class VisionTaskRunner<T> extends TaskRunner {
/** Configures the shared options of a vision task. */ /** Configures the shared options of a vision task. */
override async setOptions(options: VisionTaskOptions): Promise<void> { override applyOptions(options: VisionTaskOptions): Promise<void> {
await super.setOptions(options);
if ('runningMode' in options) { if ('runningMode' in options) {
const useStreamMode = const useStreamMode =
!!options.runningMode && options.runningMode !== 'image'; !!options.runningMode && options.runningMode !== 'image';
this.baseOptions.setUseStreamMode(useStreamMode); this.baseOptions.setUseStreamMode(useStreamMode);
} }
return super.applyOptions(options);
} }
/** Sends an image packet to the graph and awaits results. */ /** Sends an image packet to the graph and awaits results. */

View File

@ -169,9 +169,7 @@ export class GestureRecognizer extends
* *
* @param options The options for the gesture recognizer. * @param options The options for the gesture recognizer.
*/ */
override async setOptions(options: GestureRecognizerOptions): Promise<void> { override setOptions(options: GestureRecognizerOptions): Promise<void> {
await super.setOptions(options);
if ('numHands' in options) { if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands( this.handDetectorGraphOptions.setNumHands(
options.numHands ?? DEFAULT_NUM_HANDS); options.numHands ?? DEFAULT_NUM_HANDS);
@ -221,7 +219,7 @@ export class GestureRecognizer extends
?.clearClassifierOptions(); ?.clearClassifierOptions();
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -342,7 +340,7 @@ export class GestureRecognizer extends
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);

View File

@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
gestureRecognizer = new GestureRecognizerFake(); gestureRecognizer = new GestureRecognizerFake();
await gestureRecognizer.setOptions({}); // Initialize graph await gestureRecognizer.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
* *
* @param options The options for the hand landmarker. * @param options The options for the hand landmarker.
*/ */
override async setOptions(options: HandLandmarkerOptions): Promise<void> { override setOptions(options: HandLandmarkerOptions): Promise<void> {
await super.setOptions(options);
// Configure hand detector options. // Configure hand detector options.
if ('numHands' in options) { if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands( this.handDetectorGraphOptions.setNumHands(
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);

View File

@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
handLandmarker = new HandLandmarkerFake(); handLandmarker = new HandLandmarkerFake();
await handLandmarker.setOptions({}); // Initialize graph await handLandmarker.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
* *
* @param options The options for the image classifier. * @param options The options for the image classifier.
*/ */
override async setOptions(options: ImageClassifierOptions): Promise<void> { override setOptions(options: ImageClassifierOptions): Promise<void> {
await super.setOptions(options);
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
imageClassifier = new ImageClassifierFake(); imageClassifier = new ImageClassifierFake();
await imageClassifier.setOptions({}); // Initialize graph await imageClassifier.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
* *
* @param options The options for the image embedder. * @param options The options for the image embedder.
*/ */
override async setOptions(options: ImageEmbedderOptions): Promise<void> { override setOptions(options: ImageEmbedderOptions): Promise<void> {
await super.setOptions(options);
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM);

View File

@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
imageEmbedder = new ImageEmbedderFake(); imageEmbedder = new ImageEmbedderFake();
await imageEmbedder.setOptions({}); // Initialize graph await imageEmbedder.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {

View File

@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
* *
* @param options The options for the object detector. * @param options The options for the object detector.
*/ */
override async setOptions(options: ObjectDetectorOptions): Promise<void> { override setOptions(options: ObjectDetectorOptions): Promise<void> {
await super.setOptions(options);
// Note that we have to support both JSPB and ProtobufJS, hence we // Note that we have to support both JSPB and ProtobufJS, hence we
// have to expliclity clear the values instead of setting them to // have to expliclity clear the values instead of setting them to
// `undefined`. // `undefined`.
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
this.options.clearCategoryDenylistList(); this.options.clearCategoryDenylistList();
} }
this.refreshGraph(); return this.applyOptions(options);
} }
/** /**
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(DETECTIONS_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM);

View File

@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
beforeEach(async () => { beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
objectDetector = new ObjectDetectorFake(); objectDetector = new ObjectDetectorFake();
await objectDetector.setOptions({}); // Initialize graph await objectDetector.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('initializes graph', async () => { it('initializes graph', async () => {