Merge pull request #4165 from kinaryml:audio-record-api-python

PiperOrigin-RevId: 522240683
This commit is contained in:
Copybara-Service 2023-04-05 21:46:20 -07:00
commit 7455022980
8 changed files with 332 additions and 8 deletions

View File

@ -23,12 +23,18 @@ py_library(
srcs = ["audio_task_running_mode.py"],
)
py_library(
name = "audio_record",
srcs = ["audio_record.py"],
)
py_library(
name = "base_audio_task_api",
srcs = [
"base_audio_task_api.py",
],
deps = [
":audio_record",
":audio_task_running_mode",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings",

View File

@ -0,0 +1,125 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module to record audio in a streaming basis."""
import threading
import numpy as np
try:
import sounddevice as sd
except OSError as oe:
sd = None
sd_error = oe
except ImportError as ie:
sd = None
sd_error = ie
class AudioRecord(object):
"""A class to record audio in a streaming basis."""
def __init__(
self, channels: int, sampling_rate: int, buffer_size: int
) -> None:
"""Creates an AudioRecord instance.
Args:
channels: Number of input channels.
sampling_rate: Sampling rate in Hertz.
buffer_size: Size of the ring buffer in number of samples.
Raises:
ValueError: if any of the arguments is non-positive.
ImportError: if failed to import `sounddevice`.
OSError: if failed to load `PortAudio`.
"""
if sd is None:
raise sd_error
if channels <= 0:
raise ValueError('channels must be postive.')
if sampling_rate <= 0:
raise ValueError('sampling_rate must be postive.')
if buffer_size <= 0:
raise ValueError('buffer_size must be postive.')
self._audio_buffer = []
self._buffer_size = buffer_size
self._channels = channels
self._sampling_rate = sampling_rate
# Create a ring buffer to store the input audio.
self._buffer = np.zeros([buffer_size, channels], dtype=float)
self._lock = threading.Lock()
def audio_callback(data, *_):
"""A callback to receive recorded audio data from sounddevice."""
self._lock.acquire()
shift = len(data)
if shift > buffer_size:
self._buffer = np.copy(data[:buffer_size])
else:
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = np.copy(data)
self._lock.release()
# Create an input stream to continuously capture the audio data.
self._stream = sd.InputStream(
channels=channels,
samplerate=sampling_rate,
callback=audio_callback,
)
@property
def channels(self) -> int:
return self._channels
@property
def sampling_rate(self) -> int:
return self._sampling_rate
@property
def buffer_size(self) -> int:
return self._buffer_size
def start_recording(self) -> None:
"""Starts the audio recording."""
# Clear the internal ring buffer.
self._buffer.fill(0)
# Start recording using sounddevice's InputStream.
self._stream.start()
def stop(self) -> None:
"""Stops the audio recording."""
self._stream.stop()
def read(self, size: int) -> np.ndarray:
"""Reads the latest audio data captured in the buffer.
Args:
size: Number of samples to read from the buffer.
Returns:
A NumPy array containing the audio data.
Raises:
ValueError: Raised if `size` is larger than the buffer size.
"""
if size > self._buffer_size:
raise ValueError('Cannot read more samples than the size of the buffer.')
elif size <= 0:
raise ValueError('Size must be positive.')
start_index = self._buffer_size - size
return np.copy(self._buffer[start_index:])

View File

@ -20,6 +20,7 @@ from mediapipe.python import packet_creator
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.python._framework_bindings import timestamp as timestamp_module
from mediapipe.tasks.python.audio.core import audio_record
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -83,12 +84,15 @@ class BaseAudioTaskApi(object):
"""
if self._running_mode != _RunningMode.AUDIO_CLIPS:
raise ValueError(
'Task is not initialized with the audio clips mode. Current running mode:'
+ self._running_mode.name)
'Task is not initialized with the audio clips mode. Current running'
' mode:'
+ self._running_mode.name
)
return self._runner.process(inputs)
def _set_sample_rate(self, sample_rate_stream_name: str,
sample_rate: float) -> None:
def _set_sample_rate(
self, sample_rate_stream_name: str, sample_rate: float
) -> None:
"""An asynchronous method to set audio sample rate in the audio stream mode.
Args:
@ -122,10 +126,40 @@ class BaseAudioTaskApi(object):
"""
if self._running_mode != _RunningMode.AUDIO_STREAM:
raise ValueError(
'Task is not initialized with the audio stream mode. Current running mode:'
+ self._running_mode.name)
'Task is not initialized with the audio stream mode. Current running'
' mode:'
+ self._running_mode.name
)
self._runner.send(inputs)
def create_audio_record(
self, num_channels: int, sample_rate: int, required_input_buffer_size: int
) -> audio_record.AudioRecord:
"""Creates an AudioRecord instance to record audio stream.
The returned AudioRecord instance is initialized and client needs to call
the appropriate method to start recording.
Note that MediaPipe Audio tasks will up/down sample automatically to fit the
sample rate required by the model. The default sample rate of the MediaPipe
pretrained audio model, Yamnet is 16kHz.
Args:
num_channels: The number of audio channels.
sample_rate: The audio sample rate.
required_input_buffer_size: The required input buffer size in number of
float elements.
Returns:
An AudioRecord instance.
Raises:
ValueError: If there's a problem creating the AudioRecord instance.
"""
return audio_record.AudioRecord(
num_channels, sample_rate, required_input_buffer_size
)
def close(self) -> None:
"""Shuts down the mediapipe audio task instance.

View File

@ -27,6 +27,7 @@ py_test(
],
deps = [
"//mediapipe/tasks/python/audio:audio_classifier",
"//mediapipe/tasks/python/audio/core:audio_record",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result",
@ -44,9 +45,9 @@ py_test(
],
deps = [
"//mediapipe/tasks/python/audio:audio_embedder",
"//mediapipe/tasks/python/audio/core:audio_record",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
],

View File

@ -19,11 +19,11 @@ from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_classifier
from mediapipe.tasks.python.audio.core import audio_record
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
@ -34,6 +34,7 @@ _AudioClassifier = audio_classifier.AudioClassifier
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
_AudioClassifierResult = classification_result_module.ClassificationResult
_AudioData = audio_data_module.AudioData
_AudioRecord = audio_record.AudioRecord
_BaseOptions = base_options_module.BaseOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
@ -204,6 +205,19 @@ class AudioClassifierTest(parameterized.TestCase):
self._read_wav_file(audio_file))
self._check_yamnet_result(classification_result_list)
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
def test_create_audio_record_from_classifier_succeeds(self, _):
# Creates AudioRecord instance using the classifier successfully.
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path
) as classifier:
self.assertIsInstance(classifier, _AudioClassifier)
record = classifier.create_audio_record(1, 16000, 16000)
self.assertIsInstance(record, _AudioRecord)
self.assertEqual(record.channels, 1)
self.assertEqual(record.sampling_rate, 16000)
self.assertEqual(record.buffer_size, 16000)
def test_max_result_options(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(

View File

@ -24,6 +24,7 @@ import numpy as np
from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_embedder
from mediapipe.tasks.python.audio.core import audio_record
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.core import base_options as base_options_module
@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
_AudioData = audio_data_module.AudioData
_AudioRecord = audio_record.AudioRecord
_BaseOptions = base_options_module.BaseOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
@ -165,6 +167,19 @@ class AudioEmbedderTest(parameterized.TestCase):
self.assertLen(embedding_result0_list, 5)
self.assertLen(embedding_result1_list, 5)
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
def test_create_audio_record_from_embedder_succeeds(self, _):
# Creates AudioRecord instance using the embedder successfully.
with _AudioEmbedder.create_from_model_path(
self.yamnet_model_path
) as embedder:
self.assertIsInstance(embedder, _AudioEmbedder)
record = embedder.create_audio_record(1, 16000, 16000)
self.assertIsInstance(record, _AudioRecord)
self.assertEqual(record.channels, 1)
self.assertEqual(record.sampling_rate, 16000)
self.assertEqual(record.buffer_size, 16000)
def test_embed_with_yamnet_model_and_different_inputs(self):
with _AudioEmbedder.create_from_model_path(
self.yamnet_model_path) as embedder:

View File

@ -0,0 +1,25 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_test(
name = "audio_record_test",
srcs = ["audio_record_test.py"],
deps = ["//mediapipe/tasks/python/audio/core:audio_record"],
)

View File

@ -0,0 +1,104 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for audio_record."""
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from mediapipe.tasks.python.audio.core import audio_record
_mock = unittest.mock
_CHANNELS = 2
_SAMPLING_RATE = 16000
_BUFFER_SIZE = 15600
class AudioRecordTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Mock sounddevice.InputStream
with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method:
self.mock_input_stream = _mock.MagicMock()
mock_input_stream_new_method.return_value = self.mock_input_stream
self.record = audio_record.AudioRecord(
_CHANNELS, _SAMPLING_RATE, _BUFFER_SIZE
)
# Save the initialization arguments of InputStream for later assertion.
_, self.init_args = mock_input_stream_new_method.call_args
def test_init_args(self):
# Assert parameters of InputStream initialization
self.assertEqual(
self.init_args["channels"],
_CHANNELS,
"InputStream's channels doesn't match the initialization argument.",
)
self.assertEqual(
self.init_args["samplerate"],
_SAMPLING_RATE,
"InputStream's samplerate doesn't match the initialization argument.",
)
def test_life_cycle(self):
# Assert start recording routine.
self.record.start_recording()
self.mock_input_stream.start.assert_called_once()
# Assert stop recording routine.
self.record.stop()
self.mock_input_stream.stop.assert_called_once()
def test_read_succeeds_with_valid_sample_size(self):
callback_fn = self.init_args["callback"]
# Create dummy data to feed to the AudioRecord instance.
chunk_size = int(_BUFFER_SIZE * 0.5)
input_data = []
for _ in range(3):
dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
input_data.append(dummy_data)
callback_fn(dummy_data)
# Assert read data of a single chunk.
recorded_audio_data = self.record.read(chunk_size)
self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1]))
# Assert read all data in buffer.
recorded_audio_data = self.record.read(chunk_size * 2)
print(input_data[-2].shape)
expected_data = np.concatenate(input_data[-2:])
self.assertTrue(np.array_equal(recorded_audio_data, expected_data))
def test_read_fails_with_invalid_sample_size(self):
callback_fn = self.init_args["callback"]
# Create dummy data to feed to the AudioRecord instance.
dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float)
callback_fn(dummy_data)
# Assert exception if read too much data.
with self.assertRaises(ValueError):
self.record.read(_BUFFER_SIZE + 1)
if __name__ == "__main__":
absltest.main()