Apply most graph options synchronously
PiperOrigin-RevId: 498244085
This commit is contained in:
parent
7e36a5e2ae
commit
9580f04571
|
@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
*
|
||||
* @param options The options for the audio classifier.
|
||||
*/
|
||||
override async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
|
|
|
@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
audioClassifier = new AudioClassifierFake();
|
||||
await audioClassifier.setOptions({}); // Initialize graph
|
||||
await audioClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
*
|
||||
* @param options The options for the audio embedder.
|
||||
*/
|
||||
override async setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
|
|
|
@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
audioEmbedder = new AudioEmbedderFake();
|
||||
await audioEmbedder.setOptions({}); // Initialize graph
|
||||
await audioEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', () => {
|
||||
|
|
|
@ -103,29 +103,3 @@ jasmine_node_test(
|
|||
name = "embedder_options_test",
|
||||
deps = [":embedder_options_test_lib"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "base_options",
|
||||
srcs = [
|
||||
"base_options.ts",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "base_options_test_lib",
|
||||
testonly = True,
|
||||
srcs = ["base_options.test.ts"],
|
||||
deps = [":base_options"],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "base_options_test",
|
||||
deps = [":base_options_test_lib"],
|
||||
)
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
// Placeholder for internal dependency on trusted resource URL builder
|
||||
|
||||
import {convertBaseOptionsToProto} from './base_options';
|
||||
|
||||
describe('convertBaseOptionsToProto()', () => {
|
||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||
const mockBytesResult = {
|
||||
modelAsset: {
|
||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined,
|
||||
},
|
||||
useStreamMode: false,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: undefined,
|
||||
tflite: {},
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
|
||||
beforeEach(() => {
|
||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||
expect(url).toEqual('foo');
|
||||
return {
|
||||
arrayBuffer: () => mockBytes.buffer,
|
||||
} as unknown as Response;
|
||||
});
|
||||
global.fetch = fetchSpy;
|
||||
});
|
||||
|
||||
it('verifies that at least one model asset option is provided', async () => {
|
||||
await expectAsync(convertBaseOptionsToProto({}))
|
||||
.toBeRejectedWithError(
|
||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||
});
|
||||
|
||||
it('verifies that no more than one model asset option is provided', async () => {
|
||||
await expectAsync(convertBaseOptionsToProto({
|
||||
modelAssetPath: `foo`,
|
||||
modelAssetBuffer: new Uint8Array([])
|
||||
}))
|
||||
.toBeRejectedWithError(
|
||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||
});
|
||||
|
||||
it('downloads model', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetPath: `foo`,
|
||||
});
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('does not download model when bytes are provided', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
});
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable CPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'CPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable GPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: 0,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: 2,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('can reset delegate', async () => {
|
||||
let baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
});
|
||||
// Clear backend
|
||||
baseOptionsProto =
|
||||
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
});
|
|
@ -1,80 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* Converts a BaseOptions API object to its Protobuf representation.
|
||||
* @throws If neither a model assset path or buffer is provided
|
||||
*/
|
||||
export async function convertBaseOptionsToProto(
|
||||
updatedOptions: BaseOptions,
|
||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
||||
const result =
|
||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
||||
|
||||
await configureExternalFile(updatedOptions, result);
|
||||
configureAcceleration(updatedOptions, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configues the `externalFile` option and validates that a single model is
|
||||
* provided.
|
||||
*/
|
||||
async function configureExternalFile(
|
||||
options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
||||
proto.setModelAsset(externalFile);
|
||||
|
||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
}
|
||||
|
||||
let modelAssetBuffer = options.modelAssetBuffer;
|
||||
if (!modelAssetBuffer) {
|
||||
const response = await fetch(options.modelAssetPath!.toString());
|
||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||
}
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
|
||||
if (!externalFile.hasFileContent()) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
}
|
||||
|
||||
/** Configues the `acceleration` option. */
|
||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const acceleration = proto.getAcceleration() ?? new Acceleration();
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
proto.setAcceleration(acceleration);
|
||||
}
|
|
@ -18,8 +18,10 @@ mediapipe_ts_library(
|
|||
srcs = ["task_runner.ts"],
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||
|
@ -53,6 +55,7 @@ mediapipe_ts_library(
|
|||
"task_runner_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":task_runner",
|
||||
":task_runner_test_utils",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
|
||||
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
||||
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
||||
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||
|
@ -91,14 +93,52 @@ export abstract class TaskRunner {
|
|||
this.graphRunner.registerModelResourcesGraphService();
|
||||
}
|
||||
|
||||
/** Configures the shared options of a MediaPipe Task. */
|
||||
async setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
this.baseOptions = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.baseOptions);
|
||||
/** Configures the task with custom options. */
|
||||
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Applies the current set of options, including any base options that have
|
||||
* not been processed by the task implementation. The options are applied
|
||||
* synchronously unless a `modelAssetPath` is provided. This ensures that
|
||||
* for most use cases options are applied directly and immediately affect
|
||||
* the next inference.
|
||||
*/
|
||||
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||
|
||||
// Validate that exactly one model is configured
|
||||
if (options.baseOptions?.modelAssetBuffer &&
|
||||
options.baseOptions?.modelAssetPath) {
|
||||
throw new Error(
|
||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||
options.baseOptions?.modelAssetBuffer ||
|
||||
options.baseOptions?.modelAssetPath)) {
|
||||
throw new Error(
|
||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||
}
|
||||
|
||||
this.setAcceleration(baseOptions);
|
||||
if (baseOptions.modelAssetPath) {
|
||||
// We don't use `await` here since we want to apply most settings
|
||||
// synchronously.
|
||||
return fetch(baseOptions.modelAssetPath.toString())
|
||||
.then(response => response.arrayBuffer())
|
||||
.then(buffer => {
|
||||
this.setExternalFile(new Uint8Array(buffer));
|
||||
this.refreshGraph();
|
||||
});
|
||||
} else {
|
||||
// Apply the setting synchronously.
|
||||
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||
this.refreshGraph();
|
||||
return Promise.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
/** Appliest the current options to the MediaPipe graph. */
|
||||
protected abstract refreshGraph(): void;
|
||||
|
||||
/**
|
||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||
|
@ -140,6 +180,27 @@ export abstract class TaskRunner {
|
|||
}
|
||||
this.processingErrors = [];
|
||||
}
|
||||
|
||||
/** Configures the `externalFile` option */
|
||||
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
||||
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||
if (modelAssetBuffer) {
|
||||
externalFile.setFileContent(modelAssetBuffer);
|
||||
}
|
||||
this.baseOptions.setModelAsset(externalFile);
|
||||
}
|
||||
|
||||
/** Configures the `acceleration` option. */
|
||||
private setAcceleration(options: BaseOptions) {
|
||||
const acceleration =
|
||||
this.baseOptions.getAcceleration() ?? new Acceleration();
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
}
|
||||
this.baseOptions.setAcceleration(acceleration);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -15,18 +15,22 @@
|
|||
*/
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource URL builder
|
||||
|
||||
import {GraphRunnerImageLib} from './task_runner';
|
||||
import {TaskRunnerOptions} from './task_runner_options.d';
|
||||
|
||||
class TaskRunnerFake extends TaskRunner {
|
||||
protected baseOptions = new BaseOptionsProto();
|
||||
private errorListener: ErrorListener|undefined;
|
||||
private errors: string[] = [];
|
||||
|
||||
baseOptions = new BaseOptionsProto();
|
||||
|
||||
static createFake(): TaskRunnerFake {
|
||||
const wasmModule = createSpyWasmModule();
|
||||
return new TaskRunnerFake(wasmModule);
|
||||
|
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
|
|||
super.finishProcessing();
|
||||
}
|
||||
|
||||
override refreshGraph(): void {}
|
||||
|
||||
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||
super.setGraph(graphData, isBinary);
|
||||
}
|
||||
|
||||
setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
private throwErrors(): void {
|
||||
expect(this.errorListener).toBeDefined();
|
||||
for (const error of this.errors) {
|
||||
|
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
|
|||
}
|
||||
|
||||
describe('TaskRunner', () => {
|
||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||
const mockBytesResult = {
|
||||
modelAsset: {
|
||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined,
|
||||
},
|
||||
useStreamMode: false,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: undefined,
|
||||
tflite: {},
|
||||
},
|
||||
};
|
||||
|
||||
let fetchSpy: jasmine.Spy;
|
||||
let taskRunner: TaskRunnerFake;
|
||||
|
||||
beforeEach(() => {
|
||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||
expect(url).toEqual('foo');
|
||||
return {
|
||||
arrayBuffer: () => mockBytes.buffer,
|
||||
} as unknown as Response;
|
||||
});
|
||||
global.fetch = fetchSpy;
|
||||
|
||||
taskRunner = TaskRunnerFake.createFake();
|
||||
});
|
||||
|
||||
it('handles errors during graph update', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.enqueueError('Test error');
|
||||
|
||||
expect(() => {
|
||||
|
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
|
|||
});
|
||||
|
||||
it('handles errors during graph execution', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||
|
||||
taskRunner.enqueueError('Test error');
|
||||
|
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
|
|||
});
|
||||
|
||||
it('can handle multiple errors', () => {
|
||||
const taskRunner = TaskRunnerFake.createFake();
|
||||
taskRunner.enqueueError('Test error 1');
|
||||
taskRunner.enqueueError('Test error 2');
|
||||
|
||||
|
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
|
|||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||
}).toThrowError(/Test error 1, Test error 2/);
|
||||
});
|
||||
|
||||
it('verifies that at least one model asset option is provided', () => {
|
||||
expect(() => {
|
||||
taskRunner.setOptions({});
|
||||
})
|
||||
.toThrowError(
|
||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||
});
|
||||
|
||||
it('verifies that no more than one model asset option is provided', () => {
|
||||
expect(() => {
|
||||
taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetPath: `foo`,
|
||||
modelAssetBuffer: new Uint8Array([])
|
||||
}
|
||||
});
|
||||
})
|
||||
.toThrowError(
|
||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||
});
|
||||
|
||||
it('doesn\'t require model once it is configured', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
expect(() => {
|
||||
taskRunner.setOptions({});
|
||||
}).not.toThrowError();
|
||||
});
|
||||
|
||||
it('downloads model', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetPath: `foo`}});
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('does not download model when bytes are provided', async () => {
|
||||
await taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
|
||||
expect(fetchSpy).not.toHaveBeenCalled();
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('changes model synchronously when bytes are provided', () => {
|
||||
const resolvedPromise = taskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||
|
||||
// Check that the change has been applied even though we do not await the
|
||||
// above Promise
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
return resolvedPromise;
|
||||
});
|
||||
|
||||
it('can enable CPU delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'CPU',
|
||||
}
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
||||
it('can enable GPU delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
acceleration: {
|
||||
xnnpack: undefined,
|
||||
gpu: {
|
||||
useAdvancedGpuApi: false,
|
||||
api: 0,
|
||||
allowPrecisionLoss: true,
|
||||
cachedKernelPath: undefined,
|
||||
serializedModelDir: undefined,
|
||||
modelToken: undefined,
|
||||
usage: 2,
|
||||
},
|
||||
tflite: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('can reset delegate', async () => {
|
||||
await taskRunner.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'GPU',
|
||||
}
|
||||
});
|
||||
// Clear backend
|
||||
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
|
|||
*
|
||||
* @param options The options for the text classifier.
|
||||
*/
|
||||
override async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: TextClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
|
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
|
|
@ -56,7 +56,8 @@ describe('TextClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
textClassifier = new TextClassifierFake();
|
||||
await textClassifier.setOptions({}); // Initialize graph
|
||||
await textClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
|
|||
*
|
||||
* @param options The options for the text embedder.
|
||||
*/
|
||||
override async setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
|
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||
|
|
|
@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
textEmbedder = new TextEmbedderFake();
|
||||
await textEmbedder.setOptions({}); // Initialize graph
|
||||
await textEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -29,6 +29,7 @@ mediapipe_ts_library(
|
|||
testonly = True,
|
||||
srcs = ["vision_task_runner.test.ts"],
|
||||
deps = [
|
||||
":vision_task_options",
|
||||
":vision_task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
|
|
|
@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
|||
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||
|
||||
import {VisionTaskOptions} from './vision_task_options';
|
||||
import {VisionTaskRunner} from './vision_task_runner';
|
||||
|
||||
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||
|
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
|||
|
||||
protected override process(): void {}
|
||||
|
||||
protected override refreshGraph(): void {}
|
||||
|
||||
override setOptions(options: VisionTaskOptions): Promise<void> {
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
override processImageData(image: ImageSource): void {
|
||||
super.processImageData(image);
|
||||
}
|
||||
|
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
|||
}
|
||||
|
||||
describe('VisionTaskRunner', () => {
|
||||
const streamMode = {
|
||||
modelAsset: undefined,
|
||||
useStreamMode: true,
|
||||
acceleration: undefined,
|
||||
};
|
||||
|
||||
const imageMode = {
|
||||
modelAsset: undefined,
|
||||
useStreamMode: false,
|
||||
acceleration: undefined,
|
||||
};
|
||||
|
||||
let visionTaskRunner: VisionTaskRunnerFake;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
visionTaskRunner = new VisionTaskRunnerFake();
|
||||
await visionTaskRunner.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('can enable image mode', async () => {
|
||||
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||
});
|
||||
|
||||
it('can enable video mode', async () => {
|
||||
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: true}));
|
||||
});
|
||||
|
||||
it('can clear running mode', async () => {
|
||||
|
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
|
|||
|
||||
// Clear running mode
|
||||
await visionTaskRunner.setOptions({runningMode: undefined});
|
||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||
expect(visionTaskRunner.baseOptions.toObject())
|
||||
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||
});
|
||||
|
||||
it('cannot process images with video mode', async () => {
|
||||
|
|
|
@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
|
|||
/** Base class for all MediaPipe Vision Tasks. */
|
||||
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
||||
/** Configures the shared options of a vision task. */
|
||||
override async setOptions(options: VisionTaskOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override applyOptions(options: VisionTaskOptions): Promise<void> {
|
||||
if ('runningMode' in options) {
|
||||
const useStreamMode =
|
||||
!!options.runningMode && options.runningMode !== 'image';
|
||||
this.baseOptions.setUseStreamMode(useStreamMode);
|
||||
}
|
||||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
/** Sends an image packet to the graph and awaits results. */
|
||||
|
|
|
@ -169,9 +169,7 @@ export class GestureRecognizer extends
|
|||
*
|
||||
* @param options The options for the gesture recognizer.
|
||||
*/
|
||||
override async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||
if ('numHands' in options) {
|
||||
this.handDetectorGraphOptions.setNumHands(
|
||||
options.numHands ?? DEFAULT_NUM_HANDS);
|
||||
|
@ -221,7 +219,7 @@ export class GestureRecognizer extends
|
|||
?.clearClassifierOptions();
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -342,7 +340,7 @@ export class GestureRecognizer extends
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
|
|
|
@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
gestureRecognizer = new GestureRecognizerFake();
|
||||
await gestureRecognizer.setOptions({}); // Initialize graph
|
||||
await gestureRecognizer.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
*
|
||||
* @param options The options for the hand landmarker.
|
||||
*/
|
||||
override async setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||
// Configure hand detector options.
|
||||
if ('numHands' in options) {
|
||||
this.handDetectorGraphOptions.setNumHands(
|
||||
|
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
|
|
|
@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
handLandmarker = new HandLandmarkerFake();
|
||||
await handLandmarker.setOptions({}); // Initialize graph
|
||||
await handLandmarker.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
*
|
||||
* @param options The options for the image classifier.
|
||||
*/
|
||||
override async setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
|
|
@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
imageClassifier = new ImageClassifierFake();
|
||||
await imageClassifier.setOptions({}); // Initialize graph
|
||||
await imageClassifier.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
*
|
||||
* @param options The options for the image embedder.
|
||||
*/
|
||||
override async setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
override setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||
options, this.options.getEmbedderOptions()));
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||
|
|
|
@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
imageEmbedder = new ImageEmbedderFake();
|
||||
await imageEmbedder.setOptions({}); // Initialize graph
|
||||
await imageEmbedder.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
|
@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
*
|
||||
* @param options The options for the object detector.
|
||||
*/
|
||||
override async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
|
||||
override setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||
// have to expliclity clear the values instead of setting them to
|
||||
// `undefined`.
|
||||
|
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
this.options.clearCategoryDenylistList();
|
||||
}
|
||||
|
||||
this.refreshGraph();
|
||||
return this.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||
|
|
|
@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
|
|||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
objectDetector = new ObjectDetectorFake();
|
||||
await objectDetector.setOptions({}); // Initialize graph
|
||||
await objectDetector.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
|
|
Loading…
Reference in New Issue
Block a user