Small TS audio API improvement

PiperOrigin-RevId: 490374083
This commit is contained in:
MediaPipe Team 2022-11-22 17:23:48 -08:00 committed by Copybara-Service
parent efa9e737f8
commit fac97554df
3 changed files with 105 additions and 52 deletions

View File

@ -35,11 +35,7 @@ export * from './audio_classifier_result';
const MEDIAPIPE_GRAPH = const MEDIAPIPE_GRAPH =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and const AUDIO_STREAM = 'audio_in';
// cannot be changed
// TODO: Change this to `audio_in` to match the name in the CC
// implementation
const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate'; const SAMPLE_RATE_STREAM = 'sample_rate';
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
protected override process( protected override process(
audioData: Float32Array, sampleRate: number, audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioClassifierResult[] { timestampMs: number): AudioClassifierResult[] {
// Configures the number of samples in the WASM layer. We re-configure the
// number of samples and the sample rate for every frame, but ignore other
// side effects of this function (such as sending the input side packet and
// the input stream header).
this.configureAudio(
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.addAudioToStream(audioData, timestampMs); this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
this.classificationResults = []; this.classificationResults = [];
this.finishProcessing(); this.finishProcessing();

View File

@ -35,11 +35,7 @@ export * from './audio_embedder_result';
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot const AUDIO_STREAM = 'audio_in';
// be changed
// TODO: Change this to `audio_in` to match the name in the CC
// implementation
const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate'; const SAMPLE_RATE_STREAM = 'sample_rate';
const EMBEDDINGS_STREAM = 'embeddings_out'; const EMBEDDINGS_STREAM = 'embeddings_out';
const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out';
@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
protected override process( protected override process(
audioData: Float32Array, sampleRate: number, audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioEmbedderResult[] { timestampMs: number): AudioEmbedderResult[] {
// Configures the number of samples in the WASM layer. We re-configure the
// number of samples and the sample rate for every frame, but ignore other
// side effects of this function (such as sending the input side packet and
// the input stream header).
this.configureAudio(
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.addAudioToStream(audioData, timestampMs); this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
this.embeddingResults = []; this.embeddingResults = [];
this.finishProcessing(); this.finishProcessing();

View File

@ -15,9 +15,6 @@ export declare interface FileLocator {
locateFile: (filename: string) => string; locateFile: (filename: string) => string;
} }
/** Listener to be passed in by user for handling output audio data. */
export type AudioOutputListener = (output: Float32Array) => void;
/** /**
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
* doesn't break our JS/C++ bridge. * doesn't break our JS/C++ bridge.
@ -32,19 +29,14 @@ export declare interface WasmModule {
_bindTextureToCanvas: () => boolean; _bindTextureToCanvas: () => boolean;
_changeBinaryGraph: (size: number, dataPtr: number) => void; _changeBinaryGraph: (size: number, dataPtr: number) => void;
_changeTextGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void;
_configureAudio:
(channels: number, samples: number, sampleRate: number) => void;
_free: (ptr: number) => void; _free: (ptr: number) => void;
_malloc: (size: number) => number; _malloc: (size: number) => number;
_processAudio: (dataPtr: number, timestamp: number) => void;
_processFrame: (width: number, height: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void;
_setAutoRenderToScreen: (enabled: boolean) => void; _setAutoRenderToScreen: (enabled: boolean) => void;
_waitUntilIdle: () => void; _waitUntilIdle: () => void;
// Exposed so that clients of this lib can access this field // Exposed so that clients of this lib can access this field
dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; dataFileDownloads?: {[url: string]: {loaded: number, total: number}};
// Wasm module will call us back at this function when given audio data.
onAudioOutput?: AudioOutputListener;
// Wasm Module multistream entrypoints. Require // Wasm Module multistream entrypoints. Require
// gl_graph_runner_internal_multi_input as a build dependency. // gl_graph_runner_internal_multi_input as a build dependency.
@ -100,11 +92,14 @@ export declare interface WasmModule {
_attachProtoVectorListener: _attachProtoVectorListener:
(streamNamePtr: number, makeDeepCopy?: boolean) => void; (streamNamePtr: number, makeDeepCopy?: boolean) => void;
// Requires dependency ":gl_graph_runner_audio_out", and will register an // Require dependency ":gl_graph_runner_audio_out"
// audio output listening function which can be tapped into dynamically during _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void;
// graph running via onAudioOutput. This call must be made before graph is
// initialized, but after wasmModule is instantiated. // Require dependency ":gl_graph_runner_audio"
_attachAudioOutputListener: () => void; _addAudioToInputStream: (dataPtr: number, numChannels: number,
numSamples: number, streamNamePtr: number, timestamp: number) => void;
_configureAudio: (channels: number, samples: number, sampleRate: number,
streamNamePtr: number, headerNamePtr: number) => void;
// TODO: Refactor to just use a few numbers (perhaps refactor away // TODO: Refactor to just use a few numbers (perhaps refactor away
// from gl_graph_runner_internal.cc entirely to use something a little more // from gl_graph_runner_internal.cc entirely to use something a little more
@ -235,19 +230,38 @@ export class GraphRunner {
} }
/** /**
* Configures the current graph to handle audio in a certain way. Must be * Configures the current graph to handle audio processing in a certain way
* called before the graph is set/started in order to use processAudio. * for all its audio input streams. Additionally can configure audio headers
* (both input side packets as well as input stream headers), but these
* configurations only take effect if called before the graph is set/started.
* @param numChannels The number of channels of audio input. Only 1 * @param numChannels The number of channels of audio input. Only 1
* is supported for now. * is supported for now.
* @param numSamples The number of samples that are taken in each * @param numSamples The number of samples that are taken in each
* audio capture. * audio capture.
* @param sampleRate The rate, in Hz, of the sampling. * @param sampleRate The rate, in Hz, of the sampling.
* @param streamName The optional name of the input stream to additionally
* configure with audio information. This configuration only occurs before
* the graph is set/started. If unset, a default stream name will be used.
* @param headerName The optional name of the header input side packet to
* additionally configure with audio information. This configuration only
* occurs before the graph is set/started. If unset, a default header name
* will be used.
*/ */
configureAudio(numChannels: number, numSamples: number, sampleRate: number) { configureAudio(numChannels: number, numSamples: number, sampleRate: number,
this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); streamName?: string, headerName?: string) {
if (this.wasmModule._attachAudioOutputListener) { if (!this.wasmModule._configureAudio) {
this.wasmModule._attachAudioOutputListener(); console.warn(
'Attempting to use configureAudio without support for input audio. ' +
'Is build dep ":gl_graph_runner_audio" missing?');
} }
streamName = streamName || 'input_audio';
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
headerName = headerName || 'audio_header';
this.wrapStringPtr(headerName, (headerNamePtr: number) => {
this.wasmModule._configureAudio(streamNamePtr, headerNamePtr,
numChannels, numSamples, sampleRate);
});
});
} }
/** /**
@ -437,9 +451,36 @@ export class GraphRunner {
* processed. * processed.
* @param audioData An array of raw audio capture data, like * @param audioData An array of raw audio capture data, like
* from a call to getChannelData on an AudioBuffer. * from a call to getChannelData on an AudioBuffer.
* @param streamName The name of the MediaPipe graph stream to add the audio
* data to.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
*/ */
addAudioToStream(audioData: Float32Array, timestamp: number) { addAudioToStream(
audioData: Float32Array, streamName: string, timestamp: number) {
// numChannels and numSamples being 0 will cause defaults to be used,
// which will reflect values from last call to configureAudio.
this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp);
}
/**
* Takes the raw data from a JS audio capture array, and sends it to C++ to be
* processed, shaping the audioData array into an audio matrix according to
* the numChannels and numSamples parameters.
* @param audioData An array of raw audio capture data, like
* from a call to getChannelData on an AudioBuffer.
* @param numChannels The number of audio channels this data represents. If 0
* is passed, then the value will be taken from the last call to
* configureAudio.
* @param numSamples The number of audio samples captured in this data packet.
* If 0 is passed, then the value will be taken from the last call to
* configureAudio.
* @param streamName The name of the MediaPipe graph stream to add the audio
* data to.
* @param timestamp The timestamp of the current frame, in ms.
*/
addAudioToStreamWithShape(
audioData: Float32Array, numChannels: number, numSamples: number,
streamName: string, timestamp: number) {
// 4 bytes for each F32 // 4 bytes for each F32
const size = audioData.length * 4; const size = audioData.length * 4;
if (this.audioSize !== size) { if (this.audioSize !== size) {
@ -450,7 +491,11 @@ export class GraphRunner {
this.audioSize = size; this.audioSize = size;
} }
this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4);
this.wasmModule._processAudio(this.audioPtr!, timestamp);
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
this.wasmModule._addAudioToInputStream(
this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp);
});
} }
/** /**
@ -943,17 +988,45 @@ export class GraphRunner {
} }
/** /**
* Sets a listener to be called back with audio output packet data, as a * Attaches an audio packet listener to the specified output_stream, to be
* Float32Array, when graph has finished processing it. * given a Float32Array as output.
* @param audioOutputListener The caller's listener function. * @param outputStreamName The name of the graph output stream to grab audio
* data from.
* @param callbackFcn The function that will be called back with the data, as
* it is received. Note that the data is only guaranteed to exist for the
* duration of the callback, and the callback will be called inline, so it
* should not perform overly complicated (or any async) behavior. If the
* audio data needs to be able to outlive the call, you may set the
* optional makeDeepCopy parameter to true, or can manually deep-copy the
* data yourself.
* @param makeDeepCopy Optional convenience parameter which, if set to true,
* will override the default memory management behavior and make a deep
* copy of the underlying data, rather than just returning a view into the
* C++-managed memory. At the cost of a data copy, this allows the
* returned data to outlive the callback lifetime (and it will be cleaned
* up automatically by JS garbage collection whenever the user is finished
* with it).
*/ */
setOnAudioOutput(audioOutputListener: AudioOutputListener) { attachAudioListener(outputStreamName: string,
this.wasmModule.onAudioOutput = audioOutputListener; callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void {
if (!this.wasmModule._attachAudioOutputListener) { if (!this.wasmModule._attachAudioListener) {
console.warn( console.warn(
'Attempting to use AudioOutputListener without support for ' + 'Attempting to use attachAudioListener without support for ' +
'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?');
} }
// Set up our TS listener to receive any packets for this stream, and
// additionally reformat our Uint8Array into a Float32Array for the user.
this.setListener(outputStreamName, (data: Uint8Array) => {
const floatArray = new Float32Array(data.buffer); // Should be very fast
callbackFcn(floatArray);
});
// Tell our graph to listen for string packets on this stream.
this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => {
this.wasmModule._attachAudioListener(
outputStreamNamePtr, makeDeepCopy || false);
});
} }
/** /**