Create AudioTaskRunner

PiperOrigin-RevId: 489613573
This commit is contained in:
Sebastian Schmidt 2022-11-18 19:53:21 -08:00 committed by Copybara-Service
parent eb8ef1ace0
commit e853f04b79
6 changed files with 98 additions and 53 deletions

View File

@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable {
defaultSampleRate = sampleRate; defaultSampleRate = sampleRate;
} }
} }
/** /**
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener. * available in the user-defined result listener.

View File

@ -17,14 +17,14 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/audio/core:audio_task_runner",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )

View File

@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/** Performs audio classification. */ /** Performs audio classification. */
export class AudioClassifier extends TaskRunner { export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
private classificationResults: AudioClassifierResult[] = []; private classificationResults: AudioClassifierResult[] = [];
private defaultSampleRate = 48000;
private readonly options = new AudioClassifierGraphOptions(); private readonly options = new AudioClassifierGraphOptions();
/** /**
@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner {
wasmLoaderOptions, new Uint8Array(graphData)); wasmLoaderOptions, new Uint8Array(graphData));
} }
protected override get baseOptions(): BaseOptionsProto|undefined {
return this.options.getBaseOptions();
}
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
this.options.setBaseOptions(proto);
}
/** /**
* Sets new options for the audio classifier. * Sets new options for the audio classifier.
* *
@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner {
* *
* @param options The options for the audio classifier. * @param options The options for the audio classifier.
*/ */
async setOptions(options: AudioClassifierOptions): Promise<void> { override async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) { await super.setOptions(options);
const baseOptionsProto = await convertBaseOptionsToProto(
options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setClassifierOptions(convertClassifierOptionsToProto( this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions())); options, this.options.getClassifierOptions()));
this.refreshGraph(); this.refreshGraph();
} }
/** /**
* Sets the sample rate for all calls to `classify()` that omit an explicit * Performs audio classification on the provided audio clip and waits
* sample rate. `48000` is used as a default if this method is not called.
*
* @param sampleRate A sample rate (e.g. `44100`).
*/
setDefaultSampleRate(sampleRate: number) {
this.defaultSampleRate = sampleRate;
}
/**
* Performs audio classification on the provided audio data and waits
* synchronously for the response. * synchronously for the response.
* *
* @param audioData An array of raw audio capture data, like * @param audioData An array of raw audio capture data, like from a call to
* from a call to getChannelData on an AudioBuffer. * `getChannelData()` on an AudioBuffer.
* @param sampleRate The sample rate in Hz of the provided audio data. If not * @param sampleRate The sample rate in Hz of the provided audio data. If not
* set, defaults to the sample rate set via `setDefaultSampleRate()` or * set, defaults to the sample rate set via `setDefaultSampleRate()` or
* `48000` if no custom default was set. * `48000` if no custom default was set.
@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner {
*/ */
classify(audioData: Float32Array, sampleRate?: number): classify(audioData: Float32Array, sampleRate?: number):
AudioClassifierResult[] { AudioClassifierResult[] {
sampleRate = sampleRate ?? this.defaultSampleRate; return this.processAudioClip(audioData, sampleRate);
}
/** Sends an audio package to the graph and returns the classifications. */
protected override process(
audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioClassifierResult[] {
// Configures the number of samples in the WASM layer. We re-configure the // Configures the number of samples in the WASM layer. We re-configure the
// number of samples and the sample rate for every frame, but ignore other // number of samples and the sample rate for every frame, but ignore other
// side effects of this function (such as sending the input side packet and // side effects of this function (such as sending the input side packet and
// the input stream header). // the input stream header).
this.configureAudio( this.configureAudio(
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
const timestamp = performance.now(); this.addAudioToStream(audioData, timestampMs);
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
this.addAudioToStream(audioData, timestamp);
this.classificationResults = []; this.classificationResults = [];
this.finishProcessing(); this.finishProcessing();

View File

@ -1,6 +1,6 @@
# This package contains options shared by all MediaPipe Audio Tasks for Web. # This package contains options shared by all MediaPipe Audio Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"]) package(default_visibility = ["//mediapipe/tasks:internal"])
@ -11,3 +11,15 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
], ],
) )
mediapipe_ts_library(
name = "audio_task_runner",
srcs = ["audio_task_runner.ts"],
deps = [
":audio_task_options",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
],
)

View File

@ -16,29 +16,8 @@
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {BaseOptions} from '../../../../tasks/web/core/base_options';
/**
* MediaPipe audio task running mode. A MediaPipe audio task can be run with
* two different modes:
* - audio_clips: The mode for running a mediapipe audio task on independent
* audio clips.
* - audio_stream: The mode for running a mediapipe audio task on an audio
* stream, such as from a microphone.
* </ul>
*/
export type RunningMode = 'audio_clips'|'audio_stream';
/** The options for configuring a MediaPipe Audio Task. */ /** The options for configuring a MediaPipe Audio Task. */
export declare interface AudioTaskOptions { export declare interface AudioTaskOptions {
/** Options to configure the loading of the model assets. */ /** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions; baseOptions?: BaseOptions;
/**
* The running mode of the task. Default to the audio_clips mode.
* Audio tasks have two running modes:
* 1) The mode for running a mediapipe audio task on independent
* audio clips.
* 2) The mode for running a mediapipe audio task on an audio
* stream, such as from a microphone.
*/
runningMode?: RunningMode;
} }

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {AudioTaskOptions} from './audio_task_options';
/** Base class for all MediaPipe Audio Tasks. */
export abstract class AudioTaskRunner<T> extends TaskRunner {
protected abstract baseOptions?: BaseOptionsProto|undefined;
private defaultSampleRate = 48000;
/** Configures the shared options of an audio task. */
async setOptions(options: AudioTaskOptions): Promise<void> {
this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
if (options.baseOptions) {
this.baseOptions = await convertBaseOptionsToProto(
options.baseOptions, this.baseOptions);
}
}
/**
* Sets the sample rate for API calls that omit an explicit sample rate.
* `48000` is used as a default if this method is not called.
*
* @param sampleRate A sample rate (e.g. `44100`).
*/
setDefaultSampleRate(sampleRate: number) {
this.defaultSampleRate = sampleRate;
}
/** Sends an audio packet to the graph and awaits results. */
protected abstract process(
audioData: Float32Array, sampleRate: number, timestampMs: number): T;
/** Sends a single audio clip to the graph and awaits results. */
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
return this.process(
audioData, sampleRate ?? this.defaultSampleRate, performance.now());
}
}