mediapipe/mediapipe/tasks/web/text/text_embedder/text_embedder.ts
Sebastian Schmidt 496720308c Migrate remaining MP Tasks Libraries to ts_declarations
PiperOrigin-RevId: 488752799
2022-11-15 14:08:15 -08:00

176 lines
6.8 KiB
TypeScript

/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {TextEmbedderOptions} from './text_embedder_options';
import {TextEmbedderResult} from './text_embedder_result';
export * from './text_embedder_options';
export * from './text_embedder_result';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
const INPUT_STREAM = 'text_in';
const EMBEDDINGS_STREAM = 'embeddings_out';
const TEXT_EMBEDDER_CALCULATOR =
'mediapipe.tasks.text.text_embedder.TextEmbedderGraph';
/**
* Performs embedding extraction on text.
*/
export class TextEmbedder extends TaskRunner {
private embeddingResult: TextEmbedderResult = {embeddings: []};
private readonly options = new TextEmbedderGraphOptionsProto();
/**
* Initializes the Wasm runtime and creates a new text embedder from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param textEmbedderOptions The options for the text embedder. Note that
* either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file we load is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const embedder = await createMediaPipeLib(
TextEmbedder, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await embedder.setOptions(textEmbedderOptions);
return embedder;
}
/**
* Initializes the Wasm runtime and creates a new text embedder based on the
* provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the TFLite model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<TextEmbedder> {
return TextEmbedder.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new text embedder based on the
* path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the TFLite model.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<TextEmbedder> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return TextEmbedder.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
/**
* Sets new options for the text embedder.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the text embedder.
*/
async setOptions(options: TextEmbedderOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setEmbedderOptions(convertEmbedderOptionsToProto(
options, this.options.getEmbedderOptions()));
this.refreshGraph();
}
/**
* Performs embeding extraction on the provided text and waits synchronously
* for the response.
*
* @param text The text to process.
* @return The embedding resuls of the text
*/
embed(text: string): TextEmbedderResult {
// Get text embeddings by running our MediaPipe graph.
this.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing();
return this.embeddingResult;
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(EMBEDDINGS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
TextEmbedderGraphOptionsProto.ext, this.options);
const embedderNode = new CalculatorGraphConfig.Node();
embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR);
embedderNode.addInputStream('TEXT:' + INPUT_STREAM);
embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM);
embedderNode.setOptions(calculatorOptions);
graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}