Added the AudioRecord API
This commit is contained in:
		
							parent
							
								
									05b505c8e2
								
							
						
					
					
						commit
						9787056508
					
				| 
						 | 
				
			
			@ -34,5 +34,6 @@ py_library(
 | 
			
		|||
        "//mediapipe/python:_framework_bindings",
 | 
			
		||||
        "//mediapipe/python:packet_creator",
 | 
			
		||||
        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,6 +22,7 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu
 | 
			
		|||
from mediapipe.python._framework_bindings import timestamp as timestamp_module
 | 
			
		||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
			
		||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_record
 | 
			
		||||
 | 
			
		||||
_TaskRunner = task_runner_module.TaskRunner
 | 
			
		||||
_Packet = packet_module.Packet
 | 
			
		||||
| 
						 | 
				
			
			@ -126,6 +127,33 @@ class BaseAudioTaskApi(object):
 | 
			
		|||
          + self._running_mode.name)
 | 
			
		||||
    self._runner.send(inputs)
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def create_audio_record(
 | 
			
		||||
      num_channels: int,
 | 
			
		||||
      sample_rate: int,
 | 
			
		||||
      required_input_buffer_size: int
 | 
			
		||||
  ) -> audio_record.AudioRecord:
 | 
			
		||||
    """Creates an AudioRecord instance to record audio stream.
 | 
			
		||||
 | 
			
		||||
    The returned AudioRecord instance is initialized and client needs to call
 | 
			
		||||
    the appropriate method to start recording.
 | 
			
		||||
 | 
			
		||||
    Note that MediaPipe Audio tasks will up/down sample automatically to fit the
 | 
			
		||||
    sample rate required by the model. The default sample rate of the MediaPipe
 | 
			
		||||
    pretrained audio model, Yamnet is 16kHz.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      num_channels: The number of audio channels.
 | 
			
		||||
      sample_rate: The audio sample rate.
 | 
			
		||||
      required_input_buffer_size: The required input buffer size in number of
 | 
			
		||||
        float elements.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If there's a problem creating the AudioRecord instance.
 | 
			
		||||
    """
 | 
			
		||||
    return audio_record.AudioRecord(num_channels, sample_rate,
 | 
			
		||||
                                    required_input_buffer_size)
 | 
			
		||||
 | 
			
		||||
  def close(self) -> None:
 | 
			
		||||
    """Shuts down the mediapipe audio task instance.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,6 +23,11 @@ py_library(
 | 
			
		|||
    srcs = ["audio_data.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "audio_record",
 | 
			
		||||
    srcs = ["audio_record.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "bounding_box",
 | 
			
		||||
    srcs = ["bounding_box.py"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										126
									
								
								mediapipe/tasks/python/components/containers/audio_record.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								mediapipe/tasks/python/components/containers/audio_record.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,126 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""A module to record audio in a streaming basis."""
 | 
			
		||||
import threading
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
# pylint: disable=g-import-not-at-top
 | 
			
		||||
  import sounddevice as sd
 | 
			
		||||
# pylint: enable=g-import-not-at-top
 | 
			
		||||
except OSError as oe:
 | 
			
		||||
  sd = None
 | 
			
		||||
  sd_error = oe
 | 
			
		||||
