Apply most graph options synchronously
PiperOrigin-RevId: 498244085
This commit is contained in:
		
							parent
							
								
									7e36a5e2ae
								
							
						
					
					
						commit
						9580f04571
					
				|  | @ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | |||
|    * | ||||
|    * @param options The options for the audio classifier. | ||||
|    */ | ||||
|   override async setOptions(options: AudioClassifierOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: AudioClassifierOptions): Promise<void> { | ||||
|     this.options.setClassifierOptions(convertClassifierOptionsToProto( | ||||
|         options, this.options.getClassifierOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(AUDIO_STREAM); | ||||
|     graphConfig.addInputStream(SAMPLE_RATE_STREAM); | ||||
|  |  | |||
|  | @ -79,7 +79,8 @@ describe('AudioClassifier', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     audioClassifier = new AudioClassifierFake(); | ||||
|     await audioClassifier.setOptions({});  // Initialize graph
 | ||||
|     await audioClassifier.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | |||
|    * | ||||
|    * @param options The options for the audio embedder. | ||||
|    */ | ||||
|   override async setOptions(options: AudioEmbedderOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: AudioEmbedderOptions): Promise<void> { | ||||
|     this.options.setEmbedderOptions(convertEmbedderOptionsToProto( | ||||
|         options, this.options.getEmbedderOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(AUDIO_STREAM); | ||||
|     graphConfig.addInputStream(SAMPLE_RATE_STREAM); | ||||
|  |  | |||
|  | @ -70,7 +70,8 @@ describe('AudioEmbedder', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     audioEmbedder = new AudioEmbedderFake(); | ||||
|     await audioEmbedder.setOptions({});  // Initialize graph
 | ||||
|     await audioEmbedder.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', () => { | ||||
|  |  | |||
|  | @ -103,29 +103,3 @@ jasmine_node_test( | |||
|     name = "embedder_options_test", | ||||
|     deps = [":embedder_options_test_lib"], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_library( | ||||
|     name = "base_options", | ||||
|     srcs = [ | ||||
|         "base_options.ts", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", | ||||
|         "//mediapipe/tasks/web/core", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_library( | ||||
|     name = "base_options_test_lib", | ||||
|     testonly = True, | ||||
|     srcs = ["base_options.test.ts"], | ||||
|     deps = [":base_options"], | ||||
| ) | ||||
| 
 | ||||
| jasmine_node_test( | ||||
|     name = "base_options_test", | ||||
|     deps = [":base_options_test_lib"], | ||||
| ) | ||||
|  |  | |||
|  | @ -1,127 +0,0 @@ | |||
| /** | ||||
|  * Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| import 'jasmine'; | ||||
| 
 | ||||
| // Placeholder for internal dependency on encodeByteArray
 | ||||
| // Placeholder for internal dependency on trusted resource URL builder
 | ||||
| 
 | ||||
| import {convertBaseOptionsToProto} from './base_options'; | ||||
| 
 | ||||
| describe('convertBaseOptionsToProto()', () => { | ||||
|   const mockBytes = new Uint8Array([0, 1, 2, 3]); | ||||
|   const mockBytesResult = { | ||||
|     modelAsset: { | ||||
|       fileContent: Buffer.from(mockBytes).toString('base64'), | ||||
|       fileName: undefined, | ||||
|       fileDescriptorMeta: undefined, | ||||
|       filePointerMeta: undefined, | ||||
|     }, | ||||
|     useStreamMode: false, | ||||
|     acceleration: { | ||||
|       xnnpack: undefined, | ||||
|       gpu: undefined, | ||||
|       tflite: {}, | ||||
|     }, | ||||
|   }; | ||||
| 
 | ||||
|   let fetchSpy: jasmine.Spy; | ||||
| 
 | ||||
|   beforeEach(() => { | ||||
|     fetchSpy = jasmine.createSpy().and.callFake(async url => { | ||||
|       expect(url).toEqual('foo'); | ||||
|       return { | ||||
|         arrayBuffer: () => mockBytes.buffer, | ||||
|       } as unknown as Response; | ||||
|     }); | ||||
|     global.fetch = fetchSpy; | ||||
|   }); | ||||
| 
 | ||||
|   it('verifies that at least one model asset option is provided', async () => { | ||||
|     await expectAsync(convertBaseOptionsToProto({})) | ||||
|         .toBeRejectedWithError( | ||||
|             /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); | ||||
|   }); | ||||
| 
 | ||||
|   it('verifies that no more than one model asset option is provided', async () => { | ||||
|     await expectAsync(convertBaseOptionsToProto({ | ||||
|       modelAssetPath: `foo`, | ||||
|       modelAssetBuffer: new Uint8Array([]) | ||||
|     })) | ||||
|         .toBeRejectedWithError( | ||||
|             /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); | ||||
|   }); | ||||
| 
 | ||||
|   it('downloads model', async () => { | ||||
|     const baseOptionsProto = await convertBaseOptionsToProto({ | ||||
|       modelAssetPath: `foo`, | ||||
|     }); | ||||
| 
 | ||||
|     expect(fetchSpy).toHaveBeenCalled(); | ||||
|     expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('does not download model when bytes are provided', async () => { | ||||
|     const baseOptionsProto = await convertBaseOptionsToProto({ | ||||
|       modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|     }); | ||||
| 
 | ||||
|     expect(fetchSpy).not.toHaveBeenCalled(); | ||||
|     expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable CPU delegate', async () => { | ||||
|     const baseOptionsProto = await convertBaseOptionsToProto({ | ||||
|       modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|       delegate: 'CPU', | ||||
|     }); | ||||
|     expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable GPU delegate', async () => { | ||||
|     const baseOptionsProto = await convertBaseOptionsToProto({ | ||||
|       modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|       delegate: 'GPU', | ||||
|     }); | ||||
|     expect(baseOptionsProto.toObject()).toEqual({ | ||||
|       ...mockBytesResult, | ||||
|       acceleration: { | ||||
|         xnnpack: undefined, | ||||
|         gpu: { | ||||
|           useAdvancedGpuApi: false, | ||||
|           api: 0, | ||||
|           allowPrecisionLoss: true, | ||||
|           cachedKernelPath: undefined, | ||||
|           serializedModelDir: undefined, | ||||
|           modelToken: undefined, | ||||
|           usage: 2, | ||||
|         }, | ||||
|         tflite: undefined, | ||||
|       }, | ||||
|     }); | ||||
|   }); | ||||
| 
 | ||||
|   it('can reset delegate', async () => { | ||||
|     let baseOptionsProto = await convertBaseOptionsToProto({ | ||||
|       modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|       delegate: 'GPU', | ||||
|     }); | ||||
|     // Clear backend
 | ||||
|     baseOptionsProto = | ||||
|         await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); | ||||
|     expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| }); | ||||
|  | @ -1,80 +0,0 @@ | |||
| /** | ||||
|  * Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; | ||||
| import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; | ||||
| import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; | ||||
| import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; | ||||
| import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||
| 
 | ||||
| // The OSS JS API does not support the builder pattern.
 | ||||
| // tslint:disable:jspb-use-builder-pattern
 | ||||
| 
 | ||||
| /** | ||||
|  * Converts a BaseOptions API object to its Protobuf representation. | ||||
|  * @throws If neither a model assset path or buffer is provided | ||||
|  */ | ||||
| export async function convertBaseOptionsToProto( | ||||
|     updatedOptions: BaseOptions, | ||||
|     currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> { | ||||
|   const result = | ||||
|       currentOptions ? currentOptions.clone() : new BaseOptionsProto(); | ||||
| 
 | ||||
|   await configureExternalFile(updatedOptions, result); | ||||
|   configureAcceleration(updatedOptions, result); | ||||
| 
 | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| /** | ||||
|  * Configues the `externalFile` option and validates that a single model is | ||||
|  * provided. | ||||
|  */ | ||||
| async function configureExternalFile( | ||||
|     options: BaseOptions, proto: BaseOptionsProto) { | ||||
|   const externalFile = proto.getModelAsset() || new ExternalFile(); | ||||
|   proto.setModelAsset(externalFile); | ||||
| 
 | ||||
|   if (options.modelAssetPath || options.modelAssetBuffer) { | ||||
|     if (options.modelAssetPath && options.modelAssetBuffer) { | ||||
|       throw new Error( | ||||
|           'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); | ||||
|     } | ||||
| 
 | ||||
|     let modelAssetBuffer = options.modelAssetBuffer; | ||||
|     if (!modelAssetBuffer) { | ||||
|       const response = await fetch(options.modelAssetPath!.toString()); | ||||
|       modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); | ||||
|     } | ||||
|     externalFile.setFileContent(modelAssetBuffer); | ||||
|   } | ||||
| 
 | ||||
|   if (!externalFile.hasFileContent()) { | ||||
|     throw new Error( | ||||
|         'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /** Configues the `acceleration` option. */ | ||||
| function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { | ||||
|   const acceleration = proto.getAcceleration() ?? new Acceleration(); | ||||
|   if (options.delegate === 'GPU') { | ||||
|     acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); | ||||
|   } else { | ||||
|     acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); | ||||
|   } | ||||
|   proto.setAcceleration(acceleration); | ||||
| } | ||||
|  | @ -18,8 +18,10 @@ mediapipe_ts_library( | |||
|     srcs = ["task_runner.ts"], | ||||
|     deps = [ | ||||
|         ":core", | ||||
|         "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||
|         "//mediapipe/tasks/web/components/processors:base_options", | ||||
|         "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", | ||||
|         "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", | ||||
|         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||
|         "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", | ||||
|  | @ -53,6 +55,7 @@ mediapipe_ts_library( | |||
|         "task_runner_test.ts", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":core", | ||||
|         ":task_runner", | ||||
|         ":task_runner_test_utils", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||
|  |  | |||
|  | @ -14,9 +14,11 @@ | |||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; | ||||
| import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; | ||||
| import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; | ||||
| import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; | ||||
| import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; | ||||
| import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; | ||||
| import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; | ||||
| import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; | ||||
| import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; | ||||
| import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; | ||||
|  | @ -91,14 +93,52 @@ export abstract class TaskRunner { | |||
|     this.graphRunner.registerModelResourcesGraphService(); | ||||
|   } | ||||
| 
 | ||||
|   /** Configures the shared options of a MediaPipe Task. */ | ||||
|   async setOptions(options: TaskRunnerOptions): Promise<void> { | ||||
|     if (options.baseOptions) { | ||||
|       this.baseOptions = await convertBaseOptionsToProto( | ||||
|           options.baseOptions, this.baseOptions); | ||||
|   /** Configures the task with custom options. */ | ||||
|   abstract setOptions(options: TaskRunnerOptions): Promise<void>; | ||||
| 
 | ||||
|   /** | ||||
|    * Applies the current set of options, including any base options that have | ||||
|    * not been processed by the task implementation. The options are applied | ||||
|    * synchronously unless a `modelAssetPath` is provided. This ensures that | ||||
|    * for most use cases options are applied directly and immediately affect | ||||
|    * the next inference. | ||||
|    */ | ||||
|   protected applyOptions(options: TaskRunnerOptions): Promise<void> { | ||||
|     const baseOptions: BaseOptions = options.baseOptions || {}; | ||||
| 
 | ||||
|     // Validate that exactly one model is configured
 | ||||
|     if (options.baseOptions?.modelAssetBuffer && | ||||
|         options.baseOptions?.modelAssetPath) { | ||||
|       throw new Error( | ||||
|           'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); | ||||
|     } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || | ||||
|                  options.baseOptions?.modelAssetBuffer || | ||||
|                  options.baseOptions?.modelAssetPath)) { | ||||
|       throw new Error( | ||||
|           'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); | ||||
|     } | ||||
| 
 | ||||
|     this.setAcceleration(baseOptions); | ||||
|     if (baseOptions.modelAssetPath) { | ||||
|       // We don't use `await` here since we want to apply most settings
 | ||||
|       // synchronously.
 | ||||
|       return fetch(baseOptions.modelAssetPath.toString()) | ||||
|           .then(response => response.arrayBuffer()) | ||||
|           .then(buffer => { | ||||
|             this.setExternalFile(new Uint8Array(buffer)); | ||||
|             this.refreshGraph(); | ||||
|           }); | ||||
|     } else { | ||||
|       // Apply the setting synchronously.
 | ||||
|       this.setExternalFile(baseOptions.modelAssetBuffer); | ||||
|       this.refreshGraph(); | ||||
|       return Promise.resolve(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /** Appliest the current options to the MediaPipe graph. */ | ||||
|   protected abstract refreshGraph(): void; | ||||
| 
 | ||||
|   /** | ||||
|    * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run | ||||
|    * over the video stream. Will replace the previously running MediaPipe graph, | ||||
|  | @ -140,6 +180,27 @@ export abstract class TaskRunner { | |||
|     } | ||||
|     this.processingErrors = []; | ||||
|   } | ||||
| 
 | ||||
|   /** Configures the `externalFile` option */ | ||||
|   private setExternalFile(modelAssetBuffer?: Uint8Array): void { | ||||
|     const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); | ||||
|     if (modelAssetBuffer) { | ||||
|       externalFile.setFileContent(modelAssetBuffer); | ||||
|     } | ||||
|     this.baseOptions.setModelAsset(externalFile); | ||||
|   } | ||||
| 
 | ||||
|   /** Configures the `acceleration` option. */ | ||||
|   private setAcceleration(options: BaseOptions) { | ||||
|     const acceleration = | ||||
|         this.baseOptions.getAcceleration() ?? new Acceleration(); | ||||
|     if (options.delegate === 'GPU') { | ||||
|       acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); | ||||
|     } else { | ||||
|       acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); | ||||
|     } | ||||
|     this.baseOptions.setAcceleration(acceleration); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,18 +15,22 @@ | |||
|  */ | ||||
| import 'jasmine'; | ||||
| 
 | ||||
| // Placeholder for internal dependency on encodeByteArray
 | ||||
| import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; | ||||
| import {TaskRunner} from '../../../tasks/web/core/task_runner'; | ||||
| import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; | ||||
| import {ErrorListener} from '../../../web/graph_runner/graph_runner'; | ||||
| // Placeholder for internal dependency on trusted resource URL builder
 | ||||
| 
 | ||||
| import {GraphRunnerImageLib} from './task_runner'; | ||||
| import {TaskRunnerOptions} from './task_runner_options.d'; | ||||
| 
 | ||||
| class TaskRunnerFake extends TaskRunner { | ||||
|   protected baseOptions = new BaseOptionsProto(); | ||||
|   private errorListener: ErrorListener|undefined; | ||||
|   private errors: string[] = []; | ||||
| 
 | ||||
|   baseOptions = new BaseOptionsProto(); | ||||
| 
 | ||||
|   static createFake(): TaskRunnerFake { | ||||
|     const wasmModule = createSpyWasmModule(); | ||||
|     return new TaskRunnerFake(wasmModule); | ||||
|  | @ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { | |||
|     super.finishProcessing(); | ||||
|   } | ||||
| 
 | ||||
|   override refreshGraph(): void {} | ||||
| 
 | ||||
|   override setGraph(graphData: Uint8Array, isBinary: boolean): void { | ||||
|     super.setGraph(graphData, isBinary); | ||||
|   } | ||||
| 
 | ||||
|   setOptions(options: TaskRunnerOptions): Promise<void> { | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   private throwErrors(): void { | ||||
|     expect(this.errorListener).toBeDefined(); | ||||
|     for (const error of this.errors) { | ||||
|  | @ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { | |||
| } | ||||
| 
 | ||||
| describe('TaskRunner', () => { | ||||
|   const mockBytes = new Uint8Array([0, 1, 2, 3]); | ||||
|   const mockBytesResult = { | ||||
|     modelAsset: { | ||||
|       fileContent: Buffer.from(mockBytes).toString('base64'), | ||||
|       fileName: undefined, | ||||
|       fileDescriptorMeta: undefined, | ||||
|       filePointerMeta: undefined, | ||||
|     }, | ||||
|     useStreamMode: false, | ||||
|     acceleration: { | ||||
|       xnnpack: undefined, | ||||
|       gpu: undefined, | ||||
|       tflite: {}, | ||||
|     }, | ||||
|   }; | ||||
| 
 | ||||
|   let fetchSpy: jasmine.Spy; | ||||
|   let taskRunner: TaskRunnerFake; | ||||
| 
 | ||||
|   beforeEach(() => { | ||||
|     fetchSpy = jasmine.createSpy().and.callFake(async url => { | ||||
|       expect(url).toEqual('foo'); | ||||
|       return { | ||||
|         arrayBuffer: () => mockBytes.buffer, | ||||
|       } as unknown as Response; | ||||
|     }); | ||||
|     global.fetch = fetchSpy; | ||||
| 
 | ||||
|     taskRunner = TaskRunnerFake.createFake(); | ||||
|   }); | ||||
| 
 | ||||
|   it('handles errors during graph update', () => { | ||||
|     const taskRunner = TaskRunnerFake.createFake(); | ||||
|     taskRunner.enqueueError('Test error'); | ||||
| 
 | ||||
|     expect(() => { | ||||
|  | @ -85,7 +125,6 @@ describe('TaskRunner', () => { | |||
|   }); | ||||
| 
 | ||||
|   it('handles errors during graph execution', () => { | ||||
|     const taskRunner = TaskRunnerFake.createFake(); | ||||
|     taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); | ||||
| 
 | ||||
|     taskRunner.enqueueError('Test error'); | ||||
|  | @ -96,7 +135,6 @@ describe('TaskRunner', () => { | |||
|   }); | ||||
| 
 | ||||
|   it('can handle multiple errors', () => { | ||||
|     const taskRunner = TaskRunnerFake.createFake(); | ||||
|     taskRunner.enqueueError('Test error 1'); | ||||
|     taskRunner.enqueueError('Test error 2'); | ||||
| 
 | ||||
|  | @ -104,4 +142,106 @@ describe('TaskRunner', () => { | |||
|       taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); | ||||
|     }).toThrowError(/Test error 1, Test error 2/); | ||||
|   }); | ||||
| 
 | ||||
|   it('verifies that at least one model asset option is provided', () => { | ||||
|     expect(() => { | ||||
|       taskRunner.setOptions({}); | ||||
|     }) | ||||
|         .toThrowError( | ||||
|             /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); | ||||
|   }); | ||||
| 
 | ||||
|   it('verifies that no more than one model asset option is provided', () => { | ||||
|     expect(() => { | ||||
|       taskRunner.setOptions({ | ||||
|         baseOptions: { | ||||
|           modelAssetPath: `foo`, | ||||
|           modelAssetBuffer: new Uint8Array([]) | ||||
|         } | ||||
|       }); | ||||
|     }) | ||||
|         .toThrowError( | ||||
|             /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); | ||||
|   }); | ||||
| 
 | ||||
|   it('doesn\'t require model once it is configured', async () => { | ||||
|     await taskRunner.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); | ||||
|     expect(() => { | ||||
|       taskRunner.setOptions({}); | ||||
|     }).not.toThrowError(); | ||||
|   }); | ||||
| 
 | ||||
|   it('downloads model', async () => { | ||||
|     await taskRunner.setOptions( | ||||
|         {baseOptions: {modelAssetPath: `foo`}}); | ||||
| 
 | ||||
|     expect(fetchSpy).toHaveBeenCalled(); | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('does not download model when bytes are provided', async () => { | ||||
|     await taskRunner.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); | ||||
| 
 | ||||
|     expect(fetchSpy).not.toHaveBeenCalled(); | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('changes model synchronously when bytes are provided', () => { | ||||
|     const resolvedPromise = taskRunner.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); | ||||
| 
 | ||||
|     // Check that the change has been applied even though we do not await the
 | ||||
|     // above Promise
 | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); | ||||
|     return resolvedPromise; | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable CPU delegate', async () => { | ||||
|     await taskRunner.setOptions({ | ||||
|       baseOptions: { | ||||
|         modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|         delegate: 'CPU', | ||||
|       } | ||||
|     }); | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable GPU delegate', async () => { | ||||
|     await taskRunner.setOptions({ | ||||
|       baseOptions: { | ||||
|         modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|         delegate: 'GPU', | ||||
|       } | ||||
|     }); | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual({ | ||||
|       ...mockBytesResult, | ||||
|       acceleration: { | ||||
|         xnnpack: undefined, | ||||
|         gpu: { | ||||
|           useAdvancedGpuApi: false, | ||||
|           api: 0, | ||||
|           allowPrecisionLoss: true, | ||||
|           cachedKernelPath: undefined, | ||||
|           serializedModelDir: undefined, | ||||
|           modelToken: undefined, | ||||
|           usage: 2, | ||||
|         }, | ||||
|         tflite: undefined, | ||||
|       }, | ||||
|     }); | ||||
|   }); | ||||
| 
 | ||||
|   it('can reset delegate', async () => { | ||||
|     await taskRunner.setOptions({ | ||||
|       baseOptions: { | ||||
|         modelAssetBuffer: new Uint8Array(mockBytes), | ||||
|         delegate: 'GPU', | ||||
|       } | ||||
|     }); | ||||
|     // Clear backend
 | ||||
|     await taskRunner.setOptions({baseOptions: {delegate: undefined}}); | ||||
|     expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); | ||||
|   }); | ||||
| }); | ||||
|  |  | |||
|  | @ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { | |||
|    * | ||||
|    * @param options The options for the text classifier. | ||||
|    */ | ||||
|   override async setOptions(options: TextClassifierOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: TextClassifierOptions): Promise<void> { | ||||
|     this.options.setClassifierOptions(convertClassifierOptionsToProto( | ||||
|         options, this.options.getClassifierOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   protected override get baseOptions(): BaseOptionsProto { | ||||
|  | @ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(INPUT_STREAM); | ||||
|     graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); | ||||
|  |  | |||
|  | @ -56,7 +56,8 @@ describe('TextClassifier', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     textClassifier = new TextClassifierFake(); | ||||
|     await textClassifier.setOptions({});  // Initialize graph
 | ||||
|     await textClassifier.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { | |||
|    * | ||||
|    * @param options The options for the text embedder. | ||||
|    */ | ||||
|   override async setOptions(options: TextEmbedderOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: TextEmbedderOptions): Promise<void> { | ||||
|     this.options.setEmbedderOptions(convertEmbedderOptionsToProto( | ||||
|         options, this.options.getEmbedderOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   protected override get baseOptions(): BaseOptionsProto { | ||||
|  | @ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(INPUT_STREAM); | ||||
|     graphConfig.addOutputStream(EMBEDDINGS_STREAM); | ||||
|  |  | |||
|  | @ -56,7 +56,8 @@ describe('TextEmbedder', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     textEmbedder = new TextEmbedderFake(); | ||||
|     await textEmbedder.setOptions({});  // Initialize graph
 | ||||
|     await textEmbedder.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -29,6 +29,7 @@ mediapipe_ts_library( | |||
|     testonly = True, | ||||
|     srcs = ["vision_task_runner.test.ts"], | ||||
|     deps = [ | ||||
|         ":vision_task_options", | ||||
|         ":vision_task_runner", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||
|         "//mediapipe/tasks/web/core:task_runner_test_utils", | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b | |||
| import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; | ||||
| import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | ||||
| 
 | ||||
| import {VisionTaskOptions} from './vision_task_options'; | ||||
| import {VisionTaskRunner} from './vision_task_runner'; | ||||
| 
 | ||||
| class VisionTaskRunnerFake extends VisionTaskRunner<void> { | ||||
|  | @ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> { | |||
| 
 | ||||
|   protected override process(): void {} | ||||
| 
 | ||||
|   protected override refreshGraph(): void {} | ||||
| 
 | ||||
|   override setOptions(options: VisionTaskOptions): Promise<void> { | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   override processImageData(image: ImageSource): void { | ||||
|     super.processImageData(image); | ||||
|   } | ||||
|  | @ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner<void> { | |||
| } | ||||
| 
 | ||||
| describe('VisionTaskRunner', () => { | ||||
|   const streamMode = { | ||||
|     modelAsset: undefined, | ||||
|     useStreamMode: true, | ||||
|     acceleration: undefined, | ||||
|   }; | ||||
| 
 | ||||
|   const imageMode = { | ||||
|     modelAsset: undefined, | ||||
|     useStreamMode: false, | ||||
|     acceleration: undefined, | ||||
|   }; | ||||
| 
 | ||||
|   let visionTaskRunner: VisionTaskRunnerFake; | ||||
| 
 | ||||
|   beforeEach(() => { | ||||
|   beforeEach(async () => { | ||||
|     visionTaskRunner = new VisionTaskRunnerFake(); | ||||
|     await visionTaskRunner.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable image mode', async () => { | ||||
|     await visionTaskRunner.setOptions({runningMode: 'image'}); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()) | ||||
|         .toEqual(jasmine.objectContaining({useStreamMode: false})); | ||||
|   }); | ||||
| 
 | ||||
|   it('can enable video mode', async () => { | ||||
|     await visionTaskRunner.setOptions({runningMode: 'video'}); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()) | ||||
|         .toEqual(jasmine.objectContaining({useStreamMode: true})); | ||||
|   }); | ||||
| 
 | ||||
|   it('can clear running mode', async () => { | ||||
|  | @ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { | |||
| 
 | ||||
|     // Clear running mode
 | ||||
|     await visionTaskRunner.setOptions({runningMode: undefined}); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); | ||||
|     expect(visionTaskRunner.baseOptions.toObject()) | ||||
|         .toEqual(jasmine.objectContaining({useStreamMode: false})); | ||||
|   }); | ||||
| 
 | ||||
|   it('cannot process images with video mode', async () => { | ||||
|  |  | |||
|  | @ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; | |||
| /** Base class for all MediaPipe Vision Tasks. */ | ||||
| export abstract class VisionTaskRunner<T> extends TaskRunner { | ||||
|   /** Configures the shared options of a vision task. */ | ||||
|   override async setOptions(options: VisionTaskOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override applyOptions(options: VisionTaskOptions): Promise<void> { | ||||
|     if ('runningMode' in options) { | ||||
|       const useStreamMode = | ||||
|           !!options.runningMode && options.runningMode !== 'image'; | ||||
|       this.baseOptions.setUseStreamMode(useStreamMode); | ||||
|     } | ||||
|     return super.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** Sends an image packet to the graph and awaits results. */ | ||||
|  |  | |||
|  | @ -169,9 +169,7 @@ export class GestureRecognizer extends | |||
|    * | ||||
|    * @param options The options for the gesture recognizer. | ||||
|    */ | ||||
|   override async setOptions(options: GestureRecognizerOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
| 
 | ||||
|   override setOptions(options: GestureRecognizerOptions): Promise<void> { | ||||
|     if ('numHands' in options) { | ||||
|       this.handDetectorGraphOptions.setNumHands( | ||||
|           options.numHands ?? DEFAULT_NUM_HANDS); | ||||
|  | @ -221,7 +219,7 @@ export class GestureRecognizer extends | |||
|           ?.clearClassifierOptions(); | ||||
|     } | ||||
| 
 | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -342,7 +340,7 @@ export class GestureRecognizer extends | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(IMAGE_STREAM); | ||||
|     graphConfig.addInputStream(NORM_RECT_STREAM); | ||||
|  |  | |||
|  | @ -109,7 +109,8 @@ describe('GestureRecognizer', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     gestureRecognizer = new GestureRecognizerFake(); | ||||
|     await gestureRecognizer.setOptions({});  // Initialize graph
 | ||||
|     await gestureRecognizer.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | |||
|    * | ||||
|    * @param options The options for the hand landmarker. | ||||
|    */ | ||||
|   override async setOptions(options: HandLandmarkerOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
| 
 | ||||
|   override setOptions(options: HandLandmarkerOptions): Promise<void> { | ||||
|     // Configure hand detector options.
 | ||||
|     if ('numHands' in options) { | ||||
|       this.handDetectorGraphOptions.setNumHands( | ||||
|  | @ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | |||
|           options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); | ||||
|     } | ||||
| 
 | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(IMAGE_STREAM); | ||||
|     graphConfig.addInputStream(NORM_RECT_STREAM); | ||||
|  |  | |||
|  | @ -98,7 +98,8 @@ describe('HandLandmarker', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     handLandmarker = new HandLandmarkerFake(); | ||||
|     await handLandmarker.setOptions({});  // Initialize graph
 | ||||
|     await handLandmarker.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | |||
|    * | ||||
|    * @param options The options for the image classifier. | ||||
|    */ | ||||
|   override async setOptions(options: ImageClassifierOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: ImageClassifierOptions): Promise<void> { | ||||
|     this.options.setClassifierOptions(convertClassifierOptionsToProto( | ||||
|         options, this.options.getClassifierOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(INPUT_STREAM); | ||||
|     graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); | ||||
|  |  | |||
|  | @ -61,7 +61,8 @@ describe('ImageClassifier', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     imageClassifier = new ImageClassifierFake(); | ||||
|     await imageClassifier.setOptions({});  // Initialize graph
 | ||||
|     await imageClassifier.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | |||
|    * | ||||
|    * @param options The options for the image embedder. | ||||
|    */ | ||||
|   override async setOptions(options: ImageEmbedderOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
|   override setOptions(options: ImageEmbedderOptions): Promise<void> { | ||||
|     this.options.setEmbedderOptions(convertEmbedderOptionsToProto( | ||||
|         options, this.options.getEmbedderOptions())); | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(INPUT_STREAM); | ||||
|     graphConfig.addOutputStream(EMBEDDINGS_STREAM); | ||||
|  |  | |||
|  | @ -57,7 +57,8 @@ describe('ImageEmbedder', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     imageEmbedder = new ImageEmbedderFake(); | ||||
|     await imageEmbedder.setOptions({});  // Initialize graph
 | ||||
|     await imageEmbedder.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
|  | @ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | |||
|    * | ||||
|    * @param options The options for the object detector. | ||||
|    */ | ||||
|   override async setOptions(options: ObjectDetectorOptions): Promise<void> { | ||||
|     await super.setOptions(options); | ||||
| 
 | ||||
|   override setOptions(options: ObjectDetectorOptions): Promise<void> { | ||||
|     // Note that we have to support both JSPB and ProtobufJS, hence we
 | ||||
|     // have to expliclity clear the values instead of setting them to
 | ||||
|     // `undefined`.
 | ||||
|  | @ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | |||
|       this.options.clearCategoryDenylistList(); | ||||
|     } | ||||
| 
 | ||||
|     this.refreshGraph(); | ||||
|     return this.applyOptions(options); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|  | @ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | |||
|   } | ||||
| 
 | ||||
|   /** Updates the MediaPipe graph configuration. */ | ||||
|   private refreshGraph(): void { | ||||
|   protected override refreshGraph(): void { | ||||
|     const graphConfig = new CalculatorGraphConfig(); | ||||
|     graphConfig.addInputStream(INPUT_STREAM); | ||||
|     graphConfig.addOutputStream(DETECTIONS_STREAM); | ||||
|  |  | |||
|  | @ -61,7 +61,8 @@ describe('ObjectDetector', () => { | |||
|   beforeEach(async () => { | ||||
|     addJasmineCustomFloatEqualityTester(); | ||||
|     objectDetector = new ObjectDetectorFake(); | ||||
|     await objectDetector.setOptions({});  // Initialize graph
 | ||||
|     await objectDetector.setOptions( | ||||
|         {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); | ||||
|   }); | ||||
| 
 | ||||
|   it('initializes graph', async () => { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user