Move shared code to TaskRunner

PiperOrigin-RevId: 492534879
This commit is contained in:
Sebastian Schmidt 2022-12-02 12:40:59 -08:00 committed by Copybara-Service
parent dabc2af15b
commit da9587033d
32 changed files with 262 additions and 305 deletions

View File

@ -25,7 +25,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )
@ -36,7 +36,6 @@ mediapipe_ts_declaration(
"audio_classifier_result.d.ts", "audio_classifier_result.d.ts",
], ],
deps = [ deps = [
"//mediapipe/tasks/web/audio/core:audio_task_options",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",

View File

@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierOptions} from './audio_classifier_options';
@ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
* that either a path to the model asset or a model buffer needs to be * that either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions):
Promise<AudioClassifier> { Promise<AudioClassifier> {
const classifier = await TaskRunner.createInstance( return AudioTaskRunner.createInstance(
AudioClassifier, /* initializeCanvas= */ false, wasmFileset); AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
await classifier.setOptions(audioClassifierOptions); audioClassifierOptions);
return classifier;
} }
/** /**
@ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> { modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
return AudioClassifier.createFromOptions( return AudioTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the model asset. * @param modelAssetPath The path to the model asset.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<AudioClassifier> { modelAssetPath: string): Promise<AudioClassifier> {
const response = await fetch(modelAssetPath.toString()); return AudioTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
return AudioClassifier.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined { constructor(
return this.options.getBaseOptions(); wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -14,9 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options';
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** Options to configure the MediaPipe Audio Classifier Task */ /** Options to configure the MediaPipe Audio Classifier Task */
export declare interface AudioClassifierOptions extends ClassifierOptions, export declare interface AudioClassifierOptions extends ClassifierOptions,
AudioTaskOptions {} TaskRunnerOptions {}

View File

