Apply most graph options synchronously
PiperOrigin-RevId: 498244085
This commit is contained in:
parent
7e36a5e2ae
commit
9580f04571
|
@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the audio classifier.
|
* @param options The options for the audio classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: AudioClassifierOptions): Promise<void> {
|
override setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
|
|
|
@ -79,7 +79,8 @@ describe('AudioClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
audioClassifier = new AudioClassifierFake();
|
audioClassifier = new AudioClassifierFake();
|
||||||
await audioClassifier.setOptions({}); // Initialize graph
|
await audioClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the audio embedder.
|
* @param options The options for the audio embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: AudioEmbedderOptions): Promise<void> {
|
override setOptions(options: AudioEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
|
|
|
@ -70,7 +70,8 @@ describe('AudioEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
audioEmbedder = new AudioEmbedderFake();
|
audioEmbedder = new AudioEmbedderFake();
|
||||||
await audioEmbedder.setOptions({}); // Initialize graph
|
await audioEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', () => {
|
it('initializes graph', () => {
|
||||||
|
|
|
@ -103,29 +103,3 @@ jasmine_node_test(
|
||||||
name = "embedder_options_test",
|
name = "embedder_options_test",
|
||||||
deps = [":embedder_options_test_lib"],
|
deps = [":embedder_options_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
|
||||||
name = "base_options",
|
|
||||||
srcs = [
|
|
||||||
"base_options.ts",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
|
||||||
"//mediapipe/tasks/web/core",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_ts_library(
|
|
||||||
name = "base_options_test_lib",
|
|
||||||
testonly = True,
|
|
||||||
srcs = ["base_options.test.ts"],
|
|
||||||
deps = [":base_options"],
|
|
||||||
)
|
|
||||||
|
|
||||||
jasmine_node_test(
|
|
||||||
name = "base_options_test",
|
|
||||||
deps = [":base_options_test_lib"],
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import 'jasmine';
|
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
|
||||||
// Placeholder for internal dependency on trusted resource URL builder
|
|
||||||
|
|
||||||
import {convertBaseOptionsToProto} from './base_options';
|
|
||||||
|
|
||||||
describe('convertBaseOptionsToProto()', () => {
|
|
||||||
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
|
||||||
const mockBytesResult = {
|
|
||||||
modelAsset: {
|
|
||||||
fileContent: Buffer.from(mockBytes).toString('base64'),
|
|
||||||
fileName: undefined,
|
|
||||||
fileDescriptorMeta: undefined,
|
|
||||||
filePointerMeta: undefined,
|
|
||||||
},
|
|
||||||
useStreamMode: false,
|
|
||||||
acceleration: {
|
|
||||||
xnnpack: undefined,
|
|
||||||
gpu: undefined,
|
|
||||||
tflite: {},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let fetchSpy: jasmine.Spy;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
|
||||||
expect(url).toEqual('foo');
|
|
||||||
return {
|
|
||||||
arrayBuffer: () => mockBytes.buffer,
|
|
||||||
} as unknown as Response;
|
|
||||||
});
|
|
||||||
global.fetch = fetchSpy;
|
|
||||||
});
|
|
||||||
|
|
||||||
it('verifies that at least one model asset option is provided', async () => {
|
|
||||||
await expectAsync(convertBaseOptionsToProto({}))
|
|
||||||
.toBeRejectedWithError(
|
|
||||||
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('verifies that no more than one model asset option is provided', async () => {
|
|
||||||
await expectAsync(convertBaseOptionsToProto({
|
|
||||||
modelAssetPath: `foo`,
|
|
||||||
modelAssetBuffer: new Uint8Array([])
|
|
||||||
}))
|
|
||||||
.toBeRejectedWithError(
|
|
||||||
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('downloads model', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetPath: `foo`,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(fetchSpy).toHaveBeenCalled();
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('does not download model when bytes are provided', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(fetchSpy).not.toHaveBeenCalled();
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can enable CPU delegate', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'CPU',
|
|
||||||
});
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can enable GPU delegate', async () => {
|
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'GPU',
|
|
||||||
});
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual({
|
|
||||||
...mockBytesResult,
|
|
||||||
acceleration: {
|
|
||||||
xnnpack: undefined,
|
|
||||||
gpu: {
|
|
||||||
useAdvancedGpuApi: false,
|
|
||||||
api: 0,
|
|
||||||
allowPrecisionLoss: true,
|
|
||||||
cachedKernelPath: undefined,
|
|
||||||
serializedModelDir: undefined,
|
|
||||||
modelToken: undefined,
|
|
||||||
usage: 2,
|
|
||||||
},
|
|
||||||
tflite: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('can reset delegate', async () => {
|
|
||||||
let baseOptionsProto = await convertBaseOptionsToProto({
|
|
||||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
|
||||||
delegate: 'GPU',
|
|
||||||
});
|
|
||||||
// Clear backend
|
|
||||||
baseOptionsProto =
|
|
||||||
await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto);
|
|
||||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
|
||||||
});
|
|
||||||
});
|
|
|
@ -1,80 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
|
|
||||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
|
||||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
|
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a BaseOptions API object to its Protobuf representation.
|
|
||||||
* @throws If neither a model assset path or buffer is provided
|
|
||||||
*/
|
|
||||||
export async function convertBaseOptionsToProto(
|
|
||||||
updatedOptions: BaseOptions,
|
|
||||||
currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
|
|
||||||
const result =
|
|
||||||
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
|
|
||||||
|
|
||||||
await configureExternalFile(updatedOptions, result);
|
|
||||||
configureAcceleration(updatedOptions, result);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configues the `externalFile` option and validates that a single model is
|
|
||||||
* provided.
|
|
||||||
*/
|
|
||||||
async function configureExternalFile(
|
|
||||||
options: BaseOptions, proto: BaseOptionsProto) {
|
|
||||||
const externalFile = proto.getModelAsset() || new ExternalFile();
|
|
||||||
proto.setModelAsset(externalFile);
|
|
||||||
|
|
||||||
if (options.modelAssetPath || options.modelAssetBuffer) {
|
|
||||||
if (options.modelAssetPath && options.modelAssetBuffer) {
|
|
||||||
throw new Error(
|
|
||||||
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
|
||||||
}
|
|
||||||
|
|
||||||
let modelAssetBuffer = options.modelAssetBuffer;
|
|
||||||
if (!modelAssetBuffer) {
|
|
||||||
const response = await fetch(options.modelAssetPath!.toString());
|
|
||||||
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
|
||||||
}
|
|
||||||
externalFile.setFileContent(modelAssetBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!externalFile.hasFileContent()) {
|
|
||||||
throw new Error(
|
|
||||||
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Configues the `acceleration` option. */
|
|
||||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
|
||||||
const acceleration = proto.getAcceleration() ?? new Acceleration();
|
|
||||||
if (options.delegate === 'GPU') {
|
|
||||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
|
||||||
} else {
|
|
||||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
|
||||||
}
|
|
||||||
proto.setAcceleration(acceleration);
|
|
||||||
}
|
|
|
@ -18,8 +18,10 @@ mediapipe_ts_library(
|
||||||
srcs = ["task_runner.ts"],
|
srcs = ["task_runner.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||||
|
@ -53,6 +55,7 @@ mediapipe_ts_library(
|
||||||
"task_runner_test.ts",
|
"task_runner_test.ts",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":core",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":task_runner_test_utils",
|
":task_runner_test_utils",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
|
|
@ -14,9 +14,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||||
|
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
|
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
||||||
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||||
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||||
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
||||||
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||||
|
@ -91,14 +93,52 @@ export abstract class TaskRunner {
|
||||||
this.graphRunner.registerModelResourcesGraphService();
|
this.graphRunner.registerModelResourcesGraphService();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Configures the shared options of a MediaPipe Task. */
|
/** Configures the task with custom options. */
|
||||||
async setOptions(options: TaskRunnerOptions): Promise<void> {
|
abstract setOptions(options: TaskRunnerOptions): Promise<void>;
|
||||||
if (options.baseOptions) {
|
|
||||||
this.baseOptions = await convertBaseOptionsToProto(
|
/**
|
||||||
options.baseOptions, this.baseOptions);
|
* Applies the current set of options, including any base options that have
|
||||||
|
* not been processed by the task implementation. The options are applied
|
||||||
|
* synchronously unless a `modelAssetPath` is provided. This ensures that
|
||||||
|
* for most use cases options are applied directly and immediately affect
|
||||||
|
* the next inference.
|
||||||
|
*/
|
||||||
|
protected applyOptions(options: TaskRunnerOptions): Promise<void> {
|
||||||
|
const baseOptions: BaseOptions = options.baseOptions || {};
|
||||||
|
|
||||||
|
// Validate that exactly one model is configured
|
||||||
|
if (options.baseOptions?.modelAssetBuffer &&
|
||||||
|
options.baseOptions?.modelAssetPath) {
|
||||||
|
throw new Error(
|
||||||
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||||
|
} else if (!(this.baseOptions.getModelAsset()?.hasFileContent() ||
|
||||||
|
options.baseOptions?.modelAssetBuffer ||
|
||||||
|
options.baseOptions?.modelAssetPath)) {
|
||||||
|
throw new Error(
|
||||||
|
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setAcceleration(baseOptions);
|
||||||
|
if (baseOptions.modelAssetPath) {
|
||||||
|
// We don't use `await` here since we want to apply most settings
|
||||||
|
// synchronously.
|
||||||
|
return fetch(baseOptions.modelAssetPath.toString())
|
||||||
|
.then(response => response.arrayBuffer())
|
||||||
|
.then(buffer => {
|
||||||
|
this.setExternalFile(new Uint8Array(buffer));
|
||||||
|
this.refreshGraph();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Apply the setting synchronously.
|
||||||
|
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||||
|
this.refreshGraph();
|
||||||
|
return Promise.resolve();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Appliest the current options to the MediaPipe graph. */
|
||||||
|
protected abstract refreshGraph(): void;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||||
|
@ -140,6 +180,27 @@ export abstract class TaskRunner {
|
||||||
}
|
}
|
||||||
this.processingErrors = [];
|
this.processingErrors = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Configures the `externalFile` option */
|
||||||
|
private setExternalFile(modelAssetBuffer?: Uint8Array): void {
|
||||||
|
const externalFile = this.baseOptions.getModelAsset() || new ExternalFile();
|
||||||
|
if (modelAssetBuffer) {
|
||||||
|
externalFile.setFileContent(modelAssetBuffer);
|
||||||
|
}
|
||||||
|
this.baseOptions.setModelAsset(externalFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Configures the `acceleration` option. */
|
||||||
|
private setAcceleration(options: BaseOptions) {
|
||||||
|
const acceleration =
|
||||||
|
this.baseOptions.getAcceleration() ?? new Acceleration();
|
||||||
|
if (options.delegate === 'GPU') {
|
||||||
|
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||||
|
} else {
|
||||||
|
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||||
|
}
|
||||||
|
this.baseOptions.setAcceleration(acceleration);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,18 +15,22 @@
|
||||||
*/
|
*/
|
||||||
import 'jasmine';
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../tasks/web/core/task_runner';
|
||||||
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
|
||||||
|
// Placeholder for internal dependency on trusted resource URL builder
|
||||||
|
|
||||||
import {GraphRunnerImageLib} from './task_runner';
|
import {GraphRunnerImageLib} from './task_runner';
|
||||||
|
import {TaskRunnerOptions} from './task_runner_options.d';
|
||||||
|
|
||||||
class TaskRunnerFake extends TaskRunner {
|
class TaskRunnerFake extends TaskRunner {
|
||||||
protected baseOptions = new BaseOptionsProto();
|
|
||||||
private errorListener: ErrorListener|undefined;
|
private errorListener: ErrorListener|undefined;
|
||||||
private errors: string[] = [];
|
private errors: string[] = [];
|
||||||
|
|
||||||
|
baseOptions = new BaseOptionsProto();
|
||||||
|
|
||||||
static createFake(): TaskRunnerFake {
|
static createFake(): TaskRunnerFake {
|
||||||
const wasmModule = createSpyWasmModule();
|
const wasmModule = createSpyWasmModule();
|
||||||
return new TaskRunnerFake(wasmModule);
|
return new TaskRunnerFake(wasmModule);
|
||||||
|
@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
super.finishProcessing();
|
super.finishProcessing();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override refreshGraph(): void {}
|
||||||
|
|
||||||
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||||
super.setGraph(graphData, isBinary);
|
super.setGraph(graphData, isBinary);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOptions(options: TaskRunnerOptions): Promise<void> {
|
||||||
|
return this.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
private throwErrors(): void {
|
private throwErrors(): void {
|
||||||
expect(this.errorListener).toBeDefined();
|
expect(this.errorListener).toBeDefined();
|
||||||
for (const error of this.errors) {
|
for (const error of this.errors) {
|
||||||
|
@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('TaskRunner', () => {
|
describe('TaskRunner', () => {
|
||||||
|
const mockBytes = new Uint8Array([0, 1, 2, 3]);
|
||||||
|
const mockBytesResult = {
|
||||||
|
modelAsset: {
|
||||||
|
fileContent: Buffer.from(mockBytes).toString('base64'),
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined,
|
||||||
|
},
|
||||||
|
useStreamMode: false,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: undefined,
|
||||||
|
tflite: {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let fetchSpy: jasmine.Spy;
|
||||||
|
let taskRunner: TaskRunnerFake;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
fetchSpy = jasmine.createSpy().and.callFake(async url => {
|
||||||
|
expect(url).toEqual('foo');
|
||||||
|
return {
|
||||||
|
arrayBuffer: () => mockBytes.buffer,
|
||||||
|
} as unknown as Response;
|
||||||
|
});
|
||||||
|
global.fetch = fetchSpy;
|
||||||
|
|
||||||
|
taskRunner = TaskRunnerFake.createFake();
|
||||||
|
});
|
||||||
|
|
||||||
it('handles errors during graph update', () => {
|
it('handles errors during graph update', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.enqueueError('Test error');
|
taskRunner.enqueueError('Test error');
|
||||||
|
|
||||||
expect(() => {
|
expect(() => {
|
||||||
|
@ -85,7 +125,6 @@ describe('TaskRunner', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('handles errors during graph execution', () => {
|
it('handles errors during graph execution', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||||
|
|
||||||
taskRunner.enqueueError('Test error');
|
taskRunner.enqueueError('Test error');
|
||||||
|
@ -96,7 +135,6 @@ describe('TaskRunner', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can handle multiple errors', () => {
|
it('can handle multiple errors', () => {
|
||||||
const taskRunner = TaskRunnerFake.createFake();
|
|
||||||
taskRunner.enqueueError('Test error 1');
|
taskRunner.enqueueError('Test error 1');
|
||||||
taskRunner.enqueueError('Test error 2');
|
taskRunner.enqueueError('Test error 2');
|
||||||
|
|
||||||
|
@ -104,4 +142,106 @@ describe('TaskRunner', () => {
|
||||||
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true);
|
||||||
}).toThrowError(/Test error 1, Test error 2/);
|
}).toThrowError(/Test error 1, Test error 2/);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('verifies that at least one model asset option is provided', () => {
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({});
|
||||||
|
})
|
||||||
|
.toThrowError(
|
||||||
|
/Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('verifies that no more than one model asset option is provided', () => {
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetPath: `foo`,
|
||||||
|
modelAssetBuffer: new Uint8Array([])
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.toThrowError(
|
||||||
|
/Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('doesn\'t require model once it is configured', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
expect(() => {
|
||||||
|
taskRunner.setOptions({});
|
||||||
|
}).not.toThrowError();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('downloads model', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetPath: `foo`}});
|
||||||
|
|
||||||
|
expect(fetchSpy).toHaveBeenCalled();
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not download model when bytes are provided', async () => {
|
||||||
|
await taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
|
||||||
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('changes model synchronously when bytes are provided', () => {
|
||||||
|
const resolvedPromise = taskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}});
|
||||||
|
|
||||||
|
// Check that the change has been applied even though we do not await the
|
||||||
|
// above Promise
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
return resolvedPromise;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable CPU delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'CPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable GPU delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'GPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual({
|
||||||
|
...mockBytesResult,
|
||||||
|
acceleration: {
|
||||||
|
xnnpack: undefined,
|
||||||
|
gpu: {
|
||||||
|
useAdvancedGpuApi: false,
|
||||||
|
api: 0,
|
||||||
|
allowPrecisionLoss: true,
|
||||||
|
cachedKernelPath: undefined,
|
||||||
|
serializedModelDir: undefined,
|
||||||
|
modelToken: undefined,
|
||||||
|
usage: 2,
|
||||||
|
},
|
||||||
|
tflite: undefined,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can reset delegate', async () => {
|
||||||
|
await taskRunner.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||||
|
delegate: 'GPU',
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Clear backend
|
||||||
|
await taskRunner.setOptions({baseOptions: {delegate: undefined}});
|
||||||
|
expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text classifier.
|
* @param options The options for the text classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: TextClassifierOptions): Promise<void> {
|
override setOptions(options: TextClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
|
@ -56,7 +56,8 @@ describe('TextClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
textClassifier = new TextClassifierFake();
|
textClassifier = new TextClassifierFake();
|
||||||
await textClassifier.setOptions({}); // Initialize graph
|
await textClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text embedder.
|
* @param options The options for the text embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: TextEmbedderOptions): Promise<void> {
|
override setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||||
|
|
|
@ -56,7 +56,8 @@ describe('TextEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
textEmbedder = new TextEmbedderFake();
|
textEmbedder = new TextEmbedderFake();
|
||||||
await textEmbedder.setOptions({}); // Initialize graph
|
await textEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -29,6 +29,7 @@ mediapipe_ts_library(
|
||||||
testonly = True,
|
testonly = True,
|
||||||
srcs = ["vision_task_runner.test.ts"],
|
srcs = ["vision_task_runner.test.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":vision_task_options",
|
||||||
":vision_task_runner",
|
":vision_task_runner",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
|
|
@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
||||||
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
|
import {VisionTaskOptions} from './vision_task_options';
|
||||||
import {VisionTaskRunner} from './vision_task_runner';
|
import {VisionTaskRunner} from './vision_task_runner';
|
||||||
|
|
||||||
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
|
@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
|
|
||||||
protected override process(): void {}
|
protected override process(): void {}
|
||||||
|
|
||||||
|
protected override refreshGraph(): void {}
|
||||||
|
|
||||||
|
override setOptions(options: VisionTaskOptions): Promise<void> {
|
||||||
|
return this.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
override processImageData(image: ImageSource): void {
|
override processImageData(image: ImageSource): void {
|
||||||
super.processImageData(image);
|
super.processImageData(image);
|
||||||
}
|
}
|
||||||
|
@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('VisionTaskRunner', () => {
|
describe('VisionTaskRunner', () => {
|
||||||
const streamMode = {
|
|
||||||
modelAsset: undefined,
|
|
||||||
useStreamMode: true,
|
|
||||||
acceleration: undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
const imageMode = {
|
|
||||||
modelAsset: undefined,
|
|
||||||
useStreamMode: false,
|
|
||||||
acceleration: undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
let visionTaskRunner: VisionTaskRunnerFake;
|
let visionTaskRunner: VisionTaskRunnerFake;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(async () => {
|
||||||
visionTaskRunner = new VisionTaskRunnerFake();
|
visionTaskRunner = new VisionTaskRunnerFake();
|
||||||
|
await visionTaskRunner.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can enable image mode', async () => {
|
it('can enable image mode', async () => {
|
||||||
await visionTaskRunner.setOptions({runningMode: 'image'});
|
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can enable video mode', async () => {
|
it('can enable video mode', async () => {
|
||||||
await visionTaskRunner.setOptions({runningMode: 'video'});
|
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: true}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can clear running mode', async () => {
|
it('can clear running mode', async () => {
|
||||||
|
@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => {
|
||||||
|
|
||||||
// Clear running mode
|
// Clear running mode
|
||||||
await visionTaskRunner.setOptions({runningMode: undefined});
|
await visionTaskRunner.setOptions({runningMode: undefined});
|
||||||
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
expect(visionTaskRunner.baseOptions.toObject())
|
||||||
|
.toEqual(jasmine.objectContaining({useStreamMode: false}));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('cannot process images with video mode', async () => {
|
it('cannot process images with video mode', async () => {
|
||||||
|
|
|
@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options';
|
||||||
/** Base class for all MediaPipe Vision Tasks. */
|
/** Base class for all MediaPipe Vision Tasks. */
|
||||||
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
||||||
/** Configures the shared options of a vision task. */
|
/** Configures the shared options of a vision task. */
|
||||||
override async setOptions(options: VisionTaskOptions): Promise<void> {
|
override applyOptions(options: VisionTaskOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
if ('runningMode' in options) {
|
if ('runningMode' in options) {
|
||||||
const useStreamMode =
|
const useStreamMode =
|
||||||
!!options.runningMode && options.runningMode !== 'image';
|
!!options.runningMode && options.runningMode !== 'image';
|
||||||
this.baseOptions.setUseStreamMode(useStreamMode);
|
this.baseOptions.setUseStreamMode(useStreamMode);
|
||||||
}
|
}
|
||||||
|
return super.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sends an image packet to the graph and awaits results. */
|
/** Sends an image packet to the graph and awaits results. */
|
||||||
|
|
|
@ -169,9 +169,7 @@ export class GestureRecognizer extends
|
||||||
*
|
*
|
||||||
* @param options The options for the gesture recognizer.
|
* @param options The options for the gesture recognizer.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
override setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
if ('numHands' in options) {
|
if ('numHands' in options) {
|
||||||
this.handDetectorGraphOptions.setNumHands(
|
this.handDetectorGraphOptions.setNumHands(
|
||||||
options.numHands ?? DEFAULT_NUM_HANDS);
|
options.numHands ?? DEFAULT_NUM_HANDS);
|
||||||
|
@ -221,7 +219,7 @@ export class GestureRecognizer extends
|
||||||
?.clearClassifierOptions();
|
?.clearClassifierOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -342,7 +340,7 @@ export class GestureRecognizer extends
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
|
|
@ -109,7 +109,8 @@ describe('GestureRecognizer', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
gestureRecognizer = new GestureRecognizerFake();
|
gestureRecognizer = new GestureRecognizerFake();
|
||||||
await gestureRecognizer.setOptions({}); // Initialize graph
|
await gestureRecognizer.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the hand landmarker.
|
* @param options The options for the hand landmarker.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: HandLandmarkerOptions): Promise<void> {
|
override setOptions(options: HandLandmarkerOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
// Configure hand detector options.
|
// Configure hand detector options.
|
||||||
if ('numHands' in options) {
|
if ('numHands' in options) {
|
||||||
this.handDetectorGraphOptions.setNumHands(
|
this.handDetectorGraphOptions.setNumHands(
|
||||||
|
@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
|
|
@ -98,7 +98,8 @@ describe('HandLandmarker', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
handLandmarker = new HandLandmarkerFake();
|
handLandmarker = new HandLandmarkerFake();
|
||||||
await handLandmarker.setOptions({}); // Initialize graph
|
await handLandmarker.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the image classifier.
|
* @param options The options for the image classifier.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ImageClassifierOptions): Promise<void> {
|
override setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
|
@ -61,7 +61,8 @@ describe('ImageClassifier', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
imageClassifier = new ImageClassifierFake();
|
imageClassifier = new ImageClassifierFake();
|
||||||
await imageClassifier.setOptions({}); // Initialize graph
|
await imageClassifier.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
*
|
*
|
||||||
* @param options The options for the image embedder.
|
* @param options The options for the image embedder.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ImageEmbedderOptions): Promise<void> {
|
override setOptions(options: ImageEmbedderOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
|
||||||
|
|
|
@ -57,7 +57,8 @@ describe('ImageEmbedder', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
imageEmbedder = new ImageEmbedderFake();
|
imageEmbedder = new ImageEmbedderFake();
|
||||||
await imageEmbedder.setOptions({}); // Initialize graph
|
await imageEmbedder.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
|
@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
*
|
*
|
||||||
* @param options The options for the object detector.
|
* @param options The options for the object detector.
|
||||||
*/
|
*/
|
||||||
override async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
override setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||||
await super.setOptions(options);
|
|
||||||
|
|
||||||
// Note that we have to support both JSPB and ProtobufJS, hence we
|
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||||
// have to expliclity clear the values instead of setting them to
|
// have to expliclity clear the values instead of setting them to
|
||||||
// `undefined`.
|
// `undefined`.
|
||||||
|
@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
this.options.clearCategoryDenylistList();
|
this.options.clearCategoryDenylistList();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.refreshGraph();
|
return this.applyOptions(options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||||
|
|
|
@ -61,7 +61,8 @@ describe('ObjectDetector', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
addJasmineCustomFloatEqualityTester();
|
addJasmineCustomFloatEqualityTester();
|
||||||
objectDetector = new ObjectDetectorFake();
|
objectDetector = new ObjectDetectorFake();
|
||||||
await objectDetector.setOptions({}); // Initialize graph
|
await objectDetector.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user