Create AudioTaskRunner
PiperOrigin-RevId: 489613573
This commit is contained in:
		
							parent
							
								
									eb8ef1ace0
								
							
						
					
					
						commit
						e853f04b79
					
				| 
						 | 
				
			
			@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable {
 | 
			
		|||
      defaultSampleRate = sampleRate;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
 | 
			
		||||
   * available in the user-defined result listener.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,14 +17,14 @@ mediapipe_ts_library(
 | 
			
		|||
        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/web/audio/core:audio_task_runner",
 | 
			
		||||
        "//mediapipe/tasks/web/components/containers:category",
 | 
			
		||||
        "//mediapipe/tasks/web/components/containers:classification_result",
 | 
			
		||||
        "//mediapipe/tasks/web/components/processors:base_options",
 | 
			
		||||
        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
			
		||||
        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
			
		||||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
        "//mediapipe/tasks/web/core:classifier_options",
 | 
			
		||||
        "//mediapipe/tasks/web/core:task_runner",
 | 
			
		||||
        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
			
		|||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
			
		||||
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
 | 
			
		||||
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
			
		||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
			
		||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
			
		||||
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
 | 
			
		||||
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
			
		||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
			
		||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
			
		||||
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
			
		||||
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
 | 
			
		||||
// Placeholder for internal dependency on trusted resource url
 | 
			
		||||
| 
						 | 
				
			
			@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
 | 
			
		|||
// tslint:disable:jspb-use-builder-pattern
 | 
			
		||||
 | 
			
		||||
/** Performs audio classification. */
 | 
			
		||||
export class AudioClassifier extends TaskRunner {
 | 
			
		||||
export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
 | 
			
		||||
  private classificationResults: AudioClassifierResult[] = [];
 | 
			
		||||
  private defaultSampleRate = 48000;
 | 
			
		||||
  private readonly options = new AudioClassifierGraphOptions();
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner {
 | 
			
		|||
        wasmLoaderOptions, new Uint8Array(graphData));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override get baseOptions(): BaseOptionsProto|undefined {
 | 
			
		||||
    return this.options.getBaseOptions();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override set baseOptions(proto: BaseOptionsProto|undefined) {
 | 
			
		||||
    this.options.setBaseOptions(proto);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Sets new options for the audio classifier.
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner {
 | 
			
		|||
   *
 | 
			
		||||
   * @param options The options for the audio classifier.
 | 
			
		||||
   */
 | 
			
		||||
  async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
			
		||||
    if (options.baseOptions) {
 | 
			
		||||
      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
			
		||||
          options.baseOptions, this.options.getBaseOptions());
 | 
			
		||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  override async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
			
		||||
    await super.setOptions(options);
 | 
			
		||||
    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
			
		||||
        options, this.options.getClassifierOptions()));
 | 
			
		||||
    this.refreshGraph();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Sets the sample rate for all calls to `classify()` that omit an explicit
 | 
			
		||||
   * sample rate. `48000` is used as a default if this method is not called.
 | 
			
		||||
   *
 | 
			
		||||
   * @param sampleRate A sample rate (e.g. `44100`).
 | 
			
		||||
   */
 | 
			
		||||
  setDefaultSampleRate(sampleRate: number) {
 | 
			
		||||
    this.defaultSampleRate = sampleRate;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs audio classification on the provided audio data and waits
 | 
			
		||||
   * Performs audio classification on the provided audio clip and waits
 | 
			
		||||
   * synchronously for the response.
 | 
			
		||||
   *
 | 
			
		||||
   * @param audioData An array of raw audio capture data, like
 | 
			
		||||
   *     from a call to getChannelData on an AudioBuffer.
 | 
			
		||||
   * @param audioData An array of raw audio capture data, like from a call to
 | 
			
		||||
   *     `getChannelData()` on an AudioBuffer.
 | 
			
		||||
   * @param sampleRate The sample rate in Hz of the provided audio data. If not
 | 
			
		||||
   *     set, defaults to the sample rate set via `setDefaultSampleRate()` or
 | 
			
		||||
   *     `48000` if no custom default was set.
 | 
			
		||||
| 
						 | 
				
			
			@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner {
 | 
			
		|||
   */
 | 
			
		||||
  classify(audioData: Float32Array, sampleRate?: number):
 | 
			
		||||
      AudioClassifierResult[] {
 | 
			
		||||
    sampleRate = sampleRate ?? this.defaultSampleRate;
 | 
			
		||||
    return this.processAudioClip(audioData, sampleRate);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sends an audio package to the graph and returns the classifications. */
 | 
			
		||||
  protected override process(
 | 
			
		||||
      audioData: Float32Array, sampleRate: number,
 | 
			
		||||
      timestampMs: number): AudioClassifierResult[] {
 | 
			
		||||
    // Configures the number of samples in the WASM layer. We re-configure the
 | 
			
		||||
    // number of samples and the sample rate for every frame, but ignore other
 | 
			
		||||
    // side effects of this function (such as sending the input side packet and
 | 
			
		||||
    // the input stream header).
 | 
			
		||||
    this.configureAudio(
 | 
			
		||||
        /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
 | 
			
		||||
 | 
			
		||||
    const timestamp = performance.now();
 | 
			
		||||
    this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
 | 
			
		||||
    this.addAudioToStream(audioData, timestamp);
 | 
			
		||||
    this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
 | 
			
		||||
    this.addAudioToStream(audioData, timestampMs);
 | 
			
		||||
 | 
			
		||||
    this.classificationResults = [];
 | 
			
		||||
    this.finishProcessing();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,6 @@
 | 
			
		|||
# This package contains options shared by all MediaPipe Audio Tasks for Web.
 | 
			
		||||
 | 
			
		||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration")
 | 
			
		||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -11,3 +11,15 @@ mediapipe_ts_declaration(
 | 
			
		|||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "audio_task_runner",
 | 
			
		||||
    srcs = ["audio_task_runner.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":audio_task_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/web/components/processors:base_options",
 | 
			
		||||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
        "//mediapipe/tasks/web/core:task_runner",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,29 +16,8 @@
 | 
			
		|||
 | 
			
		||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * MediaPipe audio task running mode. A MediaPipe audio task can be run with
 | 
			
		||||
 * two different modes:
 | 
			
		||||
 * - audio_clips:  The mode for running a mediapipe audio task on independent
 | 
			
		||||
 *                 audio clips.
 | 
			
		||||
 * - audio_stream: The mode for running a mediapipe audio task on an audio
 | 
			
		||||
 *                 stream, such as from a microphone.
 | 
			
		||||
 * </ul>
 | 
			
		||||
 */
 | 
			
		||||
export type RunningMode = 'audio_clips'|'audio_stream';
 | 
			
		||||
 | 
			
		||||
/** The options for configuring a MediaPipe Audio Task. */
 | 
			
		||||
export declare interface AudioTaskOptions {
 | 
			
		||||
  /** Options to configure the loading of the model assets. */
 | 
			
		||||
  baseOptions?: BaseOptions;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * The running mode of the task. Default to the audio_clips mode.
 | 
			
		||||
   * Audio tasks have two running modes:
 | 
			
		||||
   * 1) The mode for running a mediapipe audio task on independent
 | 
			
		||||
   *    audio clips.
 | 
			
		||||
   * 2) The mode for running a mediapipe audio task on an audio
 | 
			
		||||
   *    stream, such as from a microphone.
 | 
			
		||||
   */
 | 
			
		||||
  runningMode?: RunningMode;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										58
									
								
								mediapipe/tasks/web/audio/core/audio_task_runner.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								mediapipe/tasks/web/audio/core/audio_task_runner.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,58 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
			
		||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
			
		||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
			
		||||
 | 
			
		||||
import {AudioTaskOptions} from './audio_task_options';
 | 
			
		||||
 | 
			
		||||
/** Base class for all MediaPipe Audio Tasks. */
 | 
			
		||||
export abstract class AudioTaskRunner<T> extends TaskRunner {
 | 
			
		||||
  protected abstract baseOptions?: BaseOptionsProto|undefined;
 | 
			
		||||
  private defaultSampleRate = 48000;
 | 
			
		||||
 | 
			
		||||
  /** Configures the shared options of an audio task. */
 | 
			
		||||
  async setOptions(options: AudioTaskOptions): Promise<void> {
 | 
			
		||||
    this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
 | 
			
		||||
    if (options.baseOptions) {
 | 
			
		||||
      this.baseOptions = await convertBaseOptionsToProto(
 | 
			
		||||
          options.baseOptions, this.baseOptions);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Sets the sample rate for API calls that omit an explicit sample rate.
 | 
			
		||||
   * `48000` is used as a default if this method is not called.
 | 
			
		||||
   *
 | 
			
		||||
   * @param sampleRate A sample rate (e.g. `44100`).
 | 
			
		||||
   */
 | 
			
		||||
  setDefaultSampleRate(sampleRate: number) {
 | 
			
		||||
    this.defaultSampleRate = sampleRate;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Sends an audio packet to the graph and awaits results. */
 | 
			
		||||
  protected abstract process(
 | 
			
		||||
      audioData: Float32Array, sampleRate: number, timestampMs: number): T;
 | 
			
		||||
 | 
			
		||||
  /** Sends a single audio clip to the graph and awaits results. */
 | 
			
		||||
  protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
 | 
			
		||||
    return this.process(
 | 
			
		||||
        audioData, sampleRate ?? this.defaultSampleRate, performance.now());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user