diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 1b0e403ff..d1bc480db 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -19,3 +19,8 @@ mediapipe_ts_library( name = "landmark", srcs = ["landmark.d.ts"], ) + +mediapipe_ts_library( + name = "embedding_result", + srcs = ["embedding_result.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/embedding_result.d.ts b/mediapipe/tasks/web/components/containers/embedding_result.d.ts new file mode 100644 index 000000000..e1efd94ce --- /dev/null +++ b/mediapipe/tasks/web/components/containers/embedding_result.d.ts @@ -0,0 +1,66 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * List of embeddings with an optional timestamp. + * + * One and only one of the two 'floatEmbedding' and 'quantizedEmbedding' will + * contain data, based on whether or not the embedder was configured to perform + * scalar quantization. + */ +export interface Embedding { + /** + * Floating-point embedding. Empty if the embedder was configured to perform + * scalar-quantization. + */ + floatEmbedding?: number[]; + + /** + * Scalar-quantized embedding. Empty if the embedder was not configured to + * perform scalar quantization. + */ + quantizedEmbedding?: Uint8Array; + /** + * The index of the classifier head these categories refer to. This is + * useful for multi-head models. + */ + headIndex: number; + + /** + * The name of the classifier head, which is the corresponding tensor + * metadata name. + */ + headName: string; +} + +/** Embedding results for a given embedder model. */ +export interface EmbeddingResult { + /** + * The embedding results for each model head, i.e. one for each output tensor. + */ + embeddings: Embedding[]; + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of + * data corresponding to these results. + * + * This is only used for embedding extraction on time series (e.g. audio + * embedding). In these use cases, the amount of data to process might + * exceed the maximum size that the model can process: to solve this, the + * input data is split into multiple chunks starting at different timestamps. + */ + timestampMs?: number; +} diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index e0d84b632..1b56bf4c9 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -23,9 +23,29 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "embedder_result", + srcs = ["embedder_result.ts"], + deps = [ + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) + +mediapipe_ts_library( + name = "embedder_options", + srcs = ["embedder_options.ts"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_jspb_proto", + "//mediapipe/tasks/web/core:embedder_options", + ], +) + mediapipe_ts_library( name = "base_options", - srcs = ["base_options.ts"], + srcs = [ + "base_options.ts", + ], deps = [ "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", diff --git a/mediapipe/tasks/web/components/processors/embedder_options.ts b/mediapipe/tasks/web/components/processors/embedder_options.ts new file mode 100644 index 000000000..f000dbd64 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_options.ts @@ -0,0 +1,46 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions as EmbedderOptionsProto} from '../../../../tasks/cc/components/processors/proto/embedder_options_pb'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** + * Converts a EmbedderOptions object to its Proto representation, optionally + * based on existing definition. + * @param options The options object to convert to a Proto. Only options that + * are expliclty provided are set. + * @param baseOptions A base object that options can be merged into. + */ +export function convertEmbedderOptionsToProto( + options: EmbedderOptions, + baseOptions?: EmbedderOptionsProto): EmbedderOptionsProto { + const embedderOptions = + baseOptions ? baseOptions.clone() : new EmbedderOptionsProto(); + + if (options.l2Normalize !== undefined) { + embedderOptions.setL2Normalize(options.l2Normalize); + } else if ('l2Normalize' in options) { // Check for undefined + embedderOptions.clearL2Normalize(); + } + + if (options.quantize !== undefined) { + embedderOptions.setQuantize(options.quantize); + } else if ('quantize' in options) { // Check for undefined + embedderOptions.clearQuantize(); + } + + return embedderOptions; +} diff --git a/mediapipe/tasks/web/components/processors/embedder_result.ts b/mediapipe/tasks/web/components/processors/embedder_result.ts new file mode 100644 index 000000000..285afe68a --- /dev/null +++ b/mediapipe/tasks/web/components/processors/embedder_result.ts @@ -0,0 +1,53 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Embedding as EmbeddingProto, EmbeddingResult as EmbeddingResultProto} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {Embedding, EmbeddingResult} from '../../../../tasks/web/components/containers/embedding_result'; + +const DEFAULT_INDEX = -1; + +/** + * Converts an Embedding proto to the Embedding object. + */ +function convertFromEmbeddingsProto(source: EmbeddingProto): Embedding { + const embedding: Embedding = { + headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, + headName: source.getHeadName() ?? '', + }; + + if (source.hasFloatEmbedding()) { + embedding.floatEmbedding = source.getFloatEmbedding()!.getValuesList(); + } else { + const encodedValue = source.getQuantizedEmbedding()?.getValues() ?? ''; + embedding.quantizedEmbedding = typeof encodedValue == 'string' ? + Uint8Array.from(atob(encodedValue), c => c.charCodeAt(0)) : encodedValue; + } + + return embedding; +} + +/** + * Converts an EmbedderResult proto to an EmbeddingResult object. + */ +export function convertFromEmbeddingResultProto( + embeddingResult: EmbeddingResultProto): EmbeddingResult { + const result: EmbeddingResult = { + embeddings: embeddingResult.getEmbeddingsList().map( + e => convertFromEmbeddingsProto(e)), + timestampMs: embeddingResult.getTimestampMs(), + }; + return result; +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 4fb57d6c3..edfc1e5c5 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -31,3 +31,11 @@ mediapipe_ts_library( ], deps = [":core"], ) + +mediapipe_ts_library( + name = "embedder_options", + srcs = [ + "embedder_options.d.ts", + ], + deps = [":core"], +) diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts new file mode 100644 index 000000000..78ddad1ae --- /dev/null +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -0,0 +1,39 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../tasks/web/core/base_options'; + +/** Options to configure the MediaPipe Embedder Task */ +export declare interface EmbedderOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * Whether to normalize the returned feature vector with L2 norm. Use this + * option only if the model does not already contain a native L2_NORMALIZATION + * TF Lite Op. In most cases, this is already the case and L2 norm is thus + * achieved through TF Lite inference. + */ + l2Normalize?: boolean|undefined; + + /** + * Whether the returned embedding should be quantized to bytes via scalar + * quantization. Embeddings are implicitly assumed to be unit-norm and + * therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + * the l2_normalize option if this is not the case. + */ + quantize?: boolean|undefined; +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD new file mode 100644 index 000000000..8e397ce6f --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -0,0 +1,32 @@ +# This contains the MediaPipe Text Embedder Task. +# +# This task takes text input and performs embedding +# + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "text_embedder", + srcs = [ + "text_embedder.ts", + "text_embedder_options.d.ts", + "text_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts new file mode 100644 index 000000000..65df5df6a --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -0,0 +1,173 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {TextEmbedderOptions} from './text_embedder_options'; +import {TextEmbedderResult} from './text_embedder_result'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const INPUT_STREAM = 'text_in'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TEXT_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph'; + +/** + * Performs embedding extraction on text. + */ +export class TextEmbedder extends TaskRunner { + private embeddingResult: TextEmbedderResult = {embeddings: []}; + private readonly options = new TextEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new text embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param textEmbedderOptions The options for the text embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + textEmbedderOptions: TextEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + TextEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(textEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return TextEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new text embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return TextEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the text embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the text embedder. + */ + async setOptions(options: TextEmbedderOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + + this.refreshGraph(); + } + + + /** + * Performs embeding extraction on the provided text and waits synchronously + * for the response. + * + * @param text The text to process. + * @return The embedding resuls of the text + */ + embed(text: string): TextEmbedderResult { + // Get text embeddings by running our MediaPipe graph. + this.addStringToStream( + text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.finishProcessing(); + return this.embeddingResult; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('TEXT:' + INPUT_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts new file mode 100644 index 000000000..9af263765 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {EmbedderOptions as TextEmbedderOptions} from '../../../../tasks/web/core/embedder_options'; diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts new file mode 100644 index 000000000..65640b507 --- /dev/null +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as TextEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result';