Move shared code to TaskRunner
PiperOrigin-RevId: 492534879
This commit is contained in:
		
							parent
							
								
									dabc2af15b
								
							
						
					
					
						commit
						da9587033d
					
				|  | @ -25,7 +25,7 @@ mediapipe_ts_library( | ||||||
|         "//mediapipe/tasks/web/components/processors:classifier_result", |         "//mediapipe/tasks/web/components/processors:classifier_result", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:classifier_options", |         "//mediapipe/tasks/web/core:classifier_options", | ||||||
|         "//mediapipe/tasks/web/core:task_runner", |         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -36,7 +36,6 @@ mediapipe_ts_declaration( | ||||||
|         "audio_classifier_result.d.ts", |         "audio_classifier_result.d.ts", | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/tasks/web/audio/core:audio_task_options", |  | ||||||
|         "//mediapipe/tasks/web/components/containers:category", |         "//mediapipe/tasks/web/components/containers:category", | ||||||
|         "//mediapipe/tasks/web/components/containers:classification_result", |         "//mediapipe/tasks/web/components/containers:classification_result", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|  |  | ||||||
|  | @ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b | ||||||
| import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; | import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; | ||||||
| import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; | import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; | ||||||
| import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | ||||||
| import {TaskRunner} from '../../../../tasks/web/core/task_runner'; |  | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
|  | import {WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {AudioClassifierOptions} from './audio_classifier_options'; | import {AudioClassifierOptions} from './audio_classifier_options'; | ||||||
|  | @ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | ||||||
|    *     that either a path to the model asset or a model buffer needs to be |    *     that either a path to the model asset or a model buffer needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): |       wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): | ||||||
|       Promise<AudioClassifier> { |       Promise<AudioClassifier> { | ||||||
|     const classifier = await TaskRunner.createInstance( |     return AudioTaskRunner.createInstance( | ||||||
|         AudioClassifier, /* initializeCanvas= */ false, wasmFileset); |         AudioClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     await classifier.setOptions(audioClassifierOptions); |         audioClassifierOptions); | ||||||
|     return classifier; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<AudioClassifier> { |       modelAssetBuffer: Uint8Array): Promise<AudioClassifier> { | ||||||
|     return AudioClassifier.createFromOptions( |     return AudioTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         AudioClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the model asset. |    * @param modelAssetPath The path to the model asset. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<AudioClassifier> { |       modelAssetPath: string): Promise<AudioClassifier> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return AudioTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         AudioClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     return AudioClassifier.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   constructor( | ||||||
|     return this.options.getBaseOptions(); |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,9 +14,9 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; |  | ||||||
| import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; | import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; | ||||||
|  | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| /** Options to configure the MediaPipe Audio Classifier Task */ | /** Options to configure the MediaPipe Audio Classifier Task */ | ||||||
| export declare interface AudioClassifierOptions extends ClassifierOptions, | export declare interface AudioClassifierOptions extends ClassifierOptions, | ||||||
|                                                         AudioTaskOptions {} |                                                         TaskRunnerOptions {} | ||||||
|  |  | ||||||
|  | @ -36,7 +36,6 @@ mediapipe_ts_declaration( | ||||||
|         "audio_embedder_result.d.ts", |         "audio_embedder_result.d.ts", | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/tasks/web/audio/core:audio_task_options", |  | ||||||
|         "//mediapipe/tasks/web/components/containers:embedding_result", |         "//mediapipe/tasks/web/components/containers:embedding_result", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:embedder_options", |         "//mediapipe/tasks/web/core:embedder_options", | ||||||
|  |  | ||||||
|  | @ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr | ||||||
| import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; | import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; | ||||||
| import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
| import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; | import {WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {AudioEmbedderOptions} from './audio_embedder_options'; | import {AudioEmbedderOptions} from './audio_embedder_options'; | ||||||
|  | @ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | ||||||
|    *     either a path to the TFLite model or the model itself needs to be |    *     either a path to the TFLite model or the model itself needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> { |       audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> { | ||||||
|     // Create a file locator based on the loader options
 |     return AudioTaskRunner.createInstance( | ||||||
|     const fileLocator: FileLocator = { |         AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|       locateFile() { |         audioEmbedderOptions); | ||||||
|         // The only file we load is the Wasm binary
 |  | ||||||
|         return wasmFileset.wasmBinaryPath.toString(); |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     const embedder = await createMediaPipeLib( |  | ||||||
|         AudioEmbedder, wasmFileset.wasmLoaderPath, |  | ||||||
|         /* assetLoaderScript= */ undefined, |  | ||||||
|         /* glCanvas= */ undefined, fileLocator); |  | ||||||
|     await embedder.setOptions(audioEmbedderOptions); |  | ||||||
|     return embedder; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> { |       modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> { | ||||||
|     return AudioEmbedder.createFromOptions( |     return AudioTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the TFLite model. |    * @param modelAssetPath The path to the TFLite model. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<AudioEmbedder> { |       modelAssetPath: string): Promise<AudioEmbedder> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return AudioTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     return AudioEmbedder.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   constructor( | ||||||
|     return this.options.getBaseOptions(); |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,9 +14,9 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; |  | ||||||
| import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; | import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; | ||||||
|  | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| /** Options to configure the MediaPipe Audio Embedder Task */ | /** Options to configure the MediaPipe Audio Embedder Task */ | ||||||
| export declare interface AudioEmbedderOptions extends EmbedderOptions, | export declare interface AudioEmbedderOptions extends EmbedderOptions, | ||||||
|                                                       AudioTaskOptions {} |                                                       TaskRunnerOptions {} | ||||||
|  |  | ||||||
|  | @ -1,24 +1,13 @@ | ||||||
| # This package contains options shared by all MediaPipe Audio Tasks for Web. | # This package contains options shared by all MediaPipe Audio Tasks for Web. | ||||||
| 
 | 
 | ||||||
| load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") | load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") | ||||||
| 
 | 
 | ||||||
| package(default_visibility = ["//mediapipe/tasks:internal"]) | package(default_visibility = ["//mediapipe/tasks:internal"]) | ||||||
| 
 | 
 | ||||||
| mediapipe_ts_declaration( |  | ||||||
|     name = "audio_task_options", |  | ||||||
|     srcs = ["audio_task_options.d.ts"], |  | ||||||
|     deps = [ |  | ||||||
|         "//mediapipe/tasks/web/core", |  | ||||||
|     ], |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| mediapipe_ts_library( | mediapipe_ts_library( | ||||||
|     name = "audio_task_runner", |     name = "audio_task_runner", | ||||||
|     srcs = ["audio_task_runner.ts"], |     srcs = ["audio_task_runner.ts"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":audio_task_options", |  | ||||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", |  | ||||||
|         "//mediapipe/tasks/web/components/processors:base_options", |  | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:task_runner", |         "//mediapipe/tasks/web/core:task_runner", | ||||||
|     ], |     ], | ||||||
|  |  | ||||||
|  | @ -1,23 +0,0 @@ | ||||||
| /** |  | ||||||
|  * Copyright 2022 The MediaPipe Authors. All Rights Reserved. |  | ||||||
|  * |  | ||||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
|  * you may not use this file except in compliance with the License. |  | ||||||
|  * You may obtain a copy of the License at |  | ||||||
|  * |  | ||||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
|  * |  | ||||||
|  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  * distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
|  * See the License for the specific language governing permissions and |  | ||||||
|  * limitations under the License. |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| import {BaseOptions} from '../../../../tasks/web/core/base_options'; |  | ||||||
| 
 |  | ||||||
| /** The options for configuring a MediaPipe Audio Task. */ |  | ||||||
| export declare interface AudioTaskOptions { |  | ||||||
|   /** Options to configure the loading of the model assets. */ |  | ||||||
|   baseOptions?: BaseOptions; |  | ||||||
| } |  | ||||||
|  | @ -14,26 +14,13 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; |  | ||||||
| import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; |  | ||||||
| import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | ||||||
| 
 | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| import {AudioTaskOptions} from './audio_task_options'; |  | ||||||
| 
 | 
 | ||||||
| /** Base class for all MediaPipe Audio Tasks. */ | /** Base class for all MediaPipe Audio Tasks. */ | ||||||
| export abstract class AudioTaskRunner<T> extends TaskRunner { | export abstract class AudioTaskRunner<T> extends TaskRunner<TaskRunnerOptions> { | ||||||
|   protected abstract baseOptions?: BaseOptionsProto|undefined; |  | ||||||
|   private defaultSampleRate = 48000; |   private defaultSampleRate = 48000; | ||||||
| 
 | 
 | ||||||
|   /** Configures the shared options of an audio task. */ |  | ||||||
|   async setOptions(options: AudioTaskOptions): Promise<void> { |  | ||||||
|     this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); |  | ||||||
|     if (options.baseOptions) { |  | ||||||
|       this.baseOptions = await convertBaseOptionsToProto( |  | ||||||
|           options.baseOptions, this.baseOptions); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /** |   /** | ||||||
|    * Sets the sample rate for API calls that omit an explicit sample rate. |    * Sets the sample rate for API calls that omit an explicit sample rate. | ||||||
|    * `48000` is used as a default if this method is not called. |    * `48000` is used as a default if this method is not called. | ||||||
|  |  | ||||||
|  | @ -17,7 +17,6 @@ mediapipe_ts_library( | ||||||
|     name = "classifier_result", |     name = "classifier_result", | ||||||
|     srcs = ["classifier_result.ts"], |     srcs = ["classifier_result.ts"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/framework/formats:classification_jspb_proto", |  | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", |         "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", | ||||||
|         "//mediapipe/tasks/web/components/containers:classification_result", |         "//mediapipe/tasks/web/components/containers:classification_result", | ||||||
|     ], |     ], | ||||||
|  |  | ||||||
|  | @ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen | ||||||
| import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; | import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; | ||||||
| import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; | import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; | ||||||
| import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; | import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; | ||||||
| import {BaseOptions} from '../../../../tasks/web/core/base_options'; | import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| // The OSS JS API does not support the builder pattern.
 | // The OSS JS API does not support the builder pattern.
 | ||||||
| // tslint:disable:jspb-use-builder-pattern
 | // tslint:disable:jspb-use-builder-pattern
 | ||||||
|  |  | ||||||
|  | @ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) | ||||||
| mediapipe_ts_declaration( | mediapipe_ts_declaration( | ||||||
|     name = "core", |     name = "core", | ||||||
|     srcs = [ |     srcs = [ | ||||||
|         "base_options.d.ts", |         "task_runner_options.d.ts", | ||||||
|         "wasm_fileset.d.ts", |         "wasm_fileset.d.ts", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| mediapipe_ts_library( | mediapipe_ts_library( | ||||||
|     name = "task_runner", |     name = "task_runner", | ||||||
|     srcs = [ |     srcs = ["task_runner.ts"], | ||||||
|         "task_runner.ts", |  | ||||||
|     ], |  | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":core", |         ":core", | ||||||
|  |         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||||
|  |         "//mediapipe/tasks/web/components/processors:base_options", | ||||||
|         "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", |         "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", | ||||||
|         "//mediapipe/web/graph_runner:graph_runner_ts", |         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||||
|         "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", |         "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", | ||||||
|  |  | ||||||
|  | @ -14,8 +14,6 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {BaseOptions} from '../../../tasks/web/core/base_options'; |  | ||||||
| 
 |  | ||||||
| /** Options to configure a MediaPipe Classifier Task. */ | /** Options to configure a MediaPipe Classifier Task. */ | ||||||
| export declare interface ClassifierOptions { | export declare interface ClassifierOptions { | ||||||
|   /** |   /** | ||||||
|  |  | ||||||
|  | @ -14,8 +14,6 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {BaseOptions} from '../../../tasks/web/core/base_options'; |  | ||||||
| 
 |  | ||||||
| /** Options to configure a MediaPipe Embedder Task */ | /** Options to configure a MediaPipe Embedder Task */ | ||||||
| export declare interface EmbedderOptions { | export declare interface EmbedderOptions { | ||||||
|   /** |   /** | ||||||
|  |  | ||||||
|  | @ -14,6 +14,9 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
|  | import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; | ||||||
|  | import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; | ||||||
|  | import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; | ||||||
| import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; | import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; | ||||||
| import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; | import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; | ||||||
| import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; | import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; | ||||||
|  | @ -28,7 +31,9 @@ const WasmMediaPipeImageLib = | ||||||
|     SupportModelResourcesGraphService(SupportImage(GraphRunner)); |     SupportModelResourcesGraphService(SupportImage(GraphRunner)); | ||||||
| 
 | 
 | ||||||
| /** Base class for all MediaPipe Tasks. */ | /** Base class for all MediaPipe Tasks. */ | ||||||
| export abstract class TaskRunner extends WasmMediaPipeImageLib { | export abstract class TaskRunner<O extends TaskRunnerOptions> extends | ||||||
|  |     WasmMediaPipeImageLib { | ||||||
|  |   protected abstract baseOptions: BaseOptionsProto; | ||||||
|   private processingErrors: Error[] = []; |   private processingErrors: Error[] = []; | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { | ||||||
|    * supported and loads the relevant WASM binary. |    * supported and loads the relevant WASM binary. | ||||||
|    * @return A fully instantiated instance of `T`. |    * @return A fully instantiated instance of `T`. | ||||||
|    */ |    */ | ||||||
|   protected static async createInstance<T extends TaskRunner>( |   protected static async createInstance<T extends TaskRunner<O>, | ||||||
|  |                                                   O extends TaskRunnerOptions>( | ||||||
|       type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean, |       type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean, | ||||||
|       fileset: WasmFileset): Promise<T> { |       fileset: WasmFileset, options: O): Promise<T> { | ||||||
|     const fileLocator: FileLocator = { |     const fileLocator: FileLocator = { | ||||||
|       locateFile() { |       locateFile() { | ||||||
|         // The only file loaded with this mechanism is the Wasm binary
 |         // The only file loaded with this mechanism is the Wasm binary
 | ||||||
|  | @ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { | ||||||
|       } |       } | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     if (initializeCanvas) { |     // Initialize a canvas if requested. If OffscreenCanvas is availble, we
 | ||||||
|       // Fall back to an OffscreenCanvas created by the GraphRunner if
 |     // let the graph runner initialize it by passing `undefined`.
 | ||||||
|       // OffscreenCanvas is available
 |     const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? | ||||||
|       const canvas = typeof OffscreenCanvas === 'undefined' ? |  | ||||||
|                                            document.createElement('canvas') : |                                            document.createElement('canvas') : | ||||||
|           undefined; |                                            undefined) : | ||||||
|       return createMediaPipeLib( |                                       null; | ||||||
|  |     const instance = await createMediaPipeLib( | ||||||
|         type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); |         type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); | ||||||
|     } else { |     await instance.setOptions(options); | ||||||
|       return createMediaPipeLib( |     return instance; | ||||||
|           type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, |  | ||||||
|           fileLocator); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   constructor( |   constructor( | ||||||
|  | @ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { | ||||||
|     this.registerModelResourcesGraphService(); |     this.registerModelResourcesGraphService(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** Configures the shared options of a MediaPipe Task. */ | ||||||
|  |   async setOptions(options: O): Promise<void> { | ||||||
|  |     if (options.baseOptions) { | ||||||
|  |       this.baseOptions = await convertBaseOptionsToProto( | ||||||
|  |           options.baseOptions, this.baseOptions); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** |   /** | ||||||
|    * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run |    * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run | ||||||
|    * over the video stream. Will replace the previously running MediaPipe graph, |    * over the video stream. Will replace the previously running MediaPipe graph, | ||||||
|  |  | ||||||
|  | @ -16,7 +16,7 @@ | ||||||
| 
 | 
 | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| /** Options to configure MediaPipe Tasks in general. */ | /** Options to configure MediaPipe model loading and processing. */ | ||||||
| export declare interface BaseOptions { | export declare interface BaseOptions { | ||||||
|   /** |   /** | ||||||
|    * The model path to the model asset file. Only one of `modelAssetPath` or |    * The model path to the model asset file. Only one of `modelAssetPath` or | ||||||
|  | @ -33,3 +33,9 @@ export declare interface BaseOptions { | ||||||
|   /** Overrides the default backend to use for the provided model. */ |   /** Overrides the default backend to use for the provided model. */ | ||||||
|   delegate?: 'cpu'|'gpu'|undefined; |   delegate?: 'cpu'|'gpu'|undefined; | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | /** Options to configure MediaPipe Tasks in general. */ | ||||||
|  | export declare interface TaskRunnerOptions { | ||||||
|  |   /** Options to configure the loading of the model assets. */ | ||||||
|  |   baseOptions?: BaseOptions; | ||||||
|  | } | ||||||
|  | @ -1,11 +0,0 @@ | ||||||
| # This package contains options shared by all MediaPipe Texxt Tasks for Web. |  | ||||||
| 
 |  | ||||||
| load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") |  | ||||||
| 
 |  | ||||||
| package(default_visibility = ["//mediapipe/tasks:internal"]) |  | ||||||
| 
 |  | ||||||
| mediapipe_ts_declaration( |  | ||||||
|     name = "text_task_options", |  | ||||||
|     srcs = ["text_task_options.d.ts"], |  | ||||||
|     deps = ["//mediapipe/tasks/web/core"], |  | ||||||
| ) |  | ||||||
|  | @ -1,23 +0,0 @@ | ||||||
| /** |  | ||||||
|  * Copyright 2022 The MediaPipe Authors. All Rights Reserved. |  | ||||||
|  * |  | ||||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
|  * you may not use this file except in compliance with the License. |  | ||||||
|  * You may obtain a copy of the License at |  | ||||||
|  * |  | ||||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
|  * |  | ||||||
|  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  * distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
|  * See the License for the specific language governing permissions and |  | ||||||
|  * limitations under the License. |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| import {BaseOptions} from '../../../../tasks/web/core/base_options'; |  | ||||||
| 
 |  | ||||||
| /** The options for configuring a MediaPipe Text task. */ |  | ||||||
| export declare interface TextTaskOptions { |  | ||||||
|   /** Options to configure the loading of the model assets. */ |  | ||||||
|   baseOptions?: BaseOptions; |  | ||||||
| } |  | ||||||
|  | @ -17,15 +17,16 @@ mediapipe_ts_library( | ||||||
|         "//mediapipe/framework:calculator_jspb_proto", |         "//mediapipe/framework:calculator_jspb_proto", | ||||||
|         "//mediapipe/framework:calculator_options_jspb_proto", |         "//mediapipe/framework:calculator_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", |         "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", | ||||||
|  |         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", |         "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/web/components/containers:category", |         "//mediapipe/tasks/web/components/containers:category", | ||||||
|         "//mediapipe/tasks/web/components/containers:classification_result", |         "//mediapipe/tasks/web/components/containers:classification_result", | ||||||
|         "//mediapipe/tasks/web/components/processors:base_options", |  | ||||||
|         "//mediapipe/tasks/web/components/processors:classifier_options", |         "//mediapipe/tasks/web/components/processors:classifier_options", | ||||||
|         "//mediapipe/tasks/web/components/processors:classifier_result", |         "//mediapipe/tasks/web/components/processors:classifier_result", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:classifier_options", |         "//mediapipe/tasks/web/core:classifier_options", | ||||||
|         "//mediapipe/tasks/web/core:task_runner", |         "//mediapipe/tasks/web/core:task_runner", | ||||||
|  |         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -38,7 +39,7 @@ mediapipe_ts_declaration( | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/tasks/web/components/containers:category", |         "//mediapipe/tasks/web/components/containers:category", | ||||||
|         "//mediapipe/tasks/web/components/containers:classification_result", |         "//mediapipe/tasks/web/components/containers:classification_result", | ||||||
|  |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:classifier_options", |         "//mediapipe/tasks/web/core:classifier_options", | ||||||
|         "//mediapipe/tasks/web/text/core:text_task_options", |  | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -17,12 +17,13 @@ | ||||||
| import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; | import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; | ||||||
| import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; | import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; | ||||||
| import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; | import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; | ||||||
|  | import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; | ||||||
| import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; | import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; | ||||||
| import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; |  | ||||||
| import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; | import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; | ||||||
| import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | ||||||
| import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
|  | import {WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {TextClassifierOptions} from './text_classifier_options'; | import {TextClassifierOptions} from './text_classifier_options'; | ||||||
|  | @ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = | ||||||
| // tslint:disable:jspb-use-builder-pattern
 | // tslint:disable:jspb-use-builder-pattern
 | ||||||
| 
 | 
 | ||||||
| /** Performs Natural Language classification. */ | /** Performs Natural Language classification. */ | ||||||
| export class TextClassifier extends TaskRunner { | export class TextClassifier extends TaskRunner<TextClassifierOptions> { | ||||||
|   private classificationResult: TextClassifierResult = {classifications: []}; |   private classificationResult: TextClassifierResult = {classifications: []}; | ||||||
|   private readonly options = new TextClassifierGraphOptions(); |   private readonly options = new TextClassifierGraphOptions(); | ||||||
| 
 | 
 | ||||||
|  | @ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner { | ||||||
|    *     either a path to the TFLite model or the model itself needs to be |    *     either a path to the TFLite model or the model itself needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> { |       textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> { | ||||||
|     const classifier = await TaskRunner.createInstance( |     return TaskRunner.createInstance( | ||||||
|         TextClassifier, /* initializeCanvas= */ false, wasmFileset); |         TextClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     await classifier.setOptions(textClassifierOptions); |         textClassifierOptions); | ||||||
|     return classifier; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<TextClassifier> { |       modelAssetBuffer: Uint8Array): Promise<TextClassifier> { | ||||||
|     return TextClassifier.createFromOptions( |     return TaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         TextClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the model asset. |    * @param modelAssetPath The path to the model asset. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<TextClassifier> { |       modelAssetPath: string): Promise<TextClassifier> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return TaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         TextClassifier, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     return TextClassifier.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |   } | ||||||
|  | 
 | ||||||
|  |   constructor( | ||||||
|  |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner { | ||||||
|    * |    * | ||||||
|    * @param options The options for the text classifier. |    * @param options The options for the text classifier. | ||||||
|    */ |    */ | ||||||
|   async setOptions(options: TextClassifierOptions): Promise<void> { |   override async setOptions(options: TextClassifierOptions): Promise<void> { | ||||||
|     if (options.baseOptions) { |     await super.setOptions(options); | ||||||
|       const baseOptionsProto = await convertBaseOptionsToProto( |  | ||||||
|           options.baseOptions, this.options.getBaseOptions()); |  | ||||||
|       this.options.setBaseOptions(baseOptionsProto); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     this.options.setClassifierOptions(convertClassifierOptionsToProto( |     this.options.setClassifierOptions(convertClassifierOptionsToProto( | ||||||
|         options, this.options.getClassifierOptions())); |         options, this.options.getClassifierOptions())); | ||||||
|     this.refreshGraph(); |     this.refreshGraph(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|  |     this.options.setBaseOptions(proto); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|    * Performs Natural Language classification on the provided text and waits |    * Performs Natural Language classification on the provided text and waits | ||||||
|  |  | ||||||
|  | @ -15,8 +15,8 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; | import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; | ||||||
| import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| /** Options to configure the MediaPipe Text Classifier Task */ | /** Options to configure the MediaPipe Text Classifier Task */ | ||||||
| export declare interface TextClassifierOptions extends ClassifierOptions, | export declare interface TextClassifierOptions extends ClassifierOptions, | ||||||
|                                                        TextTaskOptions {} |                                                        TaskRunnerOptions {} | ||||||
|  |  | ||||||
|  | @ -17,15 +17,16 @@ mediapipe_ts_library( | ||||||
|         "//mediapipe/framework:calculator_jspb_proto", |         "//mediapipe/framework:calculator_jspb_proto", | ||||||
|         "//mediapipe/framework:calculator_options_jspb_proto", |         "//mediapipe/framework:calculator_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", |         "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", | ||||||
|  |         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", |         "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", | ||||||
|         "//mediapipe/tasks/web/components/containers:embedding_result", |         "//mediapipe/tasks/web/components/containers:embedding_result", | ||||||
|         "//mediapipe/tasks/web/components/processors:base_options", |  | ||||||
|         "//mediapipe/tasks/web/components/processors:embedder_options", |         "//mediapipe/tasks/web/components/processors:embedder_options", | ||||||
|         "//mediapipe/tasks/web/components/processors:embedder_result", |         "//mediapipe/tasks/web/components/processors:embedder_result", | ||||||
|         "//mediapipe/tasks/web/components/utils:cosine_similarity", |         "//mediapipe/tasks/web/components/utils:cosine_similarity", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:embedder_options", |         "//mediapipe/tasks/web/core:embedder_options", | ||||||
|         "//mediapipe/tasks/web/core:task_runner", |         "//mediapipe/tasks/web/core:task_runner", | ||||||
|  |         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -39,6 +40,5 @@ mediapipe_ts_declaration( | ||||||
|         "//mediapipe/tasks/web/components/containers:embedding_result", |         "//mediapipe/tasks/web/components/containers:embedding_result", | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:embedder_options", |         "//mediapipe/tasks/web/core:embedder_options", | ||||||
|         "//mediapipe/tasks/web/text/core:text_task_options", |  | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -17,14 +17,15 @@ | ||||||
| import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; | import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; | ||||||
| import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; | import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; | ||||||
| import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; | import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; | ||||||
|  | import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; | ||||||
| import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; | import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; | ||||||
| import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; | import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; | ||||||
| import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; |  | ||||||
| import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; | import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; | ||||||
| import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; | import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; | ||||||
| import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | ||||||
| import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
|  | import {WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {TextEmbedderOptions} from './text_embedder_options'; | import {TextEmbedderOptions} from './text_embedder_options'; | ||||||
|  | @ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = | ||||||
| /** | /** | ||||||
|  * Performs embedding extraction on text. |  * Performs embedding extraction on text. | ||||||
|  */ |  */ | ||||||
| export class TextEmbedder extends TaskRunner { | export class TextEmbedder extends TaskRunner<TextEmbedderOptions> { | ||||||
|   private embeddingResult: TextEmbedderResult = {embeddings: []}; |   private embeddingResult: TextEmbedderResult = {embeddings: []}; | ||||||
|   private readonly options = new TextEmbedderGraphOptionsProto(); |   private readonly options = new TextEmbedderGraphOptionsProto(); | ||||||
| 
 | 
 | ||||||
|  | @ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner { | ||||||
|    *     either a path to the TFLite model or the model itself needs to be |    *     either a path to the TFLite model or the model itself needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> { |       textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> { | ||||||
|     const embedder = await TaskRunner.createInstance( |     return TaskRunner.createInstance( | ||||||
|         TextEmbedder, /* initializeCanvas= */ false, wasmFileset); |         TextEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     await embedder.setOptions(textEmbedderOptions); |         textEmbedderOptions); | ||||||
|     return embedder; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<TextEmbedder> { |       modelAssetBuffer: Uint8Array): Promise<TextEmbedder> { | ||||||
|     return TextEmbedder.createFromOptions( |     return TaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         TextEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the TFLite model. |    * @param modelAssetPath The path to the TFLite model. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<TextEmbedder> { |       modelAssetPath: string): Promise<TextEmbedder> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return TaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         TextEmbedder, /* initializeCanvas= */ false, wasmFileset, | ||||||
|     return TextEmbedder.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |   } | ||||||
|  | 
 | ||||||
|  |   constructor( | ||||||
|  |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner { | ||||||
|    * |    * | ||||||
|    * @param options The options for the text embedder. |    * @param options The options for the text embedder. | ||||||
|    */ |    */ | ||||||
|   async setOptions(options: TextEmbedderOptions): Promise<void> { |   override async setOptions(options: TextEmbedderOptions): Promise<void> { | ||||||
|     if (options.baseOptions) { |     await super.setOptions(options); | ||||||
|       const baseOptionsProto = await convertBaseOptionsToProto( |  | ||||||
|           options.baseOptions, this.options.getBaseOptions()); |  | ||||||
|       this.options.setBaseOptions(baseOptionsProto); |  | ||||||
|     } |  | ||||||
|     this.options.setEmbedderOptions(convertEmbedderOptionsToProto( |     this.options.setEmbedderOptions(convertEmbedderOptionsToProto( | ||||||
|         options, this.options.getEmbedderOptions())); |         options, this.options.getEmbedderOptions())); | ||||||
|     this.refreshGraph(); |     this.refreshGraph(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|  |     this.options.setBaseOptions(proto); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** |   /** | ||||||
|    * Performs embeding extraction on the provided text and waits synchronously |    * Performs embeding extraction on the provided text and waits synchronously | ||||||
|    * for the response. |    * for the response. | ||||||
|  |  | ||||||
|  | @ -15,8 +15,8 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; | import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; | ||||||
| import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| /** Options to configure the MediaPipe Text Embedder Task */ | /** Options to configure the MediaPipe Text Embedder Task */ | ||||||
| export declare interface TextEmbedderOptions extends EmbedderOptions, | export declare interface TextEmbedderOptions extends EmbedderOptions, | ||||||
|                                                      TextTaskOptions {} |                                                      TaskRunnerOptions {} | ||||||
|  |  | ||||||
|  | @ -17,8 +17,6 @@ mediapipe_ts_library( | ||||||
|     srcs = ["vision_task_runner.ts"], |     srcs = ["vision_task_runner.ts"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":vision_task_options", |         ":vision_task_options", | ||||||
|         "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", |  | ||||||
|         "//mediapipe/tasks/web/components/processors:base_options", |  | ||||||
|         "//mediapipe/tasks/web/core", |         "//mediapipe/tasks/web/core", | ||||||
|         "//mediapipe/tasks/web/core:task_runner", |         "//mediapipe/tasks/web/core:task_runner", | ||||||
|         "//mediapipe/web/graph_runner:graph_runner_ts", |         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {BaseOptions} from '../../../../tasks/web/core/base_options'; | import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * The two running modes of a vision task. |  * The two running modes of a vision task. | ||||||
|  | @ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; | ||||||
|  */ |  */ | ||||||
| export type RunningMode = 'image'|'video'; | export type RunningMode = 'image'|'video'; | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| /** The options for configuring a MediaPipe vision task. */ | /** The options for configuring a MediaPipe vision task. */ | ||||||
| export declare interface VisionTaskOptions { | export declare interface VisionTaskOptions extends TaskRunnerOptions { | ||||||
|   /** Options to configure the loading of the model assets. */ |  | ||||||
|   baseOptions?: BaseOptions; |  | ||||||
| 
 |  | ||||||
|   /** |   /** | ||||||
|    * The running mode of the task. Default to the image mode. |    * The running mode of the task. Default to the image mode. | ||||||
|    * Vision tasks have two running modes: |    * Vision tasks have two running modes: | ||||||
|  |  | ||||||
|  | @ -14,24 +14,17 @@ | ||||||
|  * limitations under the License. |  * limitations under the License. | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; |  | ||||||
| import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; |  | ||||||
| import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | import {TaskRunner} from '../../../../tasks/web/core/task_runner'; | ||||||
| import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | ||||||
| 
 | 
 | ||||||
| import {VisionTaskOptions} from './vision_task_options'; | import {VisionTaskOptions} from './vision_task_options'; | ||||||
| 
 | 
 | ||||||
| /** Base class for all MediaPipe Vision Tasks. */ | /** Base class for all MediaPipe Vision Tasks. */ | ||||||
| export abstract class VisionTaskRunner<T> extends TaskRunner { | export abstract class VisionTaskRunner<T> extends | ||||||
|   protected abstract baseOptions?: BaseOptionsProto|undefined; |     TaskRunner<VisionTaskOptions> { | ||||||
| 
 |  | ||||||
|   /** Configures the shared options of a vision task. */ |   /** Configures the shared options of a vision task. */ | ||||||
|   async setOptions(options: VisionTaskOptions): Promise<void> { |   override async setOptions(options: VisionTaskOptions): Promise<void> { | ||||||
|     this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); |     await super.setOptions(options); | ||||||
|     if (options.baseOptions) { |  | ||||||
|       this.baseOptions = await convertBaseOptionsToProto( |  | ||||||
|           options.baseOptions, this.baseOptions); |  | ||||||
|     } |  | ||||||
|     if ('runningMode' in options) { |     if ('runningMode' in options) { | ||||||
|       const useStreamMode = |       const useStreamMode = | ||||||
|           !!options.runningMode && options.runningMode !== 'image'; |           !!options.runningMode && options.runningMode !== 'image'; | ||||||
|  |  | ||||||
|  | @ -88,14 +88,13 @@ export class GestureRecognizer extends | ||||||
|    *     Note that either a path to the model asset or a model buffer needs to |    *     Note that either a path to the model asset or a model buffer needs to | ||||||
|    *     be provided (via `baseOptions`). |    *     be provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       gestureRecognizerOptions: GestureRecognizerOptions): |       gestureRecognizerOptions: GestureRecognizerOptions): | ||||||
|       Promise<GestureRecognizer> { |       Promise<GestureRecognizer> { | ||||||
|     const recognizer = await VisionTaskRunner.createInstance( |     return VisionTaskRunner.createInstance( | ||||||
|         GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); |         GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     await recognizer.setOptions(gestureRecognizerOptions); |         gestureRecognizerOptions); | ||||||
|     return recognizer; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -108,8 +107,9 @@ export class GestureRecognizer extends | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> { |       modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> { | ||||||
|     return GestureRecognizer.createFromOptions( |     return VisionTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -119,13 +119,12 @@ export class GestureRecognizer extends | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the model asset. |    * @param modelAssetPath The path to the model asset. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<GestureRecognizer> { |       modelAssetPath: string): Promise<GestureRecognizer> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return VisionTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     return GestureRecognizer.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   constructor( |   constructor( | ||||||
|  | @ -134,6 +133,7 @@ export class GestureRecognizer extends | ||||||
|     super(wasmModule, glCanvas); |     super(wasmModule, glCanvas); | ||||||
| 
 | 
 | ||||||
|     this.options = new GestureRecognizerGraphOptions(); |     this.options = new GestureRecognizerGraphOptions(); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|     this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); |     this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); | ||||||
|     this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); |     this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); | ||||||
|     this.handLandmarksDetectorGraphOptions = |     this.handLandmarksDetectorGraphOptions = | ||||||
|  | @ -151,11 +151,11 @@ export class GestureRecognizer extends | ||||||
|     this.initDefaults(); |     this.initDefaults(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|     return this.options.getBaseOptions(); |     return this.options.getBaseOptions()!; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | ||||||
|    *     Note that either a path to the model asset or a model buffer needs to |    *     Note that either a path to the model asset or a model buffer needs to | ||||||
|    *     be provided (via `baseOptions`). |    *     be provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> { |       handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> { | ||||||
|     const landmarker = await VisionTaskRunner.createInstance( |     return VisionTaskRunner.createInstance( | ||||||
|         HandLandmarker, /* initializeCanvas= */ true, wasmFileset); |         HandLandmarker, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     await landmarker.setOptions(handLandmarkerOptions); |         handLandmarkerOptions); | ||||||
|     return landmarker; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<HandLandmarker> { |       modelAssetBuffer: Uint8Array): Promise<HandLandmarker> { | ||||||
|     return HandLandmarker.createFromOptions( |     return VisionTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         HandLandmarker, /* initializeCanvas= */ true, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the model asset. |    * @param modelAssetPath The path to the model asset. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<HandLandmarker> { |       modelAssetPath: string): Promise<HandLandmarker> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return VisionTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         HandLandmarker, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     return HandLandmarker.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   constructor( |   constructor( | ||||||
|  | @ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | ||||||
|     super(wasmModule, glCanvas); |     super(wasmModule, glCanvas); | ||||||
| 
 | 
 | ||||||
|     this.options = new HandLandmarkerGraphOptions(); |     this.options = new HandLandmarkerGraphOptions(); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|     this.handLandmarksDetectorGraphOptions = |     this.handLandmarksDetectorGraphOptions = | ||||||
|         new HandLandmarksDetectorGraphOptions(); |         new HandLandmarksDetectorGraphOptions(); | ||||||
|     this.options.setHandLandmarksDetectorGraphOptions( |     this.options.setHandLandmarksDetectorGraphOptions( | ||||||
|  | @ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | ||||||
|     this.initDefaults(); |     this.initDefaults(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|     return this.options.getBaseOptions(); |     return this.options.getBaseOptions()!; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ | ||||||
| import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
| import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | ||||||
| import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {ImageClassifierOptions} from './image_classifier_options'; | import {ImageClassifierOptions} from './image_classifier_options'; | ||||||
|  | @ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | ||||||
|    *     that either a path to the model asset or a model buffer needs to be |    *     that either a path to the model asset or a model buffer needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): |       wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): | ||||||
|       Promise<ImageClassifier> { |       Promise<ImageClassifier> { | ||||||
|     const classifier = await VisionTaskRunner.createInstance( |     return VisionTaskRunner.createInstance( | ||||||
|         ImageClassifier, /* initializeCanvas= */ true, wasmFileset); |         ImageClassifier, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     await classifier.setOptions(imageClassifierOptions); |         imageClassifierOptions); | ||||||
|     return classifier; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<ImageClassifier> { |       modelAssetBuffer: Uint8Array): Promise<ImageClassifier> { | ||||||
|     return ImageClassifier.createFromOptions( |     return VisionTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         ImageClassifier, /* initializeCanvas= */ true, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the model asset. |    * @param modelAssetPath The path to the model asset. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<ImageClassifier> { |       modelAssetPath: string): Promise<ImageClassifier> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return VisionTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         ImageClassifier, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     return ImageClassifier.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   constructor( | ||||||
|     return this.options.getBaseOptions(); |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/ | ||||||
| import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
| import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | ||||||
| import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {ImageEmbedderOptions} from './image_embedder_options'; | import {ImageEmbedderOptions} from './image_embedder_options'; | ||||||
|  | @ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | ||||||
|    *     either a path to the TFLite model or the model itself needs to be |    *     either a path to the TFLite model or the model itself needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> { |       imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> { | ||||||
|     const embedder = await VisionTaskRunner.createInstance( |     return VisionTaskRunner.createInstance( | ||||||
|         ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); |         ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     await embedder.setOptions(imageEmbedderOptions); |         imageEmbedderOptions); | ||||||
|     return embedder; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> { |       modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> { | ||||||
|     return ImageEmbedder.createFromOptions( |     return VisionTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | ||||||
|    *     Wasm binary and its loader. |    *     Wasm binary and its loader. | ||||||
|    * @param modelAssetPath The path to the TFLite model. |    * @param modelAssetPath The path to the TFLite model. | ||||||
|    */ |    */ | ||||||
|   static async createFromModelPath( |   static createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<ImageEmbedder> { |       modelAssetPath: string): Promise<ImageEmbedder> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return VisionTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     return ImageEmbedder.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   constructor( | ||||||
|     return this.options.getBaseOptions(); |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b | ||||||
| import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; | import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; | ||||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||||
| import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | ||||||
| import {ImageSource} from '../../../../web/graph_runner/graph_runner'; | import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||||
| // Placeholder for internal dependency on trusted resource url
 | // Placeholder for internal dependency on trusted resource url
 | ||||||
| 
 | 
 | ||||||
| import {ObjectDetectorOptions} from './object_detector_options'; | import {ObjectDetectorOptions} from './object_detector_options'; | ||||||
|  | @ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | ||||||
|    *     either a path to the model asset or a model buffer needs to be |    *     either a path to the model asset or a model buffer needs to be | ||||||
|    *     provided (via `baseOptions`). |    *     provided (via `baseOptions`). | ||||||
|    */ |    */ | ||||||
|   static async createFromOptions( |   static createFromOptions( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> { |       objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> { | ||||||
|     const detector = await VisionTaskRunner.createInstance( |     return VisionTaskRunner.createInstance( | ||||||
|         ObjectDetector, /* initializeCanvas= */ true, wasmFileset); |         ObjectDetector, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     await detector.setOptions(objectDetectorOptions); |         objectDetectorOptions); | ||||||
|     return detector; |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | ||||||
|   static createFromModelBuffer( |   static createFromModelBuffer( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetBuffer: Uint8Array): Promise<ObjectDetector> { |       modelAssetBuffer: Uint8Array): Promise<ObjectDetector> { | ||||||
|     return ObjectDetector.createFromOptions( |     return VisionTaskRunner.createInstance( | ||||||
|         wasmFileset, {baseOptions: {modelAssetBuffer}}); |         ObjectDetector, /* initializeCanvas= */ true, wasmFileset, | ||||||
|  |         {baseOptions: {modelAssetBuffer}}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | ||||||
|   static async createFromModelPath( |   static async createFromModelPath( | ||||||
|       wasmFileset: WasmFileset, |       wasmFileset: WasmFileset, | ||||||
|       modelAssetPath: string): Promise<ObjectDetector> { |       modelAssetPath: string): Promise<ObjectDetector> { | ||||||
|     const response = await fetch(modelAssetPath.toString()); |     return VisionTaskRunner.createInstance( | ||||||
|     const graphData = await response.arrayBuffer(); |         ObjectDetector, /* initializeCanvas= */ true, wasmFileset, | ||||||
|     return ObjectDetector.createFromModelBuffer( |         {baseOptions: {modelAssetPath}}); | ||||||
|         wasmFileset, new Uint8Array(graphData)); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override get baseOptions(): BaseOptionsProto|undefined { |   constructor( | ||||||
|     return this.options.getBaseOptions(); |       wasmModule: WasmModule, | ||||||
|  |       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||||
|  |     super(wasmModule, glCanvas); | ||||||
|  |     this.options.setBaseOptions(new BaseOptionsProto()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   protected override set baseOptions(proto: BaseOptionsProto|undefined) { |   protected override get baseOptions(): BaseOptionsProto { | ||||||
|  |     return this.options.getBaseOptions()!; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   protected override set baseOptions(proto: BaseOptionsProto) { | ||||||
|     this.options.setBaseOptions(proto); |     this.options.setBaseOptions(proto); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user