Merge pull request #4165 from kinaryml:audio-record-api-python
PiperOrigin-RevId: 522240683
This commit is contained in:
commit
7455022980
|
@ -23,12 +23,18 @@ py_library(
|
|||
srcs = ["audio_task_running_mode.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "audio_record",
|
||||
srcs = ["audio_record.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_audio_task_api",
|
||||
srcs = [
|
||||
"base_audio_task_api.py",
|
||||
],
|
||||
deps = [
|
||||
":audio_record",
|
||||
":audio_task_running_mode",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
|
|
125
mediapipe/tasks/python/audio/core/audio_record.py
Normal file
125
mediapipe/tasks/python/audio/core/audio_record.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A module to record audio in a streaming basis."""
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except OSError as oe:
|
||||
sd = None
|
||||
sd_error = oe
|
||||
except ImportError as ie:
|
||||
sd = None
|
||||
sd_error = ie
|
||||
|
||||
|
||||
class AudioRecord(object):
|
||||
"""A class to record audio in a streaming basis."""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, sampling_rate: int, buffer_size: int
|
||||
) -> None:
|
||||
"""Creates an AudioRecord instance.
|
||||
|
||||
Args:
|
||||
channels: Number of input channels.
|
||||
sampling_rate: Sampling rate in Hertz.
|
||||
buffer_size: Size of the ring buffer in number of samples.
|
||||
|
||||
Raises:
|
||||
ValueError: if any of the arguments is non-positive.
|
||||
ImportError: if failed to import `sounddevice`.
|
||||
OSError: if failed to load `PortAudio`.
|
||||
"""
|
||||
if sd is None:
|
||||
raise sd_error
|
||||
|
||||
if channels <= 0:
|
||||
raise ValueError('channels must be postive.')
|
||||
if sampling_rate <= 0:
|
||||
raise ValueError('sampling_rate must be postive.')
|
||||
if buffer_size <= 0:
|
||||
raise ValueError('buffer_size must be postive.')
|
||||
|
||||
self._audio_buffer = []
|
||||
self._buffer_size = buffer_size
|
||||
self._channels = channels
|
||||
self._sampling_rate = sampling_rate
|
||||
|
||||
# Create a ring buffer to store the input audio.
|
||||
self._buffer = np.zeros([buffer_size, channels], dtype=float)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def audio_callback(data, *_):
|
||||
"""A callback to receive recorded audio data from sounddevice."""
|
||||
self._lock.acquire()
|
||||
shift = len(data)
|
||||
if shift > buffer_size:
|
||||
self._buffer = np.copy(data[:buffer_size])
|
||||
else:
|
||||
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||
self._buffer[-shift:, :] = np.copy(data)
|
||||
self._lock.release()
|
||||
|
||||
# Create an input stream to continuously capture the audio data.
|
||||
self._stream = sd.InputStream(
|
||||
channels=channels,
|
||||
samplerate=sampling_rate,
|
||||
callback=audio_callback,
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
return self._channels
|
||||
|
||||
@property
|
||||
def sampling_rate(self) -> int:
|
||||
return self._sampling_rate
|
||||
|
||||
@property
|
||||
def buffer_size(self) -> int:
|
||||
return self._buffer_size
|
||||
|
||||
def start_recording(self) -> None:
|
||||
"""Starts the audio recording."""
|
||||
# Clear the internal ring buffer.
|
||||
self._buffer.fill(0)
|
||||
|
||||
# Start recording using sounddevice's InputStream.
|
||||
self._stream.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the audio recording."""
|
||||
self._stream.stop()
|
||||
|
||||
def read(self, size: int) -> np.ndarray:
|
||||
"""Reads the latest audio data captured in the buffer.
|
||||
|
||||
Args:
|
||||
size: Number of samples to read from the buffer.
|
||||
|
||||
Returns:
|
||||
A NumPy array containing the audio data.
|
||||
|
||||
Raises:
|
||||
ValueError: Raised if `size` is larger than the buffer size.
|
||||
"""
|
||||
if size > self._buffer_size:
|
||||
raise ValueError('Cannot read more samples than the size of the buffer.')
|
||||
elif size <= 0:
|
||||
raise ValueError('Size must be positive.')
|
||||
|
||||
start_index = self._buffer_size - size
|
||||
return np.copy(self._buffer[start_index:])
|
|
@ -20,6 +20,7 @@ from mediapipe.python import packet_creator
|
|||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.python._framework_bindings import timestamp as timestamp_module
|
||||
from mediapipe.tasks.python.audio.core import audio_record
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
|
@ -83,12 +84,15 @@ class BaseAudioTaskApi(object):
|
|||
"""
|
||||
if self._running_mode != _RunningMode.AUDIO_CLIPS:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the audio clips mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
'Task is not initialized with the audio clips mode. Current running'
|
||||
' mode:'
|
||||
+ self._running_mode.name
|
||||
)
|
||||
return self._runner.process(inputs)
|
||||
|
||||
def _set_sample_rate(self, sample_rate_stream_name: str,
|
||||
sample_rate: float) -> None:
|
||||
def _set_sample_rate(
|
||||
self, sample_rate_stream_name: str, sample_rate: float
|
||||
) -> None:
|
||||
"""An asynchronous method to set audio sample rate in the audio stream mode.
|
||||
|
||||
Args:
|
||||
|
@ -122,10 +126,40 @@ class BaseAudioTaskApi(object):
|
|||
"""
|
||||
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the audio stream mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
'Task is not initialized with the audio stream mode. Current running'
|
||||
' mode:'
|
||||
+ self._running_mode.name
|
||||
)
|
||||
self._runner.send(inputs)
|
||||
|
||||
def create_audio_record(
|
||||
self, num_channels: int, sample_rate: int, required_input_buffer_size: int
|
||||
) -> audio_record.AudioRecord:
|
||||
"""Creates an AudioRecord instance to record audio stream.
|
||||
|
||||
The returned AudioRecord instance is initialized and client needs to call
|
||||
the appropriate method to start recording.
|
||||
|
||||
Note that MediaPipe Audio tasks will up/down sample automatically to fit the
|
||||
sample rate required by the model. The default sample rate of the MediaPipe
|
||||
pretrained audio model, Yamnet is 16kHz.
|
||||
|
||||
Args:
|
||||
num_channels: The number of audio channels.
|
||||
sample_rate: The audio sample rate.
|
||||
required_input_buffer_size: The required input buffer size in number of
|
||||
float elements.
|
||||
|
||||
Returns:
|
||||
An AudioRecord instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a problem creating the AudioRecord instance.
|
||||
"""
|
||||
return audio_record.AudioRecord(
|
||||
num_channels, sample_rate, required_input_buffer_size
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Shuts down the mediapipe audio task instance.
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/audio:audio_classifier",
|
||||
"//mediapipe/tasks/python/audio/core:audio_record",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
|
@ -44,9 +45,9 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/audio:audio_embedder",
|
||||
"//mediapipe/tasks/python/audio/core:audio_record",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
|
|
|
@ -19,11 +19,11 @@ from unittest import mock
|
|||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
from mediapipe.tasks.python.audio import audio_classifier
|
||||
from mediapipe.tasks.python.audio.core import audio_record
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
|
@ -34,6 +34,7 @@ _AudioClassifier = audio_classifier.AudioClassifier
|
|||
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
||||
_AudioClassifierResult = classification_result_module.ClassificationResult
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_AudioRecord = audio_record.AudioRecord
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||
|
||||
|
@ -204,6 +205,19 @@ class AudioClassifierTest(parameterized.TestCase):
|
|||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(classification_result_list)
|
||||
|
||||
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
|
||||
def test_create_audio_record_from_classifier_succeeds(self, _):
|
||||
# Creates AudioRecord instance using the classifier successfully.
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.yamnet_model_path
|
||||
) as classifier:
|
||||
self.assertIsInstance(classifier, _AudioClassifier)
|
||||
record = classifier.create_audio_record(1, 16000, 16000)
|
||||
self.assertIsInstance(record, _AudioRecord)
|
||||
self.assertEqual(record.channels, 1)
|
||||
self.assertEqual(record.sampling_rate, 16000)
|
||||
self.assertEqual(record.buffer_size, 16000)
|
||||
|
||||
def test_max_result_options(self):
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
|
|
|
@ -24,6 +24,7 @@ import numpy as np
|
|||
from scipy.io import wavfile
|
||||
|
||||
from mediapipe.tasks.python.audio import audio_embedder
|
||||
from mediapipe.tasks.python.audio.core import audio_record
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder
|
|||
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
|
||||
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_AudioRecord = audio_record.AudioRecord
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||
|
||||
|
@ -165,6 +167,19 @@ class AudioEmbedderTest(parameterized.TestCase):
|
|||
self.assertLen(embedding_result0_list, 5)
|
||||
self.assertLen(embedding_result1_list, 5)
|
||||
|
||||
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
|
||||
def test_create_audio_record_from_embedder_succeeds(self, _):
|
||||
# Creates AudioRecord instance using the embedder successfully.
|
||||
with _AudioEmbedder.create_from_model_path(
|
||||
self.yamnet_model_path
|
||||
) as embedder:
|
||||
self.assertIsInstance(embedder, _AudioEmbedder)
|
||||
record = embedder.create_audio_record(1, 16000, 16000)
|
||||
self.assertIsInstance(record, _AudioRecord)
|
||||
self.assertEqual(record.channels, 1)
|
||||
self.assertEqual(record.sampling_rate, 16000)
|
||||
self.assertEqual(record.buffer_size, 16000)
|
||||
|
||||
def test_embed_with_yamnet_model_and_different_inputs(self):
|
||||
with _AudioEmbedder.create_from_model_path(
|
||||
self.yamnet_model_path) as embedder:
|
||||
|
|
25
mediapipe/tasks/python/test/audio/core/BUILD
Normal file
25
mediapipe/tasks/python/test/audio/core/BUILD
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_test(
|
||||
name = "audio_record_test",
|
||||
srcs = ["audio_record_test.py"],
|
||||
deps = ["//mediapipe/tasks/python/audio/core:audio_record"],
|
||||
)
|
104
mediapipe/tasks/python/test/audio/core/audio_record_test.py
Normal file
104
mediapipe/tasks/python/test/audio/core/audio_record_test.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for audio_record."""
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.audio.core import audio_record
|
||||
|
||||
|
||||
_mock = unittest.mock
|
||||
|
||||
_CHANNELS = 2
|
||||
_SAMPLING_RATE = 16000
|
||||
_BUFFER_SIZE = 15600
|
||||
|
||||
|
||||
class AudioRecordTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Mock sounddevice.InputStream
|
||||
with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method:
|
||||
self.mock_input_stream = _mock.MagicMock()
|
||||
mock_input_stream_new_method.return_value = self.mock_input_stream
|
||||
self.record = audio_record.AudioRecord(
|
||||
_CHANNELS, _SAMPLING_RATE, _BUFFER_SIZE
|
||||
)
|
||||
|
||||
# Save the initialization arguments of InputStream for later assertion.
|
||||
_, self.init_args = mock_input_stream_new_method.call_args
|
||||
|
||||
def test_init_args(self):
|
||||
# Assert parameters of InputStream initialization
|
||||
self.assertEqual(
|
||||
self.init_args["channels"],
|
||||
_CHANNELS,
|
||||
"InputStream's channels doesn't match the initialization argument.",
|
||||
)
|
||||
self.assertEqual(
|
||||
self.init_args["samplerate"],
|
||||
_SAMPLING_RATE,
|
||||
"InputStream's samplerate doesn't match the initialization argument.",
|
||||
)
|
||||
|
||||
def test_life_cycle(self):
|
||||
# Assert start recording routine.
|
||||
self.record.start_recording()
|
||||
self.mock_input_stream.start.assert_called_once()
|
||||
|
||||
# Assert stop recording routine.
|
||||
self.record.stop()
|
||||
self.mock_input_stream.stop.assert_called_once()
|
||||
|
||||
def test_read_succeeds_with_valid_sample_size(self):
|
||||
callback_fn = self.init_args["callback"]
|
||||
|
||||
# Create dummy data to feed to the AudioRecord instance.
|
||||
chunk_size = int(_BUFFER_SIZE * 0.5)
|
||||
input_data = []
|
||||
for _ in range(3):
|
||||
dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
|
||||
input_data.append(dummy_data)
|
||||
callback_fn(dummy_data)
|
||||
|
||||
# Assert read data of a single chunk.
|
||||
recorded_audio_data = self.record.read(chunk_size)
|
||||
self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1]))
|
||||
|
||||
# Assert read all data in buffer.
|
||||
recorded_audio_data = self.record.read(chunk_size * 2)
|
||||
print(input_data[-2].shape)
|
||||
expected_data = np.concatenate(input_data[-2:])
|
||||
self.assertTrue(np.array_equal(recorded_audio_data, expected_data))
|
||||
|
||||
def test_read_fails_with_invalid_sample_size(self):
|
||||
callback_fn = self.init_args["callback"]
|
||||
|
||||
# Create dummy data to feed to the AudioRecord instance.
|
||||
dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float)
|
||||
callback_fn(dummy_data)
|
||||
|
||||
# Assert exception if read too much data.
|
||||
with self.assertRaises(ValueError):
|
||||
self.record.read(_BUFFER_SIZE + 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
Reference in New Issue
Block a user