From e853f04b79bb47e9542f54ba34065de3c5dcbd73 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 19:53:21 -0800 Subject: [PATCH] Create AudioTaskRunner PiperOrigin-RevId: 489613573 --- .../tasks/audio/core/BaseAudioTaskApi.java | 1 + .../tasks/web/audio/audio_classifier/BUILD | 4 +- .../audio_classifier/audio_classifier.ts | 53 ++++++++--------- mediapipe/tasks/web/audio/core/BUILD | 14 ++++- .../web/audio/core/audio_task_options.d.ts | 21 ------- .../tasks/web/audio/core/audio_task_runner.ts | 58 +++++++++++++++++++ 6 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/audio_task_runner.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 8eaf0adcb..2782f8d36 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 9e1fcbc51..498b17845 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -17,14 +17,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 5533b0eaa..0c54a4718 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** @@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the audio classifier. * @@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: AudioClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } - - /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { // Configures the number of samples in the WASM layer. We re-configure the // number of samples and the sample rate for every frame, but ignore other // side effects of this function (such as sending the input side packet and // the input stream header). this.configureAudio( /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index ed60f2435..91ebbf524 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,6 +1,6 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,3 +11,15 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = [ + ":audio_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts index 58a6e55d8..e3068625d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -16,29 +16,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; -/** - * MediaPipe audio task running mode. A MediaPipe audio task can be run with - * two different modes: - * - audio_clips: The mode for running a mediapipe audio task on independent - * audio clips. - * - audio_stream: The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - * - */ -export type RunningMode = 'audio_clips'|'audio_stream'; - /** The options for configuring a MediaPipe Audio Task. */ export declare interface AudioTaskOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; - - /** - * The running mode of the task. Default to the audio_clips mode. - * Audio tasks have two running modes: - * 1) The mode for running a mediapipe audio task on independent - * audio clips. - * 2) The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - */ - runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..ceff3895b --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,58 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; + +import {AudioTaskOptions} from './audio_task_options'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + private defaultSampleRate = 48000; + + /** Configures the shared options of an audio task. */ + async setOptions(options: AudioTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + } +} + +