Small TS audio API improvement
PiperOrigin-RevId: 490374083
This commit is contained in:
parent
efa9e737f8
commit
fac97554df
|
@ -35,11 +35,7 @@ export * from './audio_classifier_result';
|
|||
const MEDIAPIPE_GRAPH =
|
||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||
|
||||
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and
|
||||
// cannot be changed
|
||||
// TODO: Change this to `audio_in` to match the name in the CC
|
||||
// implementation
|
||||
const AUDIO_STREAM = 'input_audio';
|
||||
const AUDIO_STREAM = 'audio_in';
|
||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||
|
||||
|
@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
protected override process(
|
||||
audioData: Float32Array, sampleRate: number,
|
||||
timestampMs: number): AudioClassifierResult[] {
|
||||
// Configures the number of samples in the WASM layer. We re-configure the
|
||||
// number of samples and the sample rate for every frame, but ignore other
|
||||
// side effects of this function (such as sending the input side packet and
|
||||
// the input stream header).
|
||||
this.configureAudio(
|
||||
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.addAudioToStream(audioData, timestampMs);
|
||||
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||
|
||||
this.classificationResults = [];
|
||||
this.finishProcessing();
|
||||
|
|
|
@ -35,11 +35,7 @@ export * from './audio_embedder_result';
|
|||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot
|
||||
// be changed
|
||||
// TODO: Change this to `audio_in` to match the name in the CC
|
||||
// implementation
|
||||
const AUDIO_STREAM = 'input_audio';
|
||||
const AUDIO_STREAM = 'audio_in';
|
||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||
const EMBEDDINGS_STREAM = 'embeddings_out';
|
||||
const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out';
|
||||
|
@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
protected override process(
|
||||
audioData: Float32Array, sampleRate: number,
|
||||
timestampMs: number): AudioEmbedderResult[] {
|
||||
// Configures the number of samples in the WASM layer. We re-configure the
|
||||
// number of samples and the sample rate for every frame, but ignore other
|
||||
// side effects of this function (such as sending the input side packet and
|
||||
// the input stream header).
|
||||
this.configureAudio(
|
||||
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.addAudioToStream(audioData, timestampMs);
|
||||
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||
|
||||
this.embeddingResults = [];
|
||||
this.finishProcessing();
|
||||
|
|
|
@ -15,9 +15,6 @@ export declare interface FileLocator {
|
|||
locateFile: (filename: string) => string;
|
||||
}
|
||||
|
||||
/** Listener to be passed in by user for handling output audio data. */
|
||||
export type AudioOutputListener = (output: Float32Array) => void;
|
||||
|
||||
/**
|
||||
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
||||
* doesn't break our JS/C++ bridge.
|
||||
|
@ -32,19 +29,14 @@ export declare interface WasmModule {
|
|||
_bindTextureToCanvas: () => boolean;
|
||||
_changeBinaryGraph: (size: number, dataPtr: number) => void;
|
||||
_changeTextGraph: (size: number, dataPtr: number) => void;
|
||||
_configureAudio:
|
||||
(channels: number, samples: number, sampleRate: number) => void;
|
||||
_free: (ptr: number) => void;
|
||||
_malloc: (size: number) => number;
|
||||
_processAudio: (dataPtr: number, timestamp: number) => void;
|
||||
_processFrame: (width: number, height: number, timestamp: number) => void;
|
||||
_setAutoRenderToScreen: (enabled: boolean) => void;
|
||||
_waitUntilIdle: () => void;
|
||||
|
||||
// Exposed so that clients of this lib can access this field
|
||||
dataFileDownloads?: {[url: string]: {loaded: number, total: number}};
|
||||
// Wasm module will call us back at this function when given audio data.
|
||||
onAudioOutput?: AudioOutputListener;
|
||||
|
||||
// Wasm Module multistream entrypoints. Require
|
||||
// gl_graph_runner_internal_multi_input as a build dependency.
|
||||
|
@ -100,11 +92,14 @@ export declare interface WasmModule {
|
|||
_attachProtoVectorListener:
|
||||
(streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
||||
|
||||
// Requires dependency ":gl_graph_runner_audio_out", and will register an
|
||||
// audio output listening function which can be tapped into dynamically during
|
||||
// graph running via onAudioOutput. This call must be made before graph is
|
||||
// initialized, but after wasmModule is instantiated.
|
||||
_attachAudioOutputListener: () => void;
|
||||
// Require dependency ":gl_graph_runner_audio_out"
|
||||
_attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
||||
|
||||
// Require dependency ":gl_graph_runner_audio"
|
||||
_addAudioToInputStream: (dataPtr: number, numChannels: number,
|
||||
numSamples: number, streamNamePtr: number, timestamp: number) => void;
|
||||
_configureAudio: (channels: number, samples: number, sampleRate: number,
|
||||
streamNamePtr: number, headerNamePtr: number) => void;
|
||||
|
||||
// TODO: Refactor to just use a few numbers (perhaps refactor away
|
||||
// from gl_graph_runner_internal.cc entirely to use something a little more
|
||||
|
@ -235,19 +230,38 @@ export class GraphRunner {
|
|||
}
|
||||
|
||||
/**
|
||||
* Configures the current graph to handle audio in a certain way. Must be
|
||||
* called before the graph is set/started in order to use processAudio.
|
||||
* Configures the current graph to handle audio processing in a certain way
|
||||
* for all its audio input streams. Additionally can configure audio headers
|
||||
* (both input side packets as well as input stream headers), but these
|
||||
* configurations only take effect if called before the graph is set/started.
|
||||
* @param numChannels The number of channels of audio input. Only 1
|
||||
* is supported for now.
|
||||
* @param numSamples The number of samples that are taken in each
|
||||
* audio capture.
|
||||
* @param sampleRate The rate, in Hz, of the sampling.
|
||||
* @param streamName The optional name of the input stream to additionally
|
||||
* configure with audio information. This configuration only occurs before
|
||||
* the graph is set/started. If unset, a default stream name will be used.
|
||||
* @param headerName The optional name of the header input side packet to
|
||||
* additionally configure with audio information. This configuration only
|
||||
* occurs before the graph is set/started. If unset, a default header name
|
||||
* will be used.
|
||||
*/
|
||||
configureAudio(numChannels: number, numSamples: number, sampleRate: number) {
|
||||
this.wasmModule._configureAudio(numChannels, numSamples, sampleRate);
|
||||
if (this.wasmModule._attachAudioOutputListener) {
|
||||
this.wasmModule._attachAudioOutputListener();
|
||||
configureAudio(numChannels: number, numSamples: number, sampleRate: number,
|
||||
streamName?: string, headerName?: string) {
|
||||
if (!this.wasmModule._configureAudio) {
|
||||
console.warn(
|
||||
'Attempting to use configureAudio without support for input audio. ' +
|
||||
'Is build dep ":gl_graph_runner_audio" missing?');
|
||||
}
|
||||
streamName = streamName || 'input_audio';
|
||||
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
|
||||
headerName = headerName || 'audio_header';
|
||||
this.wrapStringPtr(headerName, (headerNamePtr: number) => {
|
||||
this.wasmModule._configureAudio(streamNamePtr, headerNamePtr,
|
||||
numChannels, numSamples, sampleRate);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -437,9 +451,36 @@ export class GraphRunner {
|
|||
* processed.
|
||||
* @param audioData An array of raw audio capture data, like
|
||||
* from a call to getChannelData on an AudioBuffer.
|
||||
* @param streamName The name of the MediaPipe graph stream to add the audio
|
||||
* data to.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
*/
|
||||
addAudioToStream(audioData: Float32Array, timestamp: number) {
|
||||
addAudioToStream(
|
||||
audioData: Float32Array, streamName: string, timestamp: number) {
|
||||
// numChannels and numSamples being 0 will cause defaults to be used,
|
||||
// which will reflect values from last call to configureAudio.
|
||||
this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp);
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes the raw data from a JS audio capture array, and sends it to C++ to be
|
||||
* processed, shaping the audioData array into an audio matrix according to
|
||||
* the numChannels and numSamples parameters.
|
||||
* @param audioData An array of raw audio capture data, like
|
||||
* from a call to getChannelData on an AudioBuffer.
|
||||
* @param numChannels The number of audio channels this data represents. If 0
|
||||
* is passed, then the value will be taken from the last call to
|
||||
* configureAudio.
|
||||
* @param numSamples The number of audio samples captured in this data packet.
|
||||
* If 0 is passed, then the value will be taken from the last call to
|
||||
* configureAudio.
|
||||
* @param streamName The name of the MediaPipe graph stream to add the audio
|
||||
* data to.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
*/
|
||||
addAudioToStreamWithShape(
|
||||
audioData: Float32Array, numChannels: number, numSamples: number,
|
||||
streamName: string, timestamp: number) {
|
||||
// 4 bytes for each F32
|
||||
const size = audioData.length * 4;
|
||||
if (this.audioSize !== size) {
|
||||
|
@ -450,7 +491,11 @@ export class GraphRunner {
|
|||
this.audioSize = size;
|
||||
}
|
||||
this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4);
|
||||
this.wasmModule._processAudio(this.audioPtr!, timestamp);
|
||||
|
||||
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
|
||||
this.wasmModule._addAudioToInputStream(
|
||||
this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -943,17 +988,45 @@ export class GraphRunner {
|
|||
}
|
||||
|
||||
/**
|
||||
* Sets a listener to be called back with audio output packet data, as a
|
||||
* Float32Array, when graph has finished processing it.
|
||||
* @param audioOutputListener The caller's listener function.
|
||||
* Attaches an audio packet listener to the specified output_stream, to be
|
||||
* given a Float32Array as output.
|
||||
* @param outputStreamName The name of the graph output stream to grab audio
|
||||
* data from.
|
||||
* @param callbackFcn The function that will be called back with the data, as
|
||||
* it is received. Note that the data is only guaranteed to exist for the
|
||||
* duration of the callback, and the callback will be called inline, so it
|
||||
* should not perform overly complicated (or any async) behavior. If the
|
||||
* audio data needs to be able to outlive the call, you may set the
|
||||
* optional makeDeepCopy parameter to true, or can manually deep-copy the
|
||||
* data yourself.
|
||||
* @param makeDeepCopy Optional convenience parameter which, if set to true,
|
||||
* will override the default memory management behavior and make a deep
|
||||
* copy of the underlying data, rather than just returning a view into the
|
||||
* C++-managed memory. At the cost of a data copy, this allows the
|
||||
* returned data to outlive the callback lifetime (and it will be cleaned
|
||||
* up automatically by JS garbage collection whenever the user is finished
|
||||
* with it).
|
||||
*/
|
||||
setOnAudioOutput(audioOutputListener: AudioOutputListener) {
|
||||
this.wasmModule.onAudioOutput = audioOutputListener;
|
||||
if (!this.wasmModule._attachAudioOutputListener) {
|
||||
attachAudioListener(outputStreamName: string,
|
||||
callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void {
|
||||
if (!this.wasmModule._attachAudioListener) {
|
||||
console.warn(
|
||||
'Attempting to use AudioOutputListener without support for ' +
|
||||
'Attempting to use attachAudioListener without support for ' +
|
||||
'output audio. Is build dep ":gl_graph_runner_audio_out" missing?');
|
||||
}
|
||||
|
||||
// Set up our TS listener to receive any packets for this stream, and
|
||||
// additionally reformat our Uint8Array into a Float32Array for the user.
|
||||
this.setListener(outputStreamName, (data: Uint8Array) => {
|
||||
const floatArray = new Float32Array(data.buffer); // Should be very fast
|
||||
callbackFcn(floatArray);
|
||||
});
|
||||
|
||||
// Tell our graph to listen for string packets on this stream.
|
||||
this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => {
|
||||
this.wasmModule._attachAudioListener(
|
||||
outputStreamNamePtr, makeDeepCopy || false);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue
Block a user