except ImportError as ie:
 | 
			
		||||
  sd = None
 | 
			
		||||
  sd_error = ie
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AudioRecord(object):
 | 
			
		||||
  """A class to record audio in a streaming basis."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, channels: int, sampling_rate: int,
 | 
			
		||||
               buffer_size: int) -> None:
 | 
			
		||||
    """Creates an AudioRecord instance.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      channels: Number of input channels.
 | 
			
		||||
      sampling_rate: Sampling rate in Hertz.
 | 
			
		||||
      buffer_size: Size of the ring buffer in number of samples.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: if any of the arguments is non-positive.
 | 
			
		||||
      ImportError: if failed to import `sounddevice`.
 | 
			
		||||
      OSError: if failed to load `PortAudio`.
 | 
			
		||||
    """
 | 
			
		||||
    if sd is None:
 | 
			
		||||
      raise sd_error
 | 
			
		||||
 | 
			
		||||
    if channels <= 0:
 | 
			
		||||
      raise ValueError('channels must be postive.')
 | 
			
		||||
    if sampling_rate <= 0:
 | 
			
		||||
      raise ValueError('sampling_rate must be postive.')
 | 
			
		||||
    if buffer_size <= 0:
 | 
			
		||||
      raise ValueError('buffer_size must be postive.')
 | 
			
		||||
 | 
			
		||||
    self._audio_buffer = []
 | 
			
		||||
    self._buffer_size = buffer_size
 | 
			
		||||
    self._channels = channels
 | 
			
		||||
    self._sampling_rate = sampling_rate
 | 
			
		||||
 | 
			
		||||
    # Create a ring buffer to store the input audio.
 | 
			
		||||
    self._buffer = np.zeros([buffer_size, channels], dtype=float)
 | 
			
		||||
    self._lock = threading.Lock()
 | 
			
		||||
 | 
			
		||||
    def audio_callback(data, *_):
 | 
			
		||||
      """A callback to receive recorded audio data from sounddevice."""
 | 
			
		||||
      self._lock.acquire()
 | 
			
		||||
      shift = len(data)
 | 
			
		||||
      if shift > buffer_size:
 | 
			
		||||
        self._buffer = np.copy(data[:buffer_size])
 | 
			
		||||
      else:
 | 
			
		||||
        self._buffer = np.roll(self._buffer, -shift, axis=0)
 | 
			
		||||
        self._buffer[-shift:, :] = np.copy(data)
 | 
			
		||||
      self._lock.release()
 | 
			
		||||
 | 
			
		||||
    # Create an input stream to continuously capture the audio data.
 | 
			
		||||
    self._stream = sd.InputStream(
 | 
			
		||||
        channels=channels,
 | 
			
		||||
        samplerate=sampling_rate,
 | 
			
		||||
        callback=audio_callback,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def channels(self) -> int:
 | 
			
		||||
    return self._channels
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def sampling_rate(self) -> int:
 | 
			
		||||
    return self._sampling_rate
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def buffer_size(self) -> int:
 | 
			
		||||
    return self._buffer_size
 | 
			
		||||
 | 
			
		||||
  def start_recording(self) -> None:
 | 
			
		||||
    """Starts the audio recording."""
 | 
			
		||||
    # Clear the internal ring buffer.
 | 
			
		||||
    self._buffer.fill(0)
 | 
			
		||||
 | 
			
		||||
    # Start recording using sounddevice's InputStream.
 | 
			
		||||
    self._stream.start()
 | 
			
		||||
 | 
			
		||||
  def stop(self) -> None:
 | 
			
		||||
    """Stops the audio recording."""
 | 
			
		||||
    self._stream.stop()
 | 
			
		||||
 | 
			
		||||
  def read(self, size: int) -> np.ndarray:
 | 
			
		||||
    """Reads the latest audio data captured in the buffer.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      size: Number of samples to read from the buffer.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A NumPy array containing the audio data.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: Raised if `size` is larger than the buffer size.
 | 
			
		||||
    """
 | 
			
		||||
    if size > self._buffer_size:
 | 
			
		||||
      raise ValueError('Cannot read more samples than the size of the buffer.')
 | 
			
		||||
    elif size <= 0:
 | 
			
		||||
      raise ValueError('Size must be positive.')
 | 
			
		||||
 | 
			
		||||
    start_index = self._buffer_size - size
 | 
			
		||||
    return np.copy(self._buffer[start_index:])
 | 
			
		||||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ py_test(
 | 
			
		|||
        "//mediapipe/tasks/python/audio:audio_classifier",
 | 
			
		||||
        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
			
		||||
        "//mediapipe/tasks/python/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
| 
						 | 
				
			
			@ -46,6 +47,7 @@ py_test(
 | 
			
		|||
        "//mediapipe/tasks/python/audio:audio_embedder",
 | 
			
		||||
        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:embedding_result",
 | 
			
		||||
        "//mediapipe/tasks/python/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,7 @@ from mediapipe.tasks.python.audio import audio_classifier
 | 
			
		|||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_record
 | 
			
		||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,7 @@ _AudioClassifier = audio_classifier.AudioClassifier
 | 
			
		|||
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
 | 
			
		||||
_AudioClassifierResult = classification_result_module.ClassificationResult
 | 
			
		||||
_AudioData = audio_data_module.AudioData
 | 
			
		||||
_AudioRecord = audio_record.AudioRecord
 | 
			
		||||
_BaseOptions = base_options_module.BaseOptions
 | 
			
		||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -204,6 +206,18 @@ class AudioClassifierTest(parameterized.TestCase):
 | 
			
		|||
            self._read_wav_file(audio_file))
 | 
			
		||||
        self._check_yamnet_result(classification_result_list)
 | 
			
		||||
 | 
			
		||||
  @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock())
 | 
			
		||||
  def test_create_audio_record_from_classifier_succeeds(self, _):
 | 
			
		||||
    # Creates AudioRecord instance using the classifier successfully.
 | 
			
		||||
    with _AudioClassifier.create_from_model_path(
 | 
			
		||||
        self.yamnet_model_path) as classifier:
 | 
			
		||||
      self.assertIsInstance(classifier, _AudioClassifier)
 | 
			
		||||
      record = classifier.create_audio_record(1, 16000, 16000)
 | 
			
		||||
      self.assertIsInstance(record, _AudioRecord)
 | 
			
		||||
      self.assertEqual(record.channels, 1)
 | 
			
		||||
      self.assertEqual(record.sampling_rate, 16000)
 | 
			
		||||
      self.assertEqual(record.buffer_size, 16000)
 | 
			
		||||
 | 
			
		||||
  def test_max_result_options(self):
 | 
			
		||||
    with _AudioClassifier.create_from_options(
 | 
			
		||||
        _AudioClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,7 @@ from scipy.io import wavfile
 | 
			
		|||
from mediapipe.tasks.python.audio import audio_embedder
 | 
			
		||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_record
 | 
			
		||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder
 | 
			
		|||
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
 | 
			
		||||
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
 | 
			
		||||
_AudioData = audio_data_module.AudioData
 | 
			
		||||
_AudioRecord = audio_record.AudioRecord
 | 
			
		||||
_BaseOptions = base_options_module.BaseOptions
 | 
			
		||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -165,6 +167,18 @@ class AudioEmbedderTest(parameterized.TestCase):
 | 
			
		|||
      self.assertLen(embedding_result0_list, 5)
 | 
			
		||||
      self.assertLen(embedding_result1_list, 5)
 | 
			
		||||
 | 
			
		||||
  @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock())
 | 
			
		||||
  def test_create_audio_record_from_embedder_succeeds(self, _):
 | 
			
		||||
    # Creates AudioRecord instance using the embedder successfully.
 | 
			
		||||
    with _AudioEmbedder.create_from_model_path(
 | 
			
		||||
        self.yamnet_model_path) as embedder:
 | 
			
		||||
      self.assertIsInstance(embedder, _AudioEmbedder)
 | 
			
		||||
      record = embedder.create_audio_record(1, 16000, 16000)
 | 
			
		||||
      self.assertIsInstance(record, _AudioRecord)
 | 
			
		||||
      self.assertEqual(record.channels, 1)
 | 
			
		||||
      self.assertEqual(record.sampling_rate, 16000)
 | 
			
		||||
      self.assertEqual(record.buffer_size, 16000)
 | 
			
		||||
 | 
			
		||||
  def test_embed_with_yamnet_model_and_different_inputs(self):
 | 
			
		||||
    with _AudioEmbedder.create_from_model_path(
 | 
			
		||||
        self.yamnet_model_path) as embedder:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										27
									
								
								mediapipe/tasks/python/test/audio/core/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								mediapipe/tasks/python/test/audio/core/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,27 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Placeholder for internal Python strict test compatibility macro.
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "audio_record_test",
 | 
			
		||||
    srcs = ["audio_record_test.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										97
									
								
								mediapipe/tasks/python/test/audio/core/audio_record_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								mediapipe/tasks/python/test/audio/core/audio_record_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,97 @@
 | 
			
		|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Tests for audio_record."""
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from absl.testing import absltest
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
from mediapipe.tasks.python.components.containers import audio_record
 | 
			
		||||
 | 
			
		||||
