Create AudioTaskRunner
PiperOrigin-RevId: 489613573
This commit is contained in:
		
							parent
							
								
									eb8ef1ace0
								
							
						
					
					
						commit
						e853f04b79
					
				| 
						 | 
					@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable {
 | 
				
			||||||
      defaultSampleRate = sampleRate;
 | 
					      defaultSampleRate = sampleRate;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
 | 
					   * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
 | 
				
			||||||
   * available in the user-defined result listener.
 | 
					   * available in the user-defined result listener.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,14 +17,14 @@ mediapipe_ts_library(
 | 
				
			||||||
        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
					        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
 | 
					        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/audio/core:audio_task_runner",
 | 
				
			||||||
        "//mediapipe/tasks/web/components/containers:category",
 | 
					        "//mediapipe/tasks/web/components/containers:category",
 | 
				
			||||||
        "//mediapipe/tasks/web/components/containers:classification_result",
 | 
					        "//mediapipe/tasks/web/components/containers:classification_result",
 | 
				
			||||||
        "//mediapipe/tasks/web/components/processors:base_options",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_options",
 | 
				
			||||||
        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
					        "//mediapipe/tasks/web/components/processors:classifier_result",
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
        "//mediapipe/tasks/web/core:classifier_options",
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
        "//mediapipe/tasks/web/core:task_runner",
 | 
					 | 
				
			||||||
        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
					        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
					import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
				
			||||||
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
 | 
					import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
 | 
				
			||||||
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
					import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
 | 
				
			||||||
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
					import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
				
			||||||
 | 
					import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
 | 
				
			||||||
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
					import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
 | 
				
			||||||
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
					import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
 | 
				
			||||||
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
					 | 
				
			||||||
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
					import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
 | 
				
			||||||
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
 | 
					import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner';
 | 
				
			||||||
// Placeholder for internal dependency on trusted resource url
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
| 
						 | 
					@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
 | 
				
			||||||
// tslint:disable:jspb-use-builder-pattern
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Performs audio classification. */
 | 
					/** Performs audio classification. */
 | 
				
			||||||
export class AudioClassifier extends TaskRunner {
 | 
					export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
 | 
				
			||||||
  private classificationResults: AudioClassifierResult[] = [];
 | 
					  private classificationResults: AudioClassifierResult[] = [];
 | 
				
			||||||
  private defaultSampleRate = 48000;
 | 
					 | 
				
			||||||
  private readonly options = new AudioClassifierGraphOptions();
 | 
					  private readonly options = new AudioClassifierGraphOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner {
 | 
				
			||||||
        wasmLoaderOptions, new Uint8Array(graphData));
 | 
					        wasmLoaderOptions, new Uint8Array(graphData));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  protected override get baseOptions(): BaseOptionsProto|undefined {
 | 
				
			||||||
 | 
					    return this.options.getBaseOptions();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  protected override set baseOptions(proto: BaseOptionsProto|undefined) {
 | 
				
			||||||
 | 
					    this.options.setBaseOptions(proto);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Sets new options for the audio classifier.
 | 
					   * Sets new options for the audio classifier.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
| 
						 | 
					@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner {
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param options The options for the audio classifier.
 | 
					   * @param options The options for the audio classifier.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
					  override async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    await super.setOptions(options);
 | 
				
			||||||
      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
					 | 
				
			||||||
          options.baseOptions, this.options.getBaseOptions());
 | 
					 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
					    this.options.setClassifierOptions(convertClassifierOptionsToProto(
 | 
				
			||||||
        options, this.options.getClassifierOptions()));
 | 
					        options, this.options.getClassifierOptions()));
 | 
				
			||||||
    this.refreshGraph();
 | 
					    this.refreshGraph();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Sets the sample rate for all calls to `classify()` that omit an explicit
 | 
					   * Performs audio classification on the provided audio clip and waits
 | 
				
			||||||
   * sample rate. `48000` is used as a default if this method is not called.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param sampleRate A sample rate (e.g. `44100`).
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  setDefaultSampleRate(sampleRate: number) {
 | 
					 | 
				
			||||||
    this.defaultSampleRate = sampleRate;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Performs audio classification on the provided audio data and waits
 | 
					 | 
				
			||||||
   * synchronously for the response.
 | 
					   * synchronously for the response.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param audioData An array of raw audio capture data, like
 | 
					   * @param audioData An array of raw audio capture data, like from a call to
 | 
				
			||||||
   *     from a call to getChannelData on an AudioBuffer.
 | 
					   *     `getChannelData()` on an AudioBuffer.
 | 
				
			||||||
   * @param sampleRate The sample rate in Hz of the provided audio data. If not
 | 
					   * @param sampleRate The sample rate in Hz of the provided audio data. If not
 | 
				
			||||||
   *     set, defaults to the sample rate set via `setDefaultSampleRate()` or
 | 
					   *     set, defaults to the sample rate set via `setDefaultSampleRate()` or
 | 
				
			||||||
   *     `48000` if no custom default was set.
 | 
					   *     `48000` if no custom default was set.
 | 
				
			||||||
| 
						 | 
					@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  classify(audioData: Float32Array, sampleRate?: number):
 | 
					  classify(audioData: Float32Array, sampleRate?: number):
 | 
				
			||||||
      AudioClassifierResult[] {
 | 
					      AudioClassifierResult[] {
 | 
				
			||||||
    sampleRate = sampleRate ?? this.defaultSampleRate;
 | 
					    return this.processAudioClip(audioData, sampleRate);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Sends an audio package to the graph and returns the classifications. */
 | 
				
			||||||
 | 
					  protected override process(
 | 
				
			||||||
 | 
					      audioData: Float32Array, sampleRate: number,
 | 
				
			||||||
 | 
					      timestampMs: number): AudioClassifierResult[] {
 | 
				
			||||||
    // Configures the number of samples in the WASM layer. We re-configure the
 | 
					    // Configures the number of samples in the WASM layer. We re-configure the
 | 
				
			||||||
    // number of samples and the sample rate for every frame, but ignore other
 | 
					    // number of samples and the sample rate for every frame, but ignore other
 | 
				
			||||||
    // side effects of this function (such as sending the input side packet and
 | 
					    // side effects of this function (such as sending the input side packet and
 | 
				
			||||||
    // the input stream header).
 | 
					    // the input stream header).
 | 
				
			||||||
    this.configureAudio(
 | 
					    this.configureAudio(
 | 
				
			||||||
        /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
 | 
					        /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
 | 
				
			||||||
 | 
					    this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
 | 
				
			||||||
    const timestamp = performance.now();
 | 
					    this.addAudioToStream(audioData, timestampMs);
 | 
				
			||||||
    this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
 | 
					 | 
				
			||||||
    this.addAudioToStream(audioData, timestamp);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    this.classificationResults = [];
 | 
					    this.classificationResults = [];
 | 
				
			||||||
    this.finishProcessing();
 | 
					    this.finishProcessing();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,6 @@
 | 
				
			||||||
# This package contains options shared by all MediaPipe Audio Tasks for Web.
 | 
					# This package contains options shared by all MediaPipe Audio Tasks for Web.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration")
 | 
					load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,3 +11,15 @@ mediapipe_ts_declaration(
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "audio_task_runner",
 | 
				
			||||||
 | 
					    srcs = ["audio_task_runner.ts"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":audio_task_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/components/processors:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/core:task_runner",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,29 +16,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
					import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * MediaPipe audio task running mode. A MediaPipe audio task can be run with
 | 
					 | 
				
			||||||
 * two different modes:
 | 
					 | 
				
			||||||
 * - audio_clips:  The mode for running a mediapipe audio task on independent
 | 
					 | 
				
			||||||
 *                 audio clips.
 | 
					 | 
				
			||||||
 * - audio_stream: The mode for running a mediapipe audio task on an audio
 | 
					 | 
				
			||||||
 *                 stream, such as from a microphone.
 | 
					 | 
				
			||||||
 * </ul>
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
export type RunningMode = 'audio_clips'|'audio_stream';
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/** The options for configuring a MediaPipe Audio Task. */
 | 
					/** The options for configuring a MediaPipe Audio Task. */
 | 
				
			||||||
export declare interface AudioTaskOptions {
 | 
					export declare interface AudioTaskOptions {
 | 
				
			||||||
  /** Options to configure the loading of the model assets. */
 | 
					  /** Options to configure the loading of the model assets. */
 | 
				
			||||||
  baseOptions?: BaseOptions;
 | 
					  baseOptions?: BaseOptions;
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * The running mode of the task. Default to the audio_clips mode.
 | 
					 | 
				
			||||||
   * Audio tasks have two running modes:
 | 
					 | 
				
			||||||
   * 1) The mode for running a mediapipe audio task on independent
 | 
					 | 
				
			||||||
   *    audio clips.
 | 
					 | 
				
			||||||
   * 2) The mode for running a mediapipe audio task on an audio
 | 
					 | 
				
			||||||
   *    stream, such as from a microphone.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  runningMode?: RunningMode;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										58
									
								
								mediapipe/tasks/web/audio/core/audio_task_runner.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								mediapipe/tasks/web/audio/core/audio_task_runner.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,58 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
				
			||||||
 | 
					import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
 | 
				
			||||||
 | 
					import {TaskRunner} from '../../../../tasks/web/core/task_runner';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {AudioTaskOptions} from './audio_task_options';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Base class for all MediaPipe Audio Tasks. */
 | 
				
			||||||
 | 
					export abstract class AudioTaskRunner<T> extends TaskRunner {
 | 
				
			||||||
 | 
					  protected abstract baseOptions?: BaseOptionsProto|undefined;
 | 
				
			||||||
 | 
					  private defaultSampleRate = 48000;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Configures the shared options of an audio task. */
 | 
				
			||||||
 | 
					  async setOptions(options: AudioTaskOptions): Promise<void> {
 | 
				
			||||||
 | 
					    this.baseOptions = this.baseOptions ?? new BaseOptionsProto();
 | 
				
			||||||
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
 | 
					      this.baseOptions = await convertBaseOptionsToProto(
 | 
				
			||||||
 | 
					          options.baseOptions, this.baseOptions);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Sets the sample rate for API calls that omit an explicit sample rate.
 | 
				
			||||||
 | 
					   * `48000` is used as a default if this method is not called.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param sampleRate A sample rate (e.g. `44100`).
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  setDefaultSampleRate(sampleRate: number) {
 | 
				
			||||||
 | 
					    this.defaultSampleRate = sampleRate;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Sends an audio packet to the graph and awaits results. */
 | 
				
			||||||
 | 
					  protected abstract process(
 | 
				
			||||||
 | 
					      audioData: Float32Array, sampleRate: number, timestampMs: number): T;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Sends a single audio clip to the graph and awaits results. */
 | 
				
			||||||
 | 
					  protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
 | 
				
			||||||
 | 
					    return this.process(
 | 
				
			||||||
 | 
					        audioData, sampleRate ?? this.defaultSampleRate, performance.now());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user