Merge pull request #4165 from kinaryml:audio-record-api-python
PiperOrigin-RevId: 522240683
This commit is contained in:
commit
7455022980
|
@ -23,12 +23,18 @@ py_library(
|
||||||
srcs = ["audio_task_running_mode.py"],
|
srcs = ["audio_task_running_mode.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_record",
|
||||||
|
srcs = ["audio_record.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "base_audio_task_api",
|
name = "base_audio_task_api",
|
||||||
srcs = [
|
srcs = [
|
||||||
"base_audio_task_api.py",
|
"base_audio_task_api.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":audio_record",
|
||||||
":audio_task_running_mode",
|
":audio_task_running_mode",
|
||||||
"//mediapipe/framework:calculator_py_pb2",
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
|
125
mediapipe/tasks/python/audio/core/audio_record.py
Normal file
125
mediapipe/tasks/python/audio/core/audio_record.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""A module to record audio in a streaming basis."""
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except OSError as oe:
|
||||||
|
sd = None
|
||||||
|
sd_error = oe
|
||||||
|
except ImportError as ie:
|
||||||
|
sd = None
|
||||||
|
sd_error = ie
|
||||||
|
|
||||||
|
|
||||||
|
class AudioRecord(object):
|
||||||
|
"""A class to record audio in a streaming basis."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, sampling_rate: int, buffer_size: int
|
||||||
|
) -> None:
|
||||||
|
"""Creates an AudioRecord instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: Number of input channels.
|
||||||
|
sampling_rate: Sampling rate in Hertz.
|
||||||
|
buffer_size: Size of the ring buffer in number of samples.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if any of the arguments is non-positive.
|
||||||
|
ImportError: if failed to import `sounddevice`.
|
||||||
|
OSError: if failed to load `PortAudio`.
|
||||||
|
"""
|
||||||
|
if sd is None:
|
||||||
|
raise sd_error
|
||||||
|
|
||||||
|
if channels <= 0:
|
||||||
|
raise ValueError('channels must be postive.')
|
||||||
|
if sampling_rate <= 0:
|
||||||
|
raise ValueError('sampling_rate must be postive.')
|
||||||
|
if buffer_size <= 0:
|
||||||
|
raise ValueError('buffer_size must be postive.')
|
||||||
|
|
||||||
|
self._audio_buffer = []
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
self._channels = channels
|
||||||
|
self._sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
# Create a ring buffer to store the input audio.
|
||||||
|
self._buffer = np.zeros([buffer_size, channels], dtype=float)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def audio_callback(data, *_):
|
||||||
|
"""A callback to receive recorded audio data from sounddevice."""
|
||||||
|
self._lock.acquire()
|
||||||
|
shift = len(data)
|
||||||
|
if shift > buffer_size:
|
||||||
|
self._buffer = np.copy(data[:buffer_size])
|
||||||
|
else:
|
||||||
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||||
|
self._buffer[-shift:, :] = np.copy(data)
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
# Create an input stream to continuously capture the audio data.
|
||||||
|
self._stream = sd.InputStream(
|
||||||
|
channels=channels,
|
||||||
|
samplerate=sampling_rate,
|
||||||
|
callback=audio_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> int:
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_rate(self) -> int:
|
||||||
|
return self._sampling_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_size(self) -> int:
|
||||||
|
return self._buffer_size
|
||||||
|
|
||||||
|
def start_recording(self) -> None:
|
||||||
|
"""Starts the audio recording."""
|
||||||
|
# Clear the internal ring buffer.
|
||||||
|
self._buffer.fill(0)
|
||||||
|
|
||||||
|
# Start recording using sounddevice's InputStream.
|
||||||
|
self._stream.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stops the audio recording."""
|
||||||
|
self._stream.stop()
|
||||||
|
|
||||||
|
def read(self, size: int) -> np.ndarray:
|
||||||
|
"""Reads the latest audio data captured in the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: Number of samples to read from the buffer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A NumPy array containing the audio data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Raised if `size` is larger than the buffer size.
|
||||||
|
"""
|
||||||
|
if size > self._buffer_size:
|
||||||
|
raise ValueError('Cannot read more samples than the size of the buffer.')
|
||||||
|
elif size <= 0:
|
||||||
|
raise ValueError('Size must be positive.')
|
||||||
|
|
||||||
|
start_index = self._buffer_size - size
|
||||||
|
return np.copy(self._buffer[start_index:])
|
|
@ -20,6 +20,7 @@ from mediapipe.python import packet_creator
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
from mediapipe.python._framework_bindings import timestamp as timestamp_module
|
from mediapipe.python._framework_bindings import timestamp as timestamp_module
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_record
|
||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
@ -83,12 +84,15 @@ class BaseAudioTaskApi(object):
|
||||||
"""
|
"""
|
||||||
if self._running_mode != _RunningMode.AUDIO_CLIPS:
|
if self._running_mode != _RunningMode.AUDIO_CLIPS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Task is not initialized with the audio clips mode. Current running mode:'
|
'Task is not initialized with the audio clips mode. Current running'
|
||||||
+ self._running_mode.name)
|
' mode:'
|
||||||
|
+ self._running_mode.name
|
||||||
|
)
|
||||||
return self._runner.process(inputs)
|
return self._runner.process(inputs)
|
||||||
|
|
||||||
def _set_sample_rate(self, sample_rate_stream_name: str,
|
def _set_sample_rate(
|
||||||
sample_rate: float) -> None:
|
self, sample_rate_stream_name: str, sample_rate: float
|
||||||
|
) -> None:
|
||||||
"""An asynchronous method to set audio sample rate in the audio stream mode.
|
"""An asynchronous method to set audio sample rate in the audio stream mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -122,10 +126,40 @@ class BaseAudioTaskApi(object):
|
||||||
"""
|
"""
|
||||||
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Task is not initialized with the audio stream mode. Current running mode:'
|
'Task is not initialized with the audio stream mode. Current running'
|
||||||
+ self._running_mode.name)
|
' mode:'
|
||||||
|
+ self._running_mode.name
|
||||||
|
)
|
||||||
self._runner.send(inputs)
|
self._runner.send(inputs)
|
||||||
|
|
||||||
|
def create_audio_record(
|
||||||
|
self, num_channels: int, sample_rate: int, required_input_buffer_size: int
|
||||||
|
) -> audio_record.AudioRecord:
|
||||||
|
"""Creates an AudioRecord instance to record audio stream.
|
||||||
|
|
||||||
|
The returned AudioRecord instance is initialized and client needs to call
|
||||||
|
the appropriate method to start recording.
|
||||||
|
|
||||||
|
Note that MediaPipe Audio tasks will up/down sample automatically to fit the
|
||||||
|
sample rate required by the model. The default sample rate of the MediaPipe
|
||||||
|
pretrained audio model, Yamnet is 16kHz.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_channels: The number of audio channels.
|
||||||
|
sample_rate: The audio sample rate.
|
||||||
|
required_input_buffer_size: The required input buffer size in number of
|
||||||
|
float elements.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An AudioRecord instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a problem creating the AudioRecord instance.
|
||||||
|
"""
|
||||||
|
return audio_record.AudioRecord(
|
||||||
|
num_channels, sample_rate, required_input_buffer_size
|
||||||
|
)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Shuts down the mediapipe audio task instance.
|
"""Shuts down the mediapipe audio task instance.
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/audio:audio_classifier",
|
"//mediapipe/tasks/python/audio:audio_classifier",
|
||||||
|
"//mediapipe/tasks/python/audio/core:audio_record",
|
||||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
@ -44,9 +45,9 @@ py_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/python/audio:audio_embedder",
|
"//mediapipe/tasks/python/audio:audio_embedder",
|
||||||
|
"//mediapipe/tasks/python/audio/core:audio_record",
|
||||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,11 +19,11 @@ from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
from mediapipe.tasks.python.audio import audio_classifier
|
from mediapipe.tasks.python.audio import audio_classifier
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_record
|
||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
@ -34,6 +34,7 @@ _AudioClassifier = audio_classifier.AudioClassifier
|
||||||
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
||||||
_AudioClassifierResult = classification_result_module.ClassificationResult
|
_AudioClassifierResult = classification_result_module.ClassificationResult
|
||||||
_AudioData = audio_data_module.AudioData
|
_AudioData = audio_data_module.AudioData
|
||||||
|
_AudioRecord = audio_record.AudioRecord
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||||
|
|
||||||
|
@ -204,6 +205,19 @@ class AudioClassifierTest(parameterized.TestCase):
|
||||||
self._read_wav_file(audio_file))
|
self._read_wav_file(audio_file))
|
||||||
self._check_yamnet_result(classification_result_list)
|
self._check_yamnet_result(classification_result_list)
|
||||||
|
|
||||||
|
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
|
||||||
|
def test_create_audio_record_from_classifier_succeeds(self, _):
|
||||||
|
# Creates AudioRecord instance using the classifier successfully.
|
||||||
|
with _AudioClassifier.create_from_model_path(
|
||||||
|
self.yamnet_model_path
|
||||||
|
) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _AudioClassifier)
|
||||||
|
record = classifier.create_audio_record(1, 16000, 16000)
|
||||||
|
self.assertIsInstance(record, _AudioRecord)
|
||||||
|
self.assertEqual(record.channels, 1)
|
||||||
|
self.assertEqual(record.sampling_rate, 16000)
|
||||||
|
self.assertEqual(record.buffer_size, 16000)
|
||||||
|
|
||||||
def test_max_result_options(self):
|
def test_max_result_options(self):
|
||||||
with _AudioClassifier.create_from_options(
|
with _AudioClassifier.create_from_options(
|
||||||
_AudioClassifierOptions(
|
_AudioClassifierOptions(
|
||||||
|
|
|
@ -24,6 +24,7 @@ import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
from mediapipe.tasks.python.audio import audio_embedder
|
from mediapipe.tasks.python.audio import audio_embedder
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_record
|
||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder
|
||||||
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
|
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
|
||||||
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
|
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
|
||||||
_AudioData = audio_data_module.AudioData
|
_AudioData = audio_data_module.AudioData
|
||||||
|
_AudioRecord = audio_record.AudioRecord
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||||
|
|
||||||
|
@ -165,6 +167,19 @@ class AudioEmbedderTest(parameterized.TestCase):
|
||||||
self.assertLen(embedding_result0_list, 5)
|
self.assertLen(embedding_result0_list, 5)
|
||||||
self.assertLen(embedding_result1_list, 5)
|
self.assertLen(embedding_result1_list, 5)
|
||||||
|
|
||||||
|
@mock.patch('sounddevice.InputStream', return_value=mock.MagicMock())
|
||||||
|
def test_create_audio_record_from_embedder_succeeds(self, _):
|
||||||
|
# Creates AudioRecord instance using the embedder successfully.
|
||||||
|
with _AudioEmbedder.create_from_model_path(
|
||||||
|
self.yamnet_model_path
|
||||||
|
) as embedder:
|
||||||
|
self.assertIsInstance(embedder, _AudioEmbedder)
|
||||||
|
record = embedder.create_audio_record(1, 16000, 16000)
|
||||||
|
self.assertIsInstance(record, _AudioRecord)
|
||||||
|
self.assertEqual(record.channels, 1)
|
||||||
|
self.assertEqual(record.sampling_rate, 16000)
|
||||||
|
self.assertEqual(record.buffer_size, 16000)
|
||||||
|
|
||||||
def test_embed_with_yamnet_model_and_different_inputs(self):
|
def test_embed_with_yamnet_model_and_different_inputs(self):
|
||||||
with _AudioEmbedder.create_from_model_path(
|
with _AudioEmbedder.create_from_model_path(
|
||||||
self.yamnet_model_path) as embedder:
|
self.yamnet_model_path) as embedder:
|
||||||
|
|
25
mediapipe/tasks/python/test/audio/core/BUILD
Normal file
25
mediapipe/tasks/python/test/audio/core/BUILD
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "audio_record_test",
|
||||||
|
srcs = ["audio_record_test.py"],
|
||||||
|
deps = ["//mediapipe/tasks/python/audio/core:audio_record"],
|
||||||
|
)
|
104
mediapipe/tasks/python/test/audio/core/audio_record_test.py
Normal file
104
mediapipe/tasks/python/test/audio/core/audio_record_test.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for audio_record."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.audio.core import audio_record
|
||||||
|
|
||||||
|
|
||||||
|
_mock = unittest.mock
|
||||||
|
|
||||||
|
_CHANNELS = 2
|
||||||
|
_SAMPLING_RATE = 16000
|
||||||
|
_BUFFER_SIZE = 15600
|
||||||
|
|
||||||
|
|
||||||
|
class AudioRecordTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# Mock sounddevice.InputStream
|
||||||
|
with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method:
|
||||||
|
self.mock_input_stream = _mock.MagicMock()
|
||||||
|
mock_input_stream_new_method.return_value = self.mock_input_stream
|
||||||
|
self.record = audio_record.AudioRecord(
|
||||||
|
_CHANNELS, _SAMPLING_RATE, _BUFFER_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initialization arguments of InputStream for later assertion.
|
||||||
|
_, self.init_args = mock_input_stream_new_method.call_args
|
||||||
|
|
||||||
|
def test_init_args(self):
|
||||||
|
# Assert parameters of InputStream initialization
|
||||||
|
self.assertEqual(
|
||||||
|
self.init_args["channels"],
|
||||||
|
_CHANNELS,
|
||||||
|
"InputStream's channels doesn't match the initialization argument.",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
self.init_args["samplerate"],
|
||||||
|
_SAMPLING_RATE,
|
||||||
|
"InputStream's samplerate doesn't match the initialization argument.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_life_cycle(self):
|
||||||
|
# Assert start recording routine.
|
||||||
|
self.record.start_recording()
|
||||||
|
self.mock_input_stream.start.assert_called_once()
|
||||||
|
|
||||||
|
# Assert stop recording routine.
|
||||||
|
self.record.stop()
|
||||||
|
self.mock_input_stream.stop.assert_called_once()
|
||||||
|
|
||||||
|
def test_read_succeeds_with_valid_sample_size(self):
|
||||||
|
callback_fn = self.init_args["callback"]
|
||||||
|
|
||||||
|
# Create dummy data to feed to the AudioRecord instance.
|
||||||
|
chunk_size = int(_BUFFER_SIZE * 0.5)
|
||||||
|
input_data = []
|
||||||
|
for _ in range(3):
|
||||||
|
dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
|
||||||
|
input_data.append(dummy_data)
|
||||||
|
callback_fn(dummy_data)
|
||||||
|
|
||||||
|
# Assert read data of a single chunk.
|
||||||
|
recorded_audio_data = self.record.read(chunk_size)
|
||||||
|
self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1]))
|
||||||
|
|
||||||
|
# Assert read all data in buffer.
|
||||||
|
recorded_audio_data = self.record.read(chunk_size * 2)
|
||||||
|
print(input_data[-2].shape)
|
||||||
|
expected_data = np.concatenate(input_data[-2:])
|
||||||
|
self.assertTrue(np.array_equal(recorded_audio_data, expected_data))
|
||||||
|
|
||||||
|
def test_read_fails_with_invalid_sample_size(self):
|
||||||
|
callback_fn = self.init_args["callback"]
|
||||||
|
|
||||||
|
# Create dummy data to feed to the AudioRecord instance.
|
||||||
|
dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float)
|
||||||
|
callback_fn(dummy_data)
|
||||||
|
|
||||||
|
# Assert exception if read too much data.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.record.read(_BUFFER_SIZE + 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
absltest.main()
|
Loading…
Reference in New Issue
Block a user