_mock = unittest.mock
 | 
			
		||||
 | 
			
		||||
_CHANNELS = 2
 | 
			
		||||
_SAMPLING_RATE = 16000
 | 
			
		||||
_BUFFER_SIZE = 15600
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AudioRecordTest(parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
 | 
			
		||||
    # Mock sounddevice.InputStream
 | 
			
		||||
    with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method:
 | 
			
		||||
      self.mock_input_stream = _mock.MagicMock()
 | 
			
		||||
      mock_input_stream_new_method.return_value = self.mock_input_stream
 | 
			
		||||
      self.record = audio_record.AudioRecord(_CHANNELS, _SAMPLING_RATE,
 | 
			
		||||
                                             _BUFFER_SIZE)
 | 
			
		||||
 | 
			
		||||
      # Save the initialization arguments of InputStream for later assertion.
 | 
			
		||||
      _, self.init_args = mock_input_stream_new_method.call_args
 | 
			
		||||
 | 
			
		||||
  def test_init_args(self):
 | 
			
		||||
    # Assert parameters of InputStream initialization
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        self.init_args["channels"], _CHANNELS,
 | 
			
		||||
        "InputStream's channels doesn't match the initialization argument.")
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        self.init_args["samplerate"], _SAMPLING_RATE,
 | 
			
		||||
        "InputStream's samplerate doesn't match the initialization argument.")
 | 
			
		||||
 | 
			
		||||
  def test_life_cycle(self):
 | 
			
		||||
    # Assert start recording routine.
 | 
			
		||||
    self.record.start_recording()
 | 
			
		||||
    self.mock_input_stream.start.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    # Assert stop recording routine.
 | 
			
		||||
    self.record.stop()
 | 
			
		||||
    self.mock_input_stream.stop.assert_called_once()
 | 
			
		||||
 | 
			
		||||
  def test_read_succeeds_with_valid_sample_size(self):
 | 
			
		||||
    callback_fn = self.init_args["callback"]
 | 
			
		||||
 | 
			
		||||
    # Create dummy data to feed to the AudioRecord instance.
 | 
			
		||||
    chunk_size = int(_BUFFER_SIZE * 0.5)
 | 
			
		||||
    input_data = []
 | 
			
		||||
    for _ in range(3):
 | 
			
		||||
      dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
 | 
			
		||||
      input_data.append(dummy_data)
 | 
			
		||||
      callback_fn(dummy_data)
 | 
			
		||||
 | 
			
		||||
    # Assert read data of a single chunk.
 | 
			
		||||
    recorded_audio_data = self.record.read(chunk_size)
 | 
			
		||||
    self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1]))
 | 
			
		||||
 | 
			
		||||
    # Assert read all data in buffer.
 | 
			
		||||
    recorded_audio_data = self.record.read(chunk_size * 2)
 | 
			
		||||
    print(input_data[-2].shape)
 | 
			
		||||
    expected_data = np.concatenate(input_data[-2:])
 | 
			
		||||
    self.assertTrue(np.array_equal(recorded_audio_data, expected_data))
 | 
			
		||||
 | 
			
		||||
  def test_read_fails_with_invalid_sample_size(self):
 | 
			
		||||
    callback_fn = self.init_args["callback"]
 | 
			
		||||
 | 
			
		||||
    # Create dummy data to feed to the AudioRecord instance.
 | 
			
		||||
    dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float)
 | 
			
		||||
    callback_fn(dummy_data)
 | 
			
		||||
 | 
			
		||||
    # Assert exception if read too much data.
 | 
			
		||||
    with self.assertRaises(ValueError):
 | 
			
		||||
      self.record.read(_BUFFER_SIZE + 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  absltest.main()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user