Small TS audio API improvement
PiperOrigin-RevId: 490374083
This commit is contained in:
parent
efa9e737f8
commit
fac97554df
|
@ -35,11 +35,7 @@ export * from './audio_classifier_result';
|
||||||
const MEDIAPIPE_GRAPH =
|
const MEDIAPIPE_GRAPH =
|
||||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||||
|
|
||||||
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and
|
const AUDIO_STREAM = 'audio_in';
|
||||||
// cannot be changed
|
|
||||||
// TODO: Change this to `audio_in` to match the name in the CC
|
|
||||||
// implementation
|
|
||||||
const AUDIO_STREAM = 'input_audio';
|
|
||||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||||
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||||
|
|
||||||
|
@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
||||||
protected override process(
|
protected override process(
|
||||||
audioData: Float32Array, sampleRate: number,
|
audioData: Float32Array, sampleRate: number,
|
||||||
timestampMs: number): AudioClassifierResult[] {
|
timestampMs: number): AudioClassifierResult[] {
|
||||||
// Configures the number of samples in the WASM layer. We re-configure the
|
|
||||||
// number of samples and the sample rate for every frame, but ignore other
|
|
||||||
// side effects of this function (such as sending the input side packet and
|
|
||||||
// the input stream header).
|
|
||||||
this.configureAudio(
|
|
||||||
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
|
||||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||||
this.addAudioToStream(audioData, timestampMs);
|
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||||
|
|
||||||
this.classificationResults = [];
|
this.classificationResults = [];
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
|
|
|
@ -35,11 +35,7 @@ export * from './audio_embedder_result';
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot
|
const AUDIO_STREAM = 'audio_in';
|
||||||
// be changed
|
|
||||||
// TODO: Change this to `audio_in` to match the name in the CC
|
|
||||||
// implementation
|
|
||||||
const AUDIO_STREAM = 'input_audio';
|
|
||||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||||
const EMBEDDINGS_STREAM = 'embeddings_out';
|
const EMBEDDINGS_STREAM = 'embeddings_out';
|
||||||
const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out';
|
const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out';
|
||||||
|
@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
||||||
protected override process(
|
protected override process(
|
||||||
audioData: Float32Array, sampleRate: number,
|
audioData: Float32Array, sampleRate: number,
|
||||||
timestampMs: number): AudioEmbedderResult[] {
|
timestampMs: number): AudioEmbedderResult[] {
|
||||||
// Configures the number of samples in the WASM layer. We re-configure the
|
|
||||||
// number of samples and the sample rate for every frame, but ignore other
|
|
||||||
// side effects of this function (such as sending the input side packet and
|
|
||||||
// the input stream header).
|
|
||||||
this.configureAudio(
|
|
||||||
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
|
||||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||||
this.addAudioToStream(audioData, timestampMs);
|
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||||
|
|
||||||
this.embeddingResults = [];
|
this.embeddingResults = [];
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
|
|
|
@ -15,9 +15,6 @@ export declare interface FileLocator {
|
||||||
locateFile: (filename: string) => string;
|
locateFile: (filename: string) => string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Listener to be passed in by user for handling output audio data. */
|
|
||||||
export type AudioOutputListener = (output: Float32Array) => void;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
||||||
* doesn't break our JS/C++ bridge.
|
* doesn't break our JS/C++ bridge.
|
||||||
|
@ -32,19 +29,14 @@ export declare interface WasmModule {
|
||||||
_bindTextureToCanvas: () => boolean;
|
_bindTextureToCanvas: () => boolean;
|
||||||
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
||||||
_changeTextGraph: (size: number, dataPtr: number) => void;
|
_changeTextGraph: (size: number, dataPtr: number) => void;
|
||||||
_configureAudio:
|
|
||||||
(channels: number, samples: number, sampleRate: number) => void;
|
|
||||||
_free: (ptr: number) => void;
|
_free: (ptr: number) => void;
|
||||||
_malloc: (size: number) => number;
|
_malloc: (size: number) => number;
|
||||||
_processAudio: (dataPtr: number, timestamp: number) => void;
|
|
||||||
_processFrame: (width: number, height: number, timestamp: number) => void;
|
_processFrame: (width: number, height: number, timestamp: number) => void;
|
||||||
_setAutoRenderToScreen: (enabled: boolean) => void;
|
_setAutoRenderToScreen: (enabled: boolean) => void;
|
||||||
_waitUntilIdle: () => void;
|
_waitUntilIdle: () => void;
|
||||||
|
|
||||||
// Exposed so that clients of this lib can access this field
|
// Exposed so that clients of this lib can access this field
|
||||||
dataFileDownloads?: {[url: string]: {loaded: number, total: number}};
|
dataFileDownloads?: {[url: string]: {loaded: number, total: number}};
|
||||||
// Wasm module will call us back at this function when given audio data.
|
|
||||||
onAudioOutput?: AudioOutputListener;
|
|
||||||
|
|
||||||
// Wasm Module multistream entrypoints. Require
|
// Wasm Module multistream entrypoints. Require
|
||||||
// gl_graph_runner_internal_multi_input as a build dependency.
|
// gl_graph_runner_internal_multi_input as a build dependency.
|
||||||
|
@ -100,11 +92,14 @@ export declare interface WasmModule {
|
||||||
_attachProtoVectorListener:
|
_attachProtoVectorListener:
|
||||||
(streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
(streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
||||||
|
|
||||||
// Requires dependency ":gl_graph_runner_audio_out", and will register an
|
// Require dependency ":gl_graph_runner_audio_out"
|
||||||
// audio output listening function which can be tapped into dynamically during
|
_attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
||||||
// graph running via onAudioOutput. This call must be made before graph is
|
|
||||||
// initialized, but after wasmModule is instantiated.
|
// Require dependency ":gl_graph_runner_audio"
|
||||||
_attachAudioOutputListener: () => void;
|
_addAudioToInputStream: (dataPtr: number, numChannels: number,
|
||||||
|
numSamples: number, streamNamePtr: number, timestamp: number) => void;
|
||||||
|
_configureAudio: (channels: number, samples: number, sampleRate: number,
|
||||||
|
streamNamePtr: number, headerNamePtr: number) => void;
|
||||||
|
|
||||||
// TODO: Refactor to just use a few numbers (perhaps refactor away
|
// TODO: Refactor to just use a few numbers (perhaps refactor away
|
||||||
// from gl_graph_runner_internal.cc entirely to use something a little more
|
// from gl_graph_runner_internal.cc entirely to use something a little more
|
||||||
|
@ -235,19 +230,38 @@ export class GraphRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configures the current graph to handle audio in a certain way. Must be
|
* Configures the current graph to handle audio processing in a certain way
|
||||||
* called before the graph is set/started in order to use processAudio.
|
* for all its audio input streams. Additionally can configure audio headers
|
||||||
|
* (both input side packets as well as input stream headers), but these
|
||||||
|
* configurations only take effect if called before the graph is set/started.
|
||||||
* @param numChannels The number of channels of audio input. Only 1
|
* @param numChannels The number of channels of audio input. Only 1
|
||||||
* is supported for now.
|
* is supported for now.
|
||||||
* @param numSamples The number of samples that are taken in each
|
* @param numSamples The number of samples that are taken in each
|
||||||
* audio capture.
|
* audio capture.
|
||||||
* @param sampleRate The rate, in Hz, of the sampling.
|
* @param sampleRate The rate, in Hz, of the sampling.
|
||||||
|
* @param streamName The optional name of the input stream to additionally
|
||||||
|
* configure with audio information. This configuration only occurs before
|
||||||
|
* the graph is set/started. If unset, a default stream name will be used.
|
||||||
|
* @param headerName The optional name of the header input side packet to
|
||||||
|
* additionally configure with audio information. This configuration only
|
||||||
|
* occurs before the graph is set/started. If unset, a default header name
|
||||||
|
* will be used.
|
||||||
*/
|
*/
|
||||||
configureAudio(numChannels: number, numSamples: number, sampleRate: number) {
|
configureAudio(numChannels: number, numSamples: number, sampleRate: number,
|
||||||
this.wasmModule._configureAudio(numChannels, numSamples, sampleRate);
|
streamName?: string, headerName?: string) {
|
||||||
if (this.wasmModule._attachAudioOutputListener) {
|
if (!this.wasmModule._configureAudio) {
|
||||||
this.wasmModule._attachAudioOutputListener();
|
console.warn(
|
||||||
|
'Attempting to use configureAudio without support for input audio. ' +
|
||||||
|
'Is build dep ":gl_graph_runner_audio" missing?');
|
||||||
}
|
}
|
||||||
|
streamName = streamName || 'input_audio';
|
||||||
|
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
|
||||||
|
headerName = headerName || 'audio_header';
|
||||||
|
this.wrapStringPtr(headerName, (headerNamePtr: number) => {
|
||||||
|
this.wasmModule._configureAudio(streamNamePtr, headerNamePtr,
|
||||||
|
numChannels, numSamples, sampleRate);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -437,9 +451,36 @@ export class GraphRunner {
|
||||||
* processed.
|
* processed.
|
||||||
* @param audioData An array of raw audio capture data, like
|
* @param audioData An array of raw audio capture data, like
|
||||||
* from a call to getChannelData on an AudioBuffer.
|
* from a call to getChannelData on an AudioBuffer.
|
||||||
|
* @param streamName The name of the MediaPipe graph stream to add the audio
|
||||||
|
* data to.
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
*/
|
*/
|
||||||
addAudioToStream(audioData: Float32Array, timestamp: number) {
|
addAudioToStream(
|
||||||
|
audioData: Float32Array, streamName: string, timestamp: number) {
|
||||||
|
// numChannels and numSamples being 0 will cause defaults to be used,
|
||||||
|
// which will reflect values from last call to configureAudio.
|
||||||
|
this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Takes the raw data from a JS audio capture array, and sends it to C++ to be
|
||||||
|
* processed, shaping the audioData array into an audio matrix according to
|
||||||
|
* the numChannels and numSamples parameters.
|
||||||
|
* @param audioData An array of raw audio capture data, like
|
||||||
|
* from a call to getChannelData on an AudioBuffer.
|
||||||
|
* @param numChannels The number of audio channels this data represents. If 0
|
||||||
|
* is passed, then the value will be taken from the last call to
|
||||||
|
* configureAudio.
|
||||||
|
* @param numSamples The number of audio samples captured in this data packet.
|
||||||
|
* If 0 is passed, then the value will be taken from the last call to
|
||||||
|
* configureAudio.
|
||||||
|
* @param streamName The name of the MediaPipe graph stream to add the audio
|
||||||
|
* data to.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
*/
|
||||||
|
addAudioToStreamWithShape(
|
||||||
|
audioData: Float32Array, numChannels: number, numSamples: number,
|
||||||
|
streamName: string, timestamp: number) {
|
||||||
// 4 bytes for each F32
|
// 4 bytes for each F32
|
||||||
const size = audioData.length * 4;
|
const size = audioData.length * 4;
|
||||||
if (this.audioSize !== size) {
|
if (this.audioSize !== size) {
|
||||||
|
@ -450,7 +491,11 @@ export class GraphRunner {
|
||||||
this.audioSize = size;
|
this.audioSize = size;
|
||||||
}
|
}
|
||||||
this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4);
|
this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4);
|
||||||
this.wasmModule._processAudio(this.audioPtr!, timestamp);
|
|
||||||
|
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
|
||||||
|
this.wasmModule._addAudioToInputStream(
|
||||||
|
this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -943,17 +988,45 @@ export class GraphRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets a listener to be called back with audio output packet data, as a
|
* Attaches an audio packet listener to the specified output_stream, to be
|
||||||
* Float32Array, when graph has finished processing it.
|
* given a Float32Array as output.
|
||||||
* @param audioOutputListener The caller's listener function.
|
* @param outputStreamName The name of the graph output stream to grab audio
|
||||||
|
* data from.
|
||||||
|
* @param callbackFcn The function that will be called back with the data, as
|
||||||
|
* it is received. Note that the data is only guaranteed to exist for the
|
||||||
|
* duration of the callback, and the callback will be called inline, so it
|
||||||
|
* should not perform overly complicated (or any async) behavior. If the
|
||||||
|
* audio data needs to be able to outlive the call, you may set the
|
||||||
|
* optional makeDeepCopy parameter to true, or can manually deep-copy the
|
||||||
|
* data yourself.
|
||||||
|
* @param makeDeepCopy Optional convenience parameter which, if set to true,
|
||||||
|
* will override the default memory management behavior and make a deep
|
||||||
|
* copy of the underlying data, rather than just returning a view into the
|
||||||
|
* C++-managed memory. At the cost of a data copy, this allows the
|
||||||
|
* returned data to outlive the callback lifetime (and it will be cleaned
|
||||||
|
* up automatically by JS garbage collection whenever the user is finished
|
||||||
|
* with it).
|
||||||
*/
|
*/
|
||||||
setOnAudioOutput(audioOutputListener: AudioOutputListener) {
|
attachAudioListener(outputStreamName: string,
|
||||||
this.wasmModule.onAudioOutput = audioOutputListener;
|
callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void {
|
||||||
if (!this.wasmModule._attachAudioOutputListener) {
|
if (!this.wasmModule._attachAudioListener) {
|
||||||
console.warn(
|
console.warn(
|
||||||
'Attempting to use AudioOutputListener without support for ' +
|
'Attempting to use attachAudioListener without support for ' +
|
||||||
'output audio. Is build dep ":gl_graph_runner_audio_out" missing?');
|
'output audio. Is build dep ":gl_graph_runner_audio_out" missing?');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up our TS listener to receive any packets for this stream, and
|
||||||
|
// additionally reformat our Uint8Array into a Float32Array for the user.
|
||||||
|
this.setListener(outputStreamName, (data: Uint8Array) => {
|
||||||
|
const floatArray = new Float32Array(data.buffer); // Should be very fast
|
||||||
|
callbackFcn(floatArray);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Tell our graph to listen for string packets on this stream.
|
||||||
|
this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => {
|
||||||
|
this.wasmModule._attachAudioListener(
|
||||||
|
outputStreamNamePtr, makeDeepCopy || false);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue
Block a user