@ -36,7 +36,6 @@ mediapipe_ts_declaration(
"audio_embedder_result.d.ts", "audio_embedder_result.d.ts",
], ],
deps = [ deps = [
"//mediapipe/tasks/web/audio/core:audio_task_options",
"//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",

View File

@ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {AudioEmbedderOptions} from './audio_embedder_options'; import {AudioEmbedderOptions} from './audio_embedder_options';
@ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
* either a path to the TFLite model or the model itself needs to be * either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> { audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> {
// Create a file locator based on the loader options return AudioTaskRunner.createInstance(
const fileLocator: FileLocator = { AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
locateFile() { audioEmbedderOptions);
// The only file we load is the Wasm binary
return wasmFileset.wasmBinaryPath.toString();
}
};
const embedder = await createMediaPipeLib(
AudioEmbedder, wasmFileset.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await embedder.setOptions(audioEmbedderOptions);
return embedder;
} }
/** /**
@ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> { modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> {
return AudioEmbedder.createFromOptions( return AudioTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the TFLite model. * @param modelAssetPath The path to the TFLite model.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<AudioEmbedder> { modelAssetPath: string): Promise<AudioEmbedder> {
const response = await fetch(modelAssetPath.toString()); return AudioTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
return AudioEmbedder.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined { constructor(
return this.options.getBaseOptions(); wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -14,9 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options';
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** Options to configure the MediaPipe Audio Embedder Task */ /** Options to configure the MediaPipe Audio Embedder Task */
export declare interface AudioEmbedderOptions extends EmbedderOptions, export declare interface AudioEmbedderOptions extends EmbedderOptions,
AudioTaskOptions {} TaskRunnerOptions {}

View File

@ -1,24 +1,13 @@
# This package contains options shared by all MediaPipe Audio Tasks for Web. # This package contains options shared by all MediaPipe Audio Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"]) package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_declaration(
name = "audio_task_options",
srcs = ["audio_task_options.d.ts"],
deps = [
"//mediapipe/tasks/web/core",
],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "audio_task_runner", name = "audio_task_runner",
srcs = ["audio_task_runner.ts"], srcs = ["audio_task_runner.ts"],
deps = [ deps = [
":audio_task_options",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner",
], ],

View File

@ -1,23 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../../tasks/web/core/base_options';
/** The options for configuring a MediaPipe Audio Task. */
export declare interface AudioTaskOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
}

View File

@ -14,26 +14,13 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
import {AudioTaskOptions} from './audio_task_options';
/** Base class for all MediaPipe Audio Tasks. */ /** Base class for all MediaPipe Audio Tasks. */
export abstract class AudioTaskRunner<T> extends TaskRunner { export abstract class AudioTaskRunner<T> extends TaskRunner<TaskRunnerOptions> {
protected abstract baseOptions?: BaseOptionsProto|undefined;
private defaultSampleRate = 48000; private defaultSampleRate = 48000;
/** Configures the shared options of an audio task. */
async setOptions(options: AudioTaskOptions): Promise<void> {
this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto(
options.baseOptions, this.baseOptions);
}
}
/** /**
* Sets the sample rate for API calls that omit an explicit sample rate. * Sets the sample rate for API calls that omit an explicit sample rate.
* `48000` is used as a default if this method is not called. * `48000` is used as a default if this method is not called.

View File

@ -17,7 +17,6 @@ mediapipe_ts_library(
name = "classifier_result", name = "classifier_result",
srcs = ["classifier_result.ts"], srcs = ["classifier_result.ts"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
], ],

View File

@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern

View File

@ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_declaration( mediapipe_ts_declaration(
name = "core", name = "core",
srcs = [ srcs = [
"base_options.d.ts", "task_runner_options.d.ts",
"wasm_fileset.d.ts", "wasm_fileset.d.ts",
], ],
) )
mediapipe_ts_library( mediapipe_ts_library(
name = "task_runner", name = "task_runner",
srcs = [ srcs = ["task_runner.ts"],
"task_runner.ts",
],
deps = [ deps = [
":core", ":core",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",

View File

@ -14,8 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions} from '../../../tasks/web/core/base_options';
/** Options to configure a MediaPipe Classifier Task. */ /** Options to configure a MediaPipe Classifier Task. */
export declare interface ClassifierOptions { export declare interface ClassifierOptions {
/** /**

View File

@ -14,8 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions} from '../../../tasks/web/core/base_options';
/** Options to configure a MediaPipe Embedder Task */ /** Options to configure a MediaPipe Embedder Task */
export declare interface EmbedderOptions { export declare interface EmbedderOptions {
/** /**

View File

@ -14,6 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
@ -28,7 +31,9 @@ const WasmMediaPipeImageLib =
SupportModelResourcesGraphService(SupportImage(GraphRunner)); SupportModelResourcesGraphService(SupportImage(GraphRunner));
/** Base class for all MediaPipe Tasks. */ /** Base class for all MediaPipe Tasks. */
export abstract class TaskRunner extends WasmMediaPipeImageLib { export abstract class TaskRunner<O extends TaskRunnerOptions> extends
WasmMediaPipeImageLib {
protected abstract baseOptions: BaseOptionsProto;
private processingErrors: Error[] = []; private processingErrors: Error[] = [];
/** /**
@ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
* supported and loads the relevant WASM binary. * supported and loads the relevant WASM binary.
* @return A fully instantiated instance of `T`. * @return A fully instantiated instance of `T`.
*/ */
protected static async createInstance<T extends TaskRunner>( protected static async createInstance<T extends TaskRunner<O>,
O extends TaskRunnerOptions>(
type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean, type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean,
fileset: WasmFileset): Promise<T> { fileset: WasmFileset, options: O): Promise<T> {
const fileLocator: FileLocator = { const fileLocator: FileLocator = {
locateFile() { locateFile() {
// The only file loaded with this mechanism is the Wasm binary // The only file loaded with this mechanism is the Wasm binary
@ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
} }
}; };
if (initializeCanvas) { // Initialize a canvas if requested. If OffscreenCanvas is availble, we
// Fall back to an OffscreenCanvas created by the GraphRunner if // let the graph runner initialize it by passing `undefined`.
// OffscreenCanvas is available const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ?
const canvas = typeof OffscreenCanvas === 'undefined' ? document.createElement('canvas') :
document.createElement('canvas') : undefined) :
undefined; null;
return createMediaPipeLib( const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
} else { await instance.setOptions(options);
return createMediaPipeLib( return instance;
type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null,
fileLocator);
}
} }
constructor( constructor(
@ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
this.registerModelResourcesGraphService(); this.registerModelResourcesGraphService();
} }
/** Configures the shared options of a MediaPipe Task. */
async setOptions(options: O): Promise<void> {
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto(
options.baseOptions, this.baseOptions);
}
}
/** /**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph, * over the video stream. Will replace the previously running MediaPipe graph,

View File

@ -16,7 +16,7 @@
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
/** Options to configure MediaPipe Tasks in general. */ /** Options to configure MediaPipe model loading and processing. */
export declare interface BaseOptions { export declare interface BaseOptions {
/** /**
* The model path to the model asset file. Only one of `modelAssetPath` or * The model path to the model asset file. Only one of `modelAssetPath` or
@ -33,3 +33,9 @@ export declare interface BaseOptions {
/** Overrides the default backend to use for the provided model. */ /** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined; delegate?: 'cpu'|'gpu'|undefined;
} }
/** Options to configure MediaPipe Tasks in general. */
export declare interface TaskRunnerOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
}

View File

@ -1,11 +0,0 @@
# This package contains options shared by all MediaPipe Texxt Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_declaration(
name = "text_task_options",
srcs = ["text_task_options.d.ts"],
deps = ["//mediapipe/tasks/web/core"],
)

View File

@ -1,23 +0,0 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../../tasks/web/core/base_options';
/** The options for configuring a MediaPipe Text task. */
export declare interface TextTaskOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
}

View File

@ -17,15 +17,16 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )
@ -38,7 +39,7 @@ mediapipe_ts_declaration(
deps = [ deps = [
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/text/core:text_task_options",
], ],
) )

View File

@ -17,12 +17,13 @@
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierOptions} from './text_classifier_options';
@ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH =
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/** Performs Natural Language classification. */ /** Performs Natural Language classification. */
export class TextClassifier extends TaskRunner { export class TextClassifier extends TaskRunner<TextClassifierOptions> {
private classificationResult: TextClassifierResult = {classifications: []}; private classificationResult: TextClassifierResult = {classifications: []};
private readonly options = new TextClassifierGraphOptions(); private readonly options = new TextClassifierGraphOptions();
@ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner {
* either a path to the TFLite model or the model itself needs to be * either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> { textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
const classifier = await TaskRunner.createInstance( return TaskRunner.createInstance(
TextClassifier, /* initializeCanvas= */ false, wasmFileset); TextClassifier, /* initializeCanvas= */ false, wasmFileset,
await classifier.setOptions(textClassifierOptions); textClassifierOptions);
return classifier;
} }
/** /**
@ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<TextClassifier> { modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
return TextClassifier.createFromOptions( return TaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); TextClassifier, /* initializeCanvas= */ false, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the model asset. * @param modelAssetPath The path to the model asset.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<TextClassifier> { modelAssetPath: string): Promise<TextClassifier> {
const response = await fetch(modelAssetPath.toString()); return TaskRunner.createInstance(
const graphData = await response.arrayBuffer(); TextClassifier, /* initializeCanvas= */ false, wasmFileset,
return TextClassifier.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData)); }
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
/** /**
@ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner {
* *
* @param options The options for the text classifier. * @param options The options for the text classifier.
*/ */
async setOptions(options: TextClassifierOptions): Promise<void> { override async setOptions(options: TextClassifierOptions): Promise<void> {
if (options.baseOptions) { await super.setOptions(options);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); this.refreshGraph();
} }
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/** /**
* Performs Natural Language classification on the provided text and waits * Performs Natural Language classification on the provided text and waits

View File

@ -15,8 +15,8 @@
*/ */
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** Options to configure the MediaPipe Text Classifier Task */ /** Options to configure the MediaPipe Text Classifier Task */
export declare interface TextClassifierOptions extends ClassifierOptions, export declare interface TextClassifierOptions extends ClassifierOptions,
TextTaskOptions {} TaskRunnerOptions {}

View File

@ -17,15 +17,16 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_options",
"//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/processors:embedder_result",
"//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )
@ -39,6 +40,5 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/containers:embedding_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/text/core:text_task_options",
], ],
) )

