diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index e9703e37a..af76a1fe8 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( srcs = ["audio.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 764fd8393..056426f50 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -15,9 +15,11 @@ */ import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; -export {AudioClassifier}; +export {AudioClassifier, AudioEmbedder}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..acd7494d7 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..7d9a994a3 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,43 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..51cb819de --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,211 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot +// be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(audioEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override async setOptions(options: AudioEmbedderOptions): Promise { + await super.setOptions(options); + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + this.refreshGraph(); + } + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + }); + + this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts new file mode 100644 index 000000000..98f412d0f --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts new file mode 100644 index 000000000..13abc28d9 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index a5083b326..17a908f30 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -15,3 +15,4 @@ */ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder';