Create AudioTaskRunner
PiperOrigin-RevId: 489613573
This commit is contained in:
parent
eb8ef1ace0
commit
e853f04b79
|
@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable {
|
|||
defaultSampleRate = sampleRate;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
|
|
|
@ -17,14 +17,14 @@ mediapipe_ts_library(
|
|||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/audio/core:audio_task_runner",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/core:task_runner",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
|||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
|
||||
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
|
||||
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
|||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/** Performs audio classification. */
|
||||
export class AudioClassifier extends TaskRunner {
|
||||
export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||
private classificationResults: AudioClassifierResult[] = [];
|
||||
private defaultSampleRate = 48000;
|
||||
private readonly options = new AudioClassifierGraphOptions();
|
||||
|
||||
/**
|
||||
|
@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner {
|
|||
wasmLoaderOptions, new Uint8Array(graphData));
|
||||
}
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto|undefined {
|
||||
return this.options.getBaseOptions();
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto|undefined) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the audio classifier.
|
||||
*
|
||||
|
@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner {
|
|||
*
|
||||
* @param options The options for the audio classifier.
|
||||
*/
|
||||
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
if (options.baseOptions) {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.options.getBaseOptions());
|
||||
this.options.setBaseOptions(baseOptionsProto);
|
||||
}
|
||||
|
||||
override async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||
await super.setOptions(options);
|
||||
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||
options, this.options.getClassifierOptions()));
|
||||
this.refreshGraph();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the sample rate for all calls to `classify()` that omit an explicit
|
||||
* sample rate. `48000` is used as a default if this method is not called.
|
||||
*
|
||||
* @param sampleRate A sample rate (e.g. `44100`).
|
||||
*/
|
||||
setDefaultSampleRate(sampleRate: number) {
|
||||
this.defaultSampleRate = sampleRate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs audio classification on the provided audio data and waits
|
||||
* Performs audio classification on the provided audio clip and waits
|
||||
* synchronously for the response.
|
||||
*
|
||||
* @param audioData An array of raw audio capture data, like
|
||||
* from a call to getChannelData on an AudioBuffer.
|
||||
* @param audioData An array of raw audio capture data, like from a call to
|
||||
* `getChannelData()` on an AudioBuffer.
|
||||
* @param sampleRate The sample rate in Hz of the provided audio data. If not
|
||||
* set, defaults to the sample rate set via `setDefaultSampleRate()` or
|
||||
* `48000` if no custom default was set.
|
||||
|
@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner {
|
|||
*/
|
||||
classify(audioData: Float32Array, sampleRate?: number):
|
||||
AudioClassifierResult[] {
|
||||
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||
return this.processAudioClip(audioData, sampleRate);
|
||||
}
|
||||
|
||||
/** Sends an audio package to the graph and returns the classifications. */
|
||||
protected override process(
|
||||
audioData: Float32Array, sampleRate: number,
|
||||
timestampMs: number): AudioClassifierResult[] {
|
||||
// Configures the number of samples in the WASM layer. We re-configure the
|
||||
// number of samples and the sample rate for every frame, but ignore other
|
||||
// side effects of this function (such as sending the input side packet and
|
||||
// the input stream header).
|
||||
this.configureAudio(
|
||||
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
||||
|
||||
const timestamp = performance.now();
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||
this.addAudioToStream(audioData, timestamp);
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.addAudioToStream(audioData, timestampMs);
|
||||
|
||||
this.classificationResults = [];
|
||||
this.finishProcessing();
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This package contains options shared by all MediaPipe Audio Tasks for Web.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
|
@ -11,3 +11,15 @@ mediapipe_ts_declaration(
|
|||
"//mediapipe/tasks/web/core",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "audio_task_runner",
|
||||
srcs = ["audio_task_runner.ts"],
|
||||
deps = [
|
||||
":audio_task_options",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,29 +16,8 @@
|
|||
|
||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||
|
||||
/**
|
||||
* MediaPipe audio task running mode. A MediaPipe audio task can be run with
|
||||
* two different modes:
|
||||
* - audio_clips: The mode for running a mediapipe audio task on independent
|
||||
* audio clips.
|
||||
* - audio_stream: The mode for running a mediapipe audio task on an audio
|
||||
* stream, such as from a microphone.
|
||||
* </ul>
|
||||
*/
|
||||
export type RunningMode = 'audio_clips'|'audio_stream';
|
||||
|
||||
/** The options for configuring a MediaPipe Audio Task. */
|
||||
export declare interface AudioTaskOptions {
|
||||
/** Options to configure the loading of the model assets. */
|
||||
baseOptions?: BaseOptions;
|
||||
|
||||
/**
|
||||
* The running mode of the task. Default to the audio_clips mode.
|
||||
* Audio tasks have two running modes:
|
||||
* 1) The mode for running a mediapipe audio task on independent
|
||||
* audio clips.
|
||||
* 2) The mode for running a mediapipe audio task on an audio
|
||||
* stream, such as from a microphone.
|
||||
*/
|
||||
runningMode?: RunningMode;
|
||||
}
|
||||
|
|
58
mediapipe/tasks/web/audio/core/audio_task_runner.ts
Normal file
58
mediapipe/tasks/web/audio/core/audio_task_runner.ts
Normal file
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||
|
||||
import {AudioTaskOptions} from './audio_task_options';
|
||||
|
||||
/** Base class for all MediaPipe Audio Tasks. */
|
||||
export abstract class AudioTaskRunner<T> extends TaskRunner {
|
||||
protected abstract baseOptions?: BaseOptionsProto|undefined;
|
||||
private defaultSampleRate = 48000;
|
||||
|
||||
/** Configures the shared options of an audio task. */
|
||||
async setOptions(options: AudioTaskOptions): Promise<void> {
|
||||
this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
|
||||
if (options.baseOptions) {
|
||||
this.baseOptions = await convertBaseOptionsToProto(
|
||||
options.baseOptions, this.baseOptions);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the sample rate for API calls that omit an explicit sample rate.
|
||||
* `48000` is used as a default if this method is not called.
|
||||
*
|
||||
* @param sampleRate A sample rate (e.g. `44100`).
|
||||
*/
|
||||
setDefaultSampleRate(sampleRate: number) {
|
||||
this.defaultSampleRate = sampleRate;
|
||||
}
|
||||
|
||||
/** Sends an audio packet to the graph and awaits results. */
|
||||
protected abstract process(
|
||||
audioData: Float32Array, sampleRate: number, timestampMs: number): T;
|
||||
|
||||
/** Sends a single audio clip to the graph and awaits results. */
|
||||
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
|
||||
return this.process(
|
||||
audioData, sampleRate ?? this.defaultSampleRate, performance.now());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user