View File

@ -17,14 +17,15 @@
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb';
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {TextEmbedderOptions} from './text_embedder_options'; import {TextEmbedderOptions} from './text_embedder_options';
@ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR =
/** /**
* Performs embedding extraction on text. * Performs embedding extraction on text.
*/ */
export class TextEmbedder extends TaskRunner { export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
private embeddingResult: TextEmbedderResult = {embeddings: []}; private embeddingResult: TextEmbedderResult = {embeddings: []};
private readonly options = new TextEmbedderGraphOptionsProto(); private readonly options = new TextEmbedderGraphOptionsProto();
@ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner {
* either a path to the TFLite model or the model itself needs to be * either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> { textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> {
const embedder = await TaskRunner.createInstance( return TaskRunner.createInstance(
TextEmbedder, /* initializeCanvas= */ false, wasmFileset); TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
await embedder.setOptions(textEmbedderOptions); textEmbedderOptions);
return embedder;
} }
/** /**
@ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<TextEmbedder> { modelAssetBuffer: Uint8Array): Promise<TextEmbedder> {
return TextEmbedder.createFromOptions( return TaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the TFLite model. * @param modelAssetPath The path to the TFLite model.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<TextEmbedder> { modelAssetPath: string): Promise<TextEmbedder> {
const response = await fetch(modelAssetPath.toString()); return TaskRunner.createInstance(
const graphData = await response.arrayBuffer(); TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
return TextEmbedder.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData)); }
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
/** /**
@ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner {
* *
* @param options The options for the text embedder. * @param options The options for the text embedder.
*/ */
async setOptions(options: TextEmbedderOptions): Promise<void> { override async setOptions(options: TextEmbedderOptions): Promise<void> {
if (options.baseOptions) { await super.setOptions(options);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setEmbedderOptions(convertEmbedderOptionsToProto( this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions())); options, this.options.getEmbedderOptions()));
this.refreshGraph(); this.refreshGraph();
} }
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/** /**
* Performs embeding extraction on the provided text and waits synchronously * Performs embeding extraction on the provided text and waits synchronously
* for the response. * for the response.

View File

@ -15,8 +15,8 @@
*/ */
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** Options to configure the MediaPipe Text Embedder Task */ /** Options to configure the MediaPipe Text Embedder Task */
export declare interface TextEmbedderOptions extends EmbedderOptions, export declare interface TextEmbedderOptions extends EmbedderOptions,
TextTaskOptions {} TaskRunnerOptions {}

View File

@ -17,8 +17,6 @@ mediapipe_ts_library(
srcs = ["vision_task_runner.ts"], srcs = ["vision_task_runner.ts"],
deps = [ deps = [
":vision_task_options", ":vision_task_options",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner", "//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** /**
* The two running modes of a vision task. * The two running modes of a vision task.
@ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
*/ */
export type RunningMode = 'image'|'video'; export type RunningMode = 'image'|'video';
/** The options for configuring a MediaPipe vision task. */ /** The options for configuring a MediaPipe vision task. */
export declare interface VisionTaskOptions { export declare interface VisionTaskOptions extends TaskRunnerOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
/** /**
* The running mode of the task. Default to the image mode. * The running mode of the task. Default to the image mode.
* Vision tasks have two running modes: * Vision tasks have two running modes:

View File

@ -14,24 +14,17 @@
* limitations under the License. * limitations under the License.
*/ */
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskOptions} from './vision_task_options';
/** Base class for all MediaPipe Vision Tasks. */ /** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner<T> extends TaskRunner { export abstract class VisionTaskRunner<T> extends
protected abstract baseOptions?: BaseOptionsProto|undefined; TaskRunner<VisionTaskOptions> {
/** Configures the shared options of a vision task. */ /** Configures the shared options of a vision task. */
async setOptions(options: VisionTaskOptions): Promise<void> { override async setOptions(options: VisionTaskOptions): Promise<void> {
this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); await super.setOptions(options);
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto(
options.baseOptions, this.baseOptions);
}
if ('runningMode' in options) { if ('runningMode' in options) {
const useStreamMode = const useStreamMode =
!!options.runningMode && options.runningMode !== 'image'; !!options.runningMode && options.runningMode !== 'image';

View File

@ -88,14 +88,13 @@ export class GestureRecognizer extends
* Note that either a path to the model asset or a model buffer needs to * Note that either a path to the model asset or a model buffer needs to
* be provided (via `baseOptions`). * be provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
gestureRecognizerOptions: GestureRecognizerOptions): gestureRecognizerOptions: GestureRecognizerOptions):
Promise<GestureRecognizer> { Promise<GestureRecognizer> {
const recognizer = await VisionTaskRunner.createInstance( return VisionTaskRunner.createInstance(
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
await recognizer.setOptions(gestureRecognizerOptions); gestureRecognizerOptions);
return recognizer;
} }
/** /**
@ -108,8 +107,9 @@ export class GestureRecognizer extends
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> { modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
return GestureRecognizer.createFromOptions( return VisionTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -119,13 +119,12 @@ export class GestureRecognizer extends
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the model asset. * @param modelAssetPath The path to the model asset.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<GestureRecognizer> { modelAssetPath: string): Promise<GestureRecognizer> {
const response = await fetch(modelAssetPath.toString()); return VisionTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
return GestureRecognizer.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
constructor( constructor(
@ -134,6 +133,7 @@ export class GestureRecognizer extends
super(wasmModule, glCanvas); super(wasmModule, glCanvas);
this.options = new GestureRecognizerGraphOptions(); this.options = new GestureRecognizerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());
this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
this.handLandmarksDetectorGraphOptions = this.handLandmarksDetectorGraphOptions =
@ -151,11 +151,11 @@ export class GestureRecognizer extends
this.initDefaults(); this.initDefaults();
} }
protected override get baseOptions(): BaseOptionsProto|undefined { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions(); return this.options.getBaseOptions()!;
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
* Note that either a path to the model asset or a model buffer needs to * Note that either a path to the model asset or a model buffer needs to
* be provided (via `baseOptions`). * be provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> { handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> {
const landmarker = await VisionTaskRunner.createInstance( return VisionTaskRunner.createInstance(
HandLandmarker, /* initializeCanvas= */ true, wasmFileset); HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
await landmarker.setOptions(handLandmarkerOptions); handLandmarkerOptions);
return landmarker;
} }
/** /**
@ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<HandLandmarker> { modelAssetBuffer: Uint8Array): Promise<HandLandmarker> {
return HandLandmarker.createFromOptions( return VisionTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the model asset. * @param modelAssetPath The path to the model asset.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<HandLandmarker> { modelAssetPath: string): Promise<HandLandmarker> {
const response = await fetch(modelAssetPath.toString()); return VisionTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
return HandLandmarker.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
constructor( constructor(
@ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
super(wasmModule, glCanvas); super(wasmModule, glCanvas);
this.options = new HandLandmarkerGraphOptions(); this.options = new HandLandmarkerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());
this.handLandmarksDetectorGraphOptions = this.handLandmarksDetectorGraphOptions =
new HandLandmarksDetectorGraphOptions(); new HandLandmarksDetectorGraphOptions();
this.options.setHandLandmarksDetectorGraphOptions( this.options.setHandLandmarksDetectorGraphOptions(
@ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
this.initDefaults(); this.initDefaults();
} }
protected override get baseOptions(): BaseOptionsProto|undefined { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions(); return this.options.getBaseOptions()!;
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageClassifierOptions} from './image_classifier_options'; import {ImageClassifierOptions} from './image_classifier_options';
@ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
* that either a path to the model asset or a model buffer needs to be * that either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions):
Promise<ImageClassifier> { Promise<ImageClassifier> {
const classifier = await VisionTaskRunner.createInstance( return VisionTaskRunner.createInstance(
ImageClassifier, /* initializeCanvas= */ true, wasmFileset); ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
await classifier.setOptions(imageClassifierOptions); imageClassifierOptions);
return classifier;
} }
/** /**
@ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> { modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
return ImageClassifier.createFromOptions( return VisionTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the model asset. * @param modelAssetPath The path to the model asset.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageClassifier> { modelAssetPath: string): Promise<ImageClassifier> {
const response = await fetch(modelAssetPath.toString()); return VisionTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
return ImageClassifier.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined { constructor(
return this.options.getBaseOptions(); wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageEmbedderOptions} from './image_embedder_options'; import {ImageEmbedderOptions} from './image_embedder_options';
@ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
* either a path to the TFLite model or the model itself needs to be * either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> { imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> {
const embedder = await VisionTaskRunner.createInstance( return VisionTaskRunner.createInstance(
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
await embedder.setOptions(imageEmbedderOptions); imageEmbedderOptions);
return embedder;
} }
/** /**
@ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> { modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> {
return ImageEmbedder.createFromOptions( return VisionTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
* Wasm binary and its loader. * Wasm binary and its loader.
* @param modelAssetPath The path to the TFLite model. * @param modelAssetPath The path to the TFLite model.
*/ */
static async createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageEmbedder> { modelAssetPath: string): Promise<ImageEmbedder> {
const response = await fetch(modelAssetPath.toString()); return VisionTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
return ImageEmbedder.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined { constructor(
return this.options.getBaseOptions(); wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }

View File

@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ObjectDetectorOptions} from './object_detector_options'; import {ObjectDetectorOptions} from './object_detector_options';
@ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
* either a path to the model asset or a model buffer needs to be * either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`). * provided (via `baseOptions`).
*/ */
static async createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> { objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
const detector = await VisionTaskRunner.createInstance( return VisionTaskRunner.createInstance(
ObjectDetector, /* initializeCanvas= */ true, wasmFileset); ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
await detector.setOptions(objectDetectorOptions); objectDetectorOptions);
return detector;
} }
/** /**
@ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> { modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
return ObjectDetector.createFromOptions( return VisionTaskRunner.createInstance(
wasmFileset, {baseOptions: {modelAssetBuffer}}); ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
static async createFromModelPath( static async createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ObjectDetector> { modelAssetPath: string): Promise<ObjectDetector> {
const response = await fetch(modelAssetPath.toString()); return VisionTaskRunner.createInstance(
const graphData = await response.arrayBuffer(); ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
return ObjectDetector.createFromModelBuffer( {baseOptions: {modelAssetPath}});
wasmFileset, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined { constructor(
return this.options.getBaseOptions(); wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override set baseOptions(proto: BaseOptionsProto|undefined) { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto); this.options.setBaseOptions(proto);
} }