Move shared code to TaskRunner
PiperOrigin-RevId: 492534879
This commit is contained in:
parent
dabc2af15b
commit
da9587033d
|
@ -25,7 +25,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
"//mediapipe/tasks/web/core:task_runner",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +36,6 @@ mediapipe_ts_declaration(
|
||||||
"audio_classifier_result.d.ts",
|
"audio_classifier_result.d.ts",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/audio/core:audio_task_options",
|
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
|
|
|
@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
||||||
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
|
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
|
||||||
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {AudioClassifierOptions} from './audio_classifier_options';
|
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||||
|
@ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
* that either a path to the model asset or a model buffer needs to be
|
* that either a path to the model asset or a model buffer needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions):
|
wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions):
|
||||||
Promise<AudioClassifier> {
|
Promise<AudioClassifier> {
|
||||||
const classifier = await TaskRunner.createInstance(
|
return AudioTaskRunner.createInstance(
|
||||||
AudioClassifier, /* initializeCanvas= */ false, wasmFileset);
|
AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
await classifier.setOptions(audioClassifierOptions);
|
audioClassifierOptions);
|
||||||
return classifier;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
|
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
|
||||||
return AudioClassifier.createFromOptions(
|
return AudioTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the model asset.
|
* @param modelAssetPath The path to the model asset.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<AudioClassifier> {
|
modelAssetPath: string): Promise<AudioClassifier> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return AudioTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
AudioClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
return AudioClassifier.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
constructor(
|
||||||
return this.options.getBaseOptions();
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options';
|
|
||||||
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||||
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
/** Options to configure the MediaPipe Audio Classifier Task */
|
/** Options to configure the MediaPipe Audio Classifier Task */
|
||||||
export declare interface AudioClassifierOptions extends ClassifierOptions,
|
export declare interface AudioClassifierOptions extends ClassifierOptions,
|
||||||
AudioTaskOptions {}
|
TaskRunnerOptions {}
|
||||||
|
|
|
@ -36,7 +36,6 @@ mediapipe_ts_declaration(
|
||||||
"audio_embedder_result.d.ts",
|
"audio_embedder_result.d.ts",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/audio/core:audio_task_options",
|
|
||||||
"//mediapipe/tasks/web/components/containers:embedding_result",
|
"//mediapipe/tasks/web/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:embedder_options",
|
"//mediapipe/tasks/web/core:embedder_options",
|
||||||
|
|
|
@ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr
|
||||||
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
|
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
|
||||||
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
|
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {AudioEmbedderOptions} from './audio_embedder_options';
|
import {AudioEmbedderOptions} from './audio_embedder_options';
|
||||||
|
@ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
* either a path to the TFLite model or the model itself needs to be
|
* either a path to the TFLite model or the model itself needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> {
|
audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> {
|
||||||
// Create a file locator based on the loader options
|
return AudioTaskRunner.createInstance(
|
||||||
const fileLocator: FileLocator = {
|
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
locateFile() {
|
audioEmbedderOptions);
|
||||||
// The only file we load is the Wasm binary
|
|
||||||
return wasmFileset.wasmBinaryPath.toString();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const embedder = await createMediaPipeLib(
|
|
||||||
AudioEmbedder, wasmFileset.wasmLoaderPath,
|
|
||||||
/* assetLoaderScript= */ undefined,
|
|
||||||
/* glCanvas= */ undefined, fileLocator);
|
|
||||||
await embedder.setOptions(audioEmbedderOptions);
|
|
||||||
return embedder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> {
|
modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> {
|
||||||
return AudioEmbedder.createFromOptions(
|
return AudioTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the TFLite model.
|
* @param modelAssetPath The path to the TFLite model.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<AudioEmbedder> {
|
modelAssetPath: string): Promise<AudioEmbedder> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return AudioTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
return AudioEmbedder.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
constructor(
|
||||||
return this.options.getBaseOptions();
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options';
|
|
||||||
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
|
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
|
||||||
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
/** Options to configure the MediaPipe Audio Embedder Task */
|
/** Options to configure the MediaPipe Audio Embedder Task */
|
||||||
export declare interface AudioEmbedderOptions extends EmbedderOptions,
|
export declare interface AudioEmbedderOptions extends EmbedderOptions,
|
||||||
AudioTaskOptions {}
|
TaskRunnerOptions {}
|
||||||
|
|
|
@ -1,24 +1,13 @@
|
||||||
# This package contains options shared by all MediaPipe Audio Tasks for Web.
|
# This package contains options shared by all MediaPipe Audio Tasks for Web.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
|
||||||
name = "audio_task_options",
|
|
||||||
srcs = ["audio_task_options.d.ts"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/tasks/web/core",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "audio_task_runner",
|
name = "audio_task_runner",
|
||||||
srcs = ["audio_task_runner.ts"],
|
srcs = ["audio_task_runner.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":audio_task_options",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:task_runner",
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
],
|
],
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
|
||||||
|
|
||||||
/** The options for configuring a MediaPipe Audio Task. */
|
|
||||||
export declare interface AudioTaskOptions {
|
|
||||||
/** Options to configure the loading of the model assets. */
|
|
||||||
baseOptions?: BaseOptions;
|
|
||||||
}
|
|
|
@ -14,26 +14,13 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
|
||||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
import {AudioTaskOptions} from './audio_task_options';
|
|
||||||
|
|
||||||
/** Base class for all MediaPipe Audio Tasks. */
|
/** Base class for all MediaPipe Audio Tasks. */
|
||||||
export abstract class AudioTaskRunner<T> extends TaskRunner {
|
export abstract class AudioTaskRunner<T> extends TaskRunner<TaskRunnerOptions> {
|
||||||
protected abstract baseOptions?: BaseOptionsProto|undefined;
|
|
||||||
private defaultSampleRate = 48000;
|
private defaultSampleRate = 48000;
|
||||||
|
|
||||||
/** Configures the shared options of an audio task. */
|
|
||||||
async setOptions(options: AudioTaskOptions): Promise<void> {
|
|
||||||
this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
|
|
||||||
if (options.baseOptions) {
|
|
||||||
this.baseOptions = await convertBaseOptionsToProto(
|
|
||||||
options.baseOptions, this.baseOptions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the sample rate for API calls that omit an explicit sample rate.
|
* Sets the sample rate for API calls that omit an explicit sample rate.
|
||||||
* `48000` is used as a default if this method is not called.
|
* `48000` is used as a default if this method is not called.
|
||||||
|
|
|
@ -17,7 +17,6 @@ mediapipe_ts_library(
|
||||||
name = "classifier_result",
|
name = "classifier_result",
|
||||||
srcs = ["classifier_result.ts"],
|
srcs = ["classifier_result.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/formats:classification_jspb_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
],
|
],
|
||||||
|
|
|
@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen
|
||||||
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
import {BaseOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
|
@ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
mediapipe_ts_declaration(
|
mediapipe_ts_declaration(
|
||||||
name = "core",
|
name = "core",
|
||||||
srcs = [
|
srcs = [
|
||||||
"base_options.d.ts",
|
"task_runner_options.d.ts",
|
||||||
"wasm_fileset.d.ts",
|
"wasm_fileset.d.ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "task_runner",
|
name = "task_runner",
|
||||||
srcs = [
|
srcs = ["task_runner.ts"],
|
||||||
"task_runner.ts",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||||
|
|
|
@ -14,8 +14,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {BaseOptions} from '../../../tasks/web/core/base_options';
|
|
||||||
|
|
||||||
/** Options to configure a MediaPipe Classifier Task. */
|
/** Options to configure a MediaPipe Classifier Task. */
|
||||||
export declare interface ClassifierOptions {
|
export declare interface ClassifierOptions {
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -14,8 +14,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {BaseOptions} from '../../../tasks/web/core/base_options';
|
|
||||||
|
|
||||||
/** Options to configure a MediaPipe Embedder Task */
|
/** Options to configure a MediaPipe Embedder Task */
|
||||||
export declare interface EmbedderOptions {
|
export declare interface EmbedderOptions {
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options';
|
||||||
|
import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
|
||||||
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||||
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
|
||||||
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||||
|
@ -28,7 +31,9 @@ const WasmMediaPipeImageLib =
|
||||||
SupportModelResourcesGraphService(SupportImage(GraphRunner));
|
SupportModelResourcesGraphService(SupportImage(GraphRunner));
|
||||||
|
|
||||||
/** Base class for all MediaPipe Tasks. */
|
/** Base class for all MediaPipe Tasks. */
|
||||||
export abstract class TaskRunner extends WasmMediaPipeImageLib {
|
export abstract class TaskRunner<O extends TaskRunnerOptions> extends
|
||||||
|
WasmMediaPipeImageLib {
|
||||||
|
protected abstract baseOptions: BaseOptionsProto;
|
||||||
private processingErrors: Error[] = [];
|
private processingErrors: Error[] = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
|
||||||
* supported and loads the relevant WASM binary.
|
* supported and loads the relevant WASM binary.
|
||||||
* @return A fully instantiated instance of `T`.
|
* @return A fully instantiated instance of `T`.
|
||||||
*/
|
*/
|
||||||
protected static async createInstance<T extends TaskRunner>(
|
protected static async createInstance<T extends TaskRunner<O>,
|
||||||
|
O extends TaskRunnerOptions>(
|
||||||
type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean,
|
type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean,
|
||||||
fileset: WasmFileset): Promise<T> {
|
fileset: WasmFileset, options: O): Promise<T> {
|
||||||
const fileLocator: FileLocator = {
|
const fileLocator: FileLocator = {
|
||||||
locateFile() {
|
locateFile() {
|
||||||
// The only file loaded with this mechanism is the Wasm binary
|
// The only file loaded with this mechanism is the Wasm binary
|
||||||
|
@ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (initializeCanvas) {
|
// Initialize a canvas if requested. If OffscreenCanvas is availble, we
|
||||||
// Fall back to an OffscreenCanvas created by the GraphRunner if
|
// let the graph runner initialize it by passing `undefined`.
|
||||||
// OffscreenCanvas is available
|
const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ?
|
||||||
const canvas = typeof OffscreenCanvas === 'undefined' ?
|
document.createElement('canvas') :
|
||||||
document.createElement('canvas') :
|
undefined) :
|
||||||
undefined;
|
null;
|
||||||
return createMediaPipeLib(
|
const instance = await createMediaPipeLib(
|
||||||
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
|
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
|
||||||
} else {
|
await instance.setOptions(options);
|
||||||
return createMediaPipeLib(
|
return instance;
|
||||||
type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null,
|
|
||||||
fileLocator);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
|
@ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib {
|
||||||
this.registerModelResourcesGraphService();
|
this.registerModelResourcesGraphService();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Configures the shared options of a MediaPipe Task. */
|
||||||
|
async setOptions(options: O): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
this.baseOptions = await convertBaseOptionsToProto(
|
||||||
|
options.baseOptions, this.baseOptions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
/** Options to configure MediaPipe Tasks in general. */
|
/** Options to configure MediaPipe model loading and processing. */
|
||||||
export declare interface BaseOptions {
|
export declare interface BaseOptions {
|
||||||
/**
|
/**
|
||||||
* The model path to the model asset file. Only one of `modelAssetPath` or
|
* The model path to the model asset file. Only one of `modelAssetPath` or
|
||||||
|
@ -33,3 +33,9 @@ export declare interface BaseOptions {
|
||||||
/** Overrides the default backend to use for the provided model. */
|
/** Overrides the default backend to use for the provided model. */
|
||||||
delegate?: 'cpu'|'gpu'|undefined;
|
delegate?: 'cpu'|'gpu'|undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Options to configure MediaPipe Tasks in general. */
|
||||||
|
export declare interface TaskRunnerOptions {
|
||||||
|
/** Options to configure the loading of the model assets. */
|
||||||
|
baseOptions?: BaseOptions;
|
||||||
|
}
|
|
@ -1,11 +0,0 @@
|
||||||
# This package contains options shared by all MediaPipe Texxt Tasks for Web.
|
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration")
|
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
|
||||||
name = "text_task_options",
|
|
||||||
srcs = ["text_task_options.d.ts"],
|
|
||||||
deps = ["//mediapipe/tasks/web/core"],
|
|
||||||
)
|
|
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
|
||||||
|
|
||||||
/** The options for configuring a MediaPipe Text task. */
|
|
||||||
export declare interface TextTaskOptions {
|
|
||||||
/** Options to configure the loading of the model assets. */
|
|
||||||
baseOptions?: BaseOptions;
|
|
||||||
}
|
|
|
@ -17,15 +17,16 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/framework:calculator_jspb_proto",
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
"//mediapipe/tasks/web/core:task_runner",
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ mediapipe_ts_declaration(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
"//mediapipe/tasks/web/text/core:text_task_options",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,12 +17,13 @@
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
|
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
|
||||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
|
||||||
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {TextClassifierOptions} from './text_classifier_options';
|
import {TextClassifierOptions} from './text_classifier_options';
|
||||||
|
@ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH =
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
/** Performs Natural Language classification. */
|
/** Performs Natural Language classification. */
|
||||||
export class TextClassifier extends TaskRunner {
|
export class TextClassifier extends TaskRunner<TextClassifierOptions> {
|
||||||
private classificationResult: TextClassifierResult = {classifications: []};
|
private classificationResult: TextClassifierResult = {classifications: []};
|
||||||
private readonly options = new TextClassifierGraphOptions();
|
private readonly options = new TextClassifierGraphOptions();
|
||||||
|
|
||||||
|
@ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner {
|
||||||
* either a path to the TFLite model or the model itself needs to be
|
* either a path to the TFLite model or the model itself needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
|
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
|
||||||
const classifier = await TaskRunner.createInstance(
|
return TaskRunner.createInstance(
|
||||||
TextClassifier, /* initializeCanvas= */ false, wasmFileset);
|
TextClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
await classifier.setOptions(textClassifierOptions);
|
textClassifierOptions);
|
||||||
return classifier;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
|
modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
|
||||||
return TextClassifier.createFromOptions(
|
return TaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
TextClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the model asset.
|
* @param modelAssetPath The path to the model asset.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<TextClassifier> {
|
modelAssetPath: string): Promise<TextClassifier> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return TaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
TextClassifier, /* initializeCanvas= */ false, wasmFileset,
|
||||||
return TextClassifier.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
}
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text classifier.
|
* @param options The options for the text classifier.
|
||||||
*/
|
*/
|
||||||
async setOptions(options: TextClassifierOptions): Promise<void> {
|
override async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
await super.setOptions(options);
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
|
||||||
options.baseOptions, this.options.getBaseOptions());
|
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
options, this.options.getClassifierOptions()));
|
options, this.options.getClassifierOptions()));
|
||||||
this.refreshGraph();
|
this.refreshGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
|
this.options.setBaseOptions(proto);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs Natural Language classification on the provided text and waits
|
* Performs Natural Language classification on the provided text and waits
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||||
import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options';
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
/** Options to configure the MediaPipe Text Classifier Task */
|
/** Options to configure the MediaPipe Text Classifier Task */
|
||||||
export declare interface TextClassifierOptions extends ClassifierOptions,
|
export declare interface TextClassifierOptions extends ClassifierOptions,
|
||||||
TextTaskOptions {}
|
TaskRunnerOptions {}
|
||||||
|
|
|
@ -17,15 +17,16 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/framework:calculator_jspb_proto",
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:embedding_result",
|
"//mediapipe/tasks/web/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
|
||||||
"//mediapipe/tasks/web/components/processors:embedder_options",
|
"//mediapipe/tasks/web/components/processors:embedder_options",
|
||||||
"//mediapipe/tasks/web/components/processors:embedder_result",
|
"//mediapipe/tasks/web/components/processors:embedder_result",
|
||||||
"//mediapipe/tasks/web/components/utils:cosine_similarity",
|
"//mediapipe/tasks/web/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:embedder_options",
|
"//mediapipe/tasks/web/core:embedder_options",
|
||||||
"//mediapipe/tasks/web/core:task_runner",
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,6 +40,5 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/components/containers:embedding_result",
|
"//mediapipe/tasks/web/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:embedder_options",
|
"//mediapipe/tasks/web/core:embedder_options",
|
||||||
"//mediapipe/tasks/web/text/core:text_task_options",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,14 +17,15 @@
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
|
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb';
|
import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb';
|
||||||
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
|
import {Embedding} from '../../../../tasks/web/components/containers/embedding_result';
|
||||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
|
||||||
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
|
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
|
||||||
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
|
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
|
||||||
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {TextEmbedderOptions} from './text_embedder_options';
|
import {TextEmbedderOptions} from './text_embedder_options';
|
||||||
|
@ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR =
|
||||||
/**
|
/**
|
||||||
* Performs embedding extraction on text.
|
* Performs embedding extraction on text.
|
||||||
*/
|
*/
|
||||||
export class TextEmbedder extends TaskRunner {
|
export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
|
||||||
private embeddingResult: TextEmbedderResult = {embeddings: []};
|
private embeddingResult: TextEmbedderResult = {embeddings: []};
|
||||||
private readonly options = new TextEmbedderGraphOptionsProto();
|
private readonly options = new TextEmbedderGraphOptionsProto();
|
||||||
|
|
||||||
|
@ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner {
|
||||||
* either a path to the TFLite model or the model itself needs to be
|
* either a path to the TFLite model or the model itself needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> {
|
textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> {
|
||||||
const embedder = await TaskRunner.createInstance(
|
return TaskRunner.createInstance(
|
||||||
TextEmbedder, /* initializeCanvas= */ false, wasmFileset);
|
TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
await embedder.setOptions(textEmbedderOptions);
|
textEmbedderOptions);
|
||||||
return embedder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<TextEmbedder> {
|
modelAssetBuffer: Uint8Array): Promise<TextEmbedder> {
|
||||||
return TextEmbedder.createFromOptions(
|
return TaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the TFLite model.
|
* @param modelAssetPath The path to the TFLite model.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<TextEmbedder> {
|
modelAssetPath: string): Promise<TextEmbedder> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return TaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
TextEmbedder, /* initializeCanvas= */ false, wasmFileset,
|
||||||
return TextEmbedder.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
}
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner {
|
||||||
*
|
*
|
||||||
* @param options The options for the text embedder.
|
* @param options The options for the text embedder.
|
||||||
*/
|
*/
|
||||||
async setOptions(options: TextEmbedderOptions): Promise<void> {
|
override async setOptions(options: TextEmbedderOptions): Promise<void> {
|
||||||
if (options.baseOptions) {
|
await super.setOptions(options);
|
||||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
|
||||||
options.baseOptions, this.options.getBaseOptions());
|
|
||||||
this.options.setBaseOptions(baseOptionsProto);
|
|
||||||
}
|
|
||||||
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
|
||||||
options, this.options.getEmbedderOptions()));
|
options, this.options.getEmbedderOptions()));
|
||||||
this.refreshGraph();
|
this.refreshGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
|
this.options.setBaseOptions(proto);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs embeding extraction on the provided text and waits synchronously
|
* Performs embeding extraction on the provided text and waits synchronously
|
||||||
* for the response.
|
* for the response.
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
|
import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options';
|
||||||
import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options';
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
/** Options to configure the MediaPipe Text Embedder Task */
|
/** Options to configure the MediaPipe Text Embedder Task */
|
||||||
export declare interface TextEmbedderOptions extends EmbedderOptions,
|
export declare interface TextEmbedderOptions extends EmbedderOptions,
|
||||||
TextTaskOptions {}
|
TaskRunnerOptions {}
|
||||||
|
|
|
@ -17,8 +17,6 @@ mediapipe_ts_library(
|
||||||
srcs = ["vision_task_runner.ts"],
|
srcs = ["vision_task_runner.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
":vision_task_options",
|
":vision_task_options",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:task_runner",
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The two running modes of a vision task.
|
* The two running modes of a vision task.
|
||||||
|
@ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
*/
|
*/
|
||||||
export type RunningMode = 'image'|'video';
|
export type RunningMode = 'image'|'video';
|
||||||
|
|
||||||
|
|
||||||
/** The options for configuring a MediaPipe vision task. */
|
/** The options for configuring a MediaPipe vision task. */
|
||||||
export declare interface VisionTaskOptions {
|
export declare interface VisionTaskOptions extends TaskRunnerOptions {
|
||||||
/** Options to configure the loading of the model assets. */
|
|
||||||
baseOptions?: BaseOptions;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The running mode of the task. Default to the image mode.
|
* The running mode of the task. Default to the image mode.
|
||||||
* Vision tasks have two running modes:
|
* Vision tasks have two running modes:
|
||||||
|
|
|
@ -14,24 +14,17 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
|
||||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
|
||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
import {VisionTaskOptions} from './vision_task_options';
|
import {VisionTaskOptions} from './vision_task_options';
|
||||||
|
|
||||||
/** Base class for all MediaPipe Vision Tasks. */
|
/** Base class for all MediaPipe Vision Tasks. */
|
||||||
export abstract class VisionTaskRunner<T> extends TaskRunner {
|
export abstract class VisionTaskRunner<T> extends
|
||||||
protected abstract baseOptions?: BaseOptionsProto|undefined;
|
TaskRunner<VisionTaskOptions> {
|
||||||
|
|
||||||
/** Configures the shared options of a vision task. */
|
/** Configures the shared options of a vision task. */
|
||||||
async setOptions(options: VisionTaskOptions): Promise<void> {
|
override async setOptions(options: VisionTaskOptions): Promise<void> {
|
||||||
this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
|
await super.setOptions(options);
|
||||||
if (options.baseOptions) {
|
|
||||||
this.baseOptions = await convertBaseOptionsToProto(
|
|
||||||
options.baseOptions, this.baseOptions);
|
|
||||||
}
|
|
||||||
if ('runningMode' in options) {
|
if ('runningMode' in options) {
|
||||||
const useStreamMode =
|
const useStreamMode =
|
||||||
!!options.runningMode && options.runningMode !== 'image';
|
!!options.runningMode && options.runningMode !== 'image';
|
||||||
|
|
|
@ -88,14 +88,13 @@ export class GestureRecognizer extends
|
||||||
* Note that either a path to the model asset or a model buffer needs to
|
* Note that either a path to the model asset or a model buffer needs to
|
||||||
* be provided (via `baseOptions`).
|
* be provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
gestureRecognizerOptions: GestureRecognizerOptions):
|
gestureRecognizerOptions: GestureRecognizerOptions):
|
||||||
Promise<GestureRecognizer> {
|
Promise<GestureRecognizer> {
|
||||||
const recognizer = await VisionTaskRunner.createInstance(
|
return VisionTaskRunner.createInstance(
|
||||||
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset);
|
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
|
||||||
await recognizer.setOptions(gestureRecognizerOptions);
|
gestureRecognizerOptions);
|
||||||
return recognizer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -108,8 +107,9 @@ export class GestureRecognizer extends
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
|
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
|
||||||
return GestureRecognizer.createFromOptions(
|
return VisionTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -119,13 +119,12 @@ export class GestureRecognizer extends
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the model asset.
|
* @param modelAssetPath The path to the model asset.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<GestureRecognizer> {
|
modelAssetPath: string): Promise<GestureRecognizer> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return VisionTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset,
|
||||||
return GestureRecognizer.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
|
@ -134,6 +133,7 @@ export class GestureRecognizer extends
|
||||||
super(wasmModule, glCanvas);
|
super(wasmModule, glCanvas);
|
||||||
|
|
||||||
this.options = new GestureRecognizerGraphOptions();
|
this.options = new GestureRecognizerGraphOptions();
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
|
this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
|
||||||
this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
|
this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
|
||||||
this.handLandmarksDetectorGraphOptions =
|
this.handLandmarksDetectorGraphOptions =
|
||||||
|
@ -151,11 +151,11 @@ export class GestureRecognizer extends
|
||||||
this.initDefaults();
|
this.initDefaults();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
return this.options.getBaseOptions();
|
return this.options.getBaseOptions()!;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
* Note that either a path to the model asset or a model buffer needs to
|
* Note that either a path to the model asset or a model buffer needs to
|
||||||
* be provided (via `baseOptions`).
|
* be provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> {
|
handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> {
|
||||||
const landmarker = await VisionTaskRunner.createInstance(
|
return VisionTaskRunner.createInstance(
|
||||||
HandLandmarker, /* initializeCanvas= */ true, wasmFileset);
|
HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
|
||||||
await landmarker.setOptions(handLandmarkerOptions);
|
handLandmarkerOptions);
|
||||||
return landmarker;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<HandLandmarker> {
|
modelAssetBuffer: Uint8Array): Promise<HandLandmarker> {
|
||||||
return HandLandmarker.createFromOptions(
|
return VisionTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the model asset.
|
* @param modelAssetPath The path to the model asset.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<HandLandmarker> {
|
modelAssetPath: string): Promise<HandLandmarker> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return VisionTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
HandLandmarker, /* initializeCanvas= */ true, wasmFileset,
|
||||||
return HandLandmarker.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
|
@ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
super(wasmModule, glCanvas);
|
super(wasmModule, glCanvas);
|
||||||
|
|
||||||
this.options = new HandLandmarkerGraphOptions();
|
this.options = new HandLandmarkerGraphOptions();
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
this.handLandmarksDetectorGraphOptions =
|
this.handLandmarksDetectorGraphOptions =
|
||||||
new HandLandmarksDetectorGraphOptions();
|
new HandLandmarksDetectorGraphOptions();
|
||||||
this.options.setHandLandmarksDetectorGraphOptions(
|
this.options.setHandLandmarksDetectorGraphOptions(
|
||||||
|
@ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
||||||
this.initDefaults();
|
this.initDefaults();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
return this.options.getBaseOptions();
|
return this.options.getBaseOptions()!;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
|
||||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageClassifierOptions} from './image_classifier_options';
|
import {ImageClassifierOptions} from './image_classifier_options';
|
||||||
|
@ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
* that either a path to the model asset or a model buffer needs to be
|
* that either a path to the model asset or a model buffer needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions):
|
wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions):
|
||||||
Promise<ImageClassifier> {
|
Promise<ImageClassifier> {
|
||||||
const classifier = await VisionTaskRunner.createInstance(
|
return VisionTaskRunner.createInstance(
|
||||||
ImageClassifier, /* initializeCanvas= */ true, wasmFileset);
|
ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
|
||||||
await classifier.setOptions(imageClassifierOptions);
|
imageClassifierOptions);
|
||||||
return classifier;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
|
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
|
||||||
return ImageClassifier.createFromOptions(
|
return VisionTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the model asset.
|
* @param modelAssetPath The path to the model asset.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<ImageClassifier> {
|
modelAssetPath: string): Promise<ImageClassifier> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return VisionTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
ImageClassifier, /* initializeCanvas= */ true, wasmFileset,
|
||||||
return ImageClassifier.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
constructor(
|
||||||
return this.options.getBaseOptions();
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/
|
||||||
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageEmbedderOptions} from './image_embedder_options';
|
import {ImageEmbedderOptions} from './image_embedder_options';
|
||||||
|
@ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
* either a path to the TFLite model or the model itself needs to be
|
* either a path to the TFLite model or the model itself needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> {
|
imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> {
|
||||||
const embedder = await VisionTaskRunner.createInstance(
|
return VisionTaskRunner.createInstance(
|
||||||
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset);
|
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
|
||||||
await embedder.setOptions(imageEmbedderOptions);
|
imageEmbedderOptions);
|
||||||
return embedder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> {
|
modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> {
|
||||||
return ImageEmbedder.createFromOptions(
|
return VisionTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
||||||
* Wasm binary and its loader.
|
* Wasm binary and its loader.
|
||||||
* @param modelAssetPath The path to the TFLite model.
|
* @param modelAssetPath The path to the TFLite model.
|
||||||
*/
|
*/
|
||||||
static async createFromModelPath(
|
static createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<ImageEmbedder> {
|
modelAssetPath: string): Promise<ImageEmbedder> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return VisionTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset,
|
||||||
return ImageEmbedder.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
constructor(
|
||||||
return this.options.getBaseOptions();
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
||||||
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
|
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ObjectDetectorOptions} from './object_detector_options';
|
import {ObjectDetectorOptions} from './object_detector_options';
|
||||||
|
@ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
* either a path to the model asset or a model buffer needs to be
|
* either a path to the model asset or a model buffer needs to be
|
||||||
* provided (via `baseOptions`).
|
* provided (via `baseOptions`).
|
||||||
*/
|
*/
|
||||||
static async createFromOptions(
|
static createFromOptions(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
|
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
|
||||||
const detector = await VisionTaskRunner.createInstance(
|
return VisionTaskRunner.createInstance(
|
||||||
ObjectDetector, /* initializeCanvas= */ true, wasmFileset);
|
ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
|
||||||
await detector.setOptions(objectDetectorOptions);
|
objectDetectorOptions);
|
||||||
return detector;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
static createFromModelBuffer(
|
static createFromModelBuffer(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
|
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
|
||||||
return ObjectDetector.createFromOptions(
|
return VisionTaskRunner.createInstance(
|
||||||
wasmFileset, {baseOptions: {modelAssetBuffer}});
|
ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
||||||
static async createFromModelPath(
|
static async createFromModelPath(
|
||||||
wasmFileset: WasmFileset,
|
wasmFileset: WasmFileset,
|
||||||
modelAssetPath: string): Promise<ObjectDetector> {
|
modelAssetPath: string): Promise<ObjectDetector> {
|
||||||
const response = await fetch(modelAssetPath.toString());
|
return VisionTaskRunner.createInstance(
|
||||||
const graphData = await response.arrayBuffer();
|
ObjectDetector, /* initializeCanvas= */ true, wasmFileset,
|
||||||
return ObjectDetector.createFromModelBuffer(
|
{baseOptions: {modelAssetPath}});
|
||||||
wasmFileset, new Uint8Array(graphData));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
constructor(
|
||||||
return this.options.getBaseOptions();
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(wasmModule, glCanvas);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
this.options.setBaseOptions(proto);
|
this.options.setBaseOptions(proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user