From 97870565081643d3acd6c64428b50e0485a57709 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Fri, 10 Mar 2023 10:17:03 -0800 Subject: [PATCH] Added the AudioRecord API --- mediapipe/tasks/python/audio/core/BUILD | 1 + .../python/audio/core/base_audio_task_api.py | 28 ++++ .../tasks/python/components/containers/BUILD | 5 + .../components/containers/audio_record.py | 126 ++++++++++++++++++ mediapipe/tasks/python/test/audio/BUILD | 2 + .../test/audio/audio_classifier_test.py | 14 ++ .../python/test/audio/audio_embedder_test.py | 14 ++ mediapipe/tasks/python/test/audio/core/BUILD | 27 ++++ .../test/audio/core/audio_record_test.py | 97 ++++++++++++++ 9 files changed, 314 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/audio_record.py create mode 100644 mediapipe/tasks/python/test/audio/core/BUILD create mode 100644 mediapipe/tasks/python/test/audio/core/audio_record_test.py diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index 5b4203d7b..28dc4b960 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -34,5 +34,6 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/components/containers:audio_record", ], ) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index 5b08a2b76..80e8ad605 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -22,6 +22,7 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu from mediapipe.python._framework_bindings import timestamp as timestamp_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.components.containers import audio_record _TaskRunner = task_runner_module.TaskRunner _Packet = packet_module.Packet @@ -126,6 +127,33 @@ class BaseAudioTaskApi(object): + self._running_mode.name) self._runner.send(inputs) + @staticmethod + def create_audio_record( + num_channels: int, + sample_rate: int, + required_input_buffer_size: int + ) -> audio_record.AudioRecord: + """Creates an AudioRecord instance to record audio stream. + + The returned AudioRecord instance is initialized and client needs to call + the appropriate method to start recording. + + Note that MediaPipe Audio tasks will up/down sample automatically to fit the + sample rate required by the model. The default sample rate of the MediaPipe + pretrained audio model, Yamnet is 16kHz. + + Args: + num_channels: The number of audio channels. + sample_rate: The audio sample rate. + required_input_buffer_size: The required input buffer size in number of + float elements. + + Raises: + ValueError: If there's a problem creating the AudioRecord instance. + """ + return audio_record.AudioRecord(num_channels, sample_rate, + required_input_buffer_size) + def close(self) -> None: """Shuts down the mediapipe audio task instance. diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 7108617ff..61163365c 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -23,6 +23,11 @@ py_library( srcs = ["audio_data.py"], ) +py_library( + name = "audio_record", + srcs = ["audio_record.py"], +) + py_library( name = "bounding_box", srcs = ["bounding_box.py"], diff --git a/mediapipe/tasks/python/components/containers/audio_record.py b/mediapipe/tasks/python/components/containers/audio_record.py new file mode 100644 index 000000000..824f36e3e --- /dev/null +++ b/mediapipe/tasks/python/components/containers/audio_record.py @@ -0,0 +1,126 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A module to record audio in a streaming basis.""" +import threading +import numpy as np + +try: +# pylint: disable=g-import-not-at-top + import sounddevice as sd +# pylint: enable=g-import-not-at-top +except OSError as oe: + sd = None + sd_error = oe +except ImportError as ie: + sd = None + sd_error = ie + + +class AudioRecord(object): + """A class to record audio in a streaming basis.""" + + def __init__(self, channels: int, sampling_rate: int, + buffer_size: int) -> None: + """Creates an AudioRecord instance. + + Args: + channels: Number of input channels. + sampling_rate: Sampling rate in Hertz. + buffer_size: Size of the ring buffer in number of samples. + + Raises: + ValueError: if any of the arguments is non-positive. + ImportError: if failed to import `sounddevice`. + OSError: if failed to load `PortAudio`. + """ + if sd is None: + raise sd_error + + if channels <= 0: + raise ValueError('channels must be postive.') + if sampling_rate <= 0: + raise ValueError('sampling_rate must be postive.') + if buffer_size <= 0: + raise ValueError('buffer_size must be postive.') + + self._audio_buffer = [] + self._buffer_size = buffer_size + self._channels = channels + self._sampling_rate = sampling_rate + + # Create a ring buffer to store the input audio. + self._buffer = np.zeros([buffer_size, channels], dtype=float) + self._lock = threading.Lock() + + def audio_callback(data, *_): + """A callback to receive recorded audio data from sounddevice.""" + self._lock.acquire() + shift = len(data) + if shift > buffer_size: + self._buffer = np.copy(data[:buffer_size]) + else: + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = np.copy(data) + self._lock.release() + + # Create an input stream to continuously capture the audio data. + self._stream = sd.InputStream( + channels=channels, + samplerate=sampling_rate, + callback=audio_callback, + ) + + @property + def channels(self) -> int: + return self._channels + + @property + def sampling_rate(self) -> int: + return self._sampling_rate + + @property + def buffer_size(self) -> int: + return self._buffer_size + + def start_recording(self) -> None: + """Starts the audio recording.""" + # Clear the internal ring buffer. + self._buffer.fill(0) + + # Start recording using sounddevice's InputStream. + self._stream.start() + + def stop(self) -> None: + """Stops the audio recording.""" + self._stream.stop() + + def read(self, size: int) -> np.ndarray: + """Reads the latest audio data captured in the buffer. + + Args: + size: Number of samples to read from the buffer. + + Returns: + A NumPy array containing the audio data. + + Raises: + ValueError: Raised if `size` is larger than the buffer size. + """ + if size > self._buffer_size: + raise ValueError('Cannot read more samples than the size of the buffer.') + elif size <= 0: + raise ValueError('Size must be positive.') + + start_index = self._buffer_size - size + return np.copy(self._buffer[start_index:]) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 43f1d417c..d6e0788f2 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -29,6 +29,7 @@ py_test( "//mediapipe/tasks/python/audio:audio_classifier", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:classification_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", @@ -46,6 +47,7 @@ py_test( "//mediapipe/tasks/python/audio:audio_embedder", "//mediapipe/tasks/python/audio/core:audio_task_running_mode", "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:audio_record", "//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 75146547c..665a5ca13 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -27,6 +27,7 @@ from mediapipe.tasks.python.audio import audio_classifier from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -34,6 +35,7 @@ _AudioClassifier = audio_classifier.AudioClassifier _AudioClassifierOptions = audio_classifier.AudioClassifierOptions _AudioClassifierResult = classification_result_module.ClassificationResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -204,6 +206,18 @@ class AudioClassifierTest(parameterized.TestCase): self._read_wav_file(audio_file)) self._check_yamnet_result(classification_result_list) + @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock()) + def test_create_audio_record_from_classifier_succeeds(self, _): + # Creates AudioRecord instance using the classifier successfully. + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path) as classifier: + self.assertIsInstance(classifier, _AudioClassifier) + record = classifier.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_max_result_options(self): with _AudioClassifier.create_from_options( _AudioClassifierOptions( diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py index 934cdc8db..2015d2bce 100644 --- a/mediapipe/tasks/python/test/audio/audio_embedder_test.py +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -26,6 +26,7 @@ from scipy.io import wavfile from mediapipe.tasks.python.audio import audio_embedder from mediapipe.tasks.python.audio.core import audio_task_running_mode from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import audio_record from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils @@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder _AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions _AudioEmbedderResult = audio_embedder.AudioEmbedderResult _AudioData = audio_data_module.AudioData +_AudioRecord = audio_record.AudioRecord _BaseOptions = base_options_module.BaseOptions _RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode @@ -165,6 +167,18 @@ class AudioEmbedderTest(parameterized.TestCase): self.assertLen(embedding_result0_list, 5) self.assertLen(embedding_result1_list, 5) + @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock()) + def test_create_audio_record_from_embedder_succeeds(self, _): + # Creates AudioRecord instance using the embedder successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + record = embedder.create_audio_record(1, 16000, 16000) + self.assertIsInstance(record, _AudioRecord) + self.assertEqual(record.channels, 1) + self.assertEqual(record.sampling_rate, 16000) + self.assertEqual(record.buffer_size, 16000) + def test_embed_with_yamnet_model_and_different_inputs(self): with _AudioEmbedder.create_from_model_path( self.yamnet_model_path) as embedder: diff --git a/mediapipe/tasks/python/test/audio/core/BUILD b/mediapipe/tasks/python/test/audio/core/BUILD new file mode 100644 index 000000000..14f2e4f6c --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/BUILD @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "audio_record_test", + srcs = ["audio_record_test.py"], + deps = [ + "//mediapipe/tasks/python/components/containers:audio_record", + ], +) diff --git a/mediapipe/tasks/python/test/audio/core/audio_record_test.py b/mediapipe/tasks/python/test/audio/core/audio_record_test.py new file mode 100644 index 000000000..dfa72a822 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/core/audio_record_test.py @@ -0,0 +1,97 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio_record.""" + +import numpy as np +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +from mediapipe.tasks.python.components.containers import audio_record + +_mock = unittest.mock + +_CHANNELS = 2 +_SAMPLING_RATE = 16000 +_BUFFER_SIZE = 15600 + + +class AudioRecordTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + + # Mock sounddevice.InputStream + with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method: + self.mock_input_stream = _mock.MagicMock() + mock_input_stream_new_method.return_value = self.mock_input_stream + self.record = audio_record.AudioRecord(_CHANNELS, _SAMPLING_RATE, + _BUFFER_SIZE) + + # Save the initialization arguments of InputStream for later assertion. + _, self.init_args = mock_input_stream_new_method.call_args + + def test_init_args(self): + # Assert parameters of InputStream initialization + self.assertEqual( + self.init_args["channels"], _CHANNELS, + "InputStream's channels doesn't match the initialization argument.") + self.assertEqual( + self.init_args["samplerate"], _SAMPLING_RATE, + "InputStream's samplerate doesn't match the initialization argument.") + + def test_life_cycle(self): + # Assert start recording routine. + self.record.start_recording() + self.mock_input_stream.start.assert_called_once() + + # Assert stop recording routine. + self.record.stop() + self.mock_input_stream.stop.assert_called_once() + + def test_read_succeeds_with_valid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + chunk_size = int(_BUFFER_SIZE * 0.5) + input_data = [] + for _ in range(3): + dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float) + input_data.append(dummy_data) + callback_fn(dummy_data) + + # Assert read data of a single chunk. + recorded_audio_data = self.record.read(chunk_size) + self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1])) + + # Assert read all data in buffer. + recorded_audio_data = self.record.read(chunk_size * 2) + print(input_data[-2].shape) + expected_data = np.concatenate(input_data[-2:]) + self.assertTrue(np.array_equal(recorded_audio_data, expected_data)) + + def test_read_fails_with_invalid_sample_size(self): + callback_fn = self.init_args["callback"] + + # Create dummy data to feed to the AudioRecord instance. + dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float) + callback_fn(dummy_data) + + # Assert exception if read too much data. + with self.assertRaises(ValueError): + self.record.read(_BUFFER_SIZE + 1) + + +if __name__ == "__main__": + absltest.main()