Added the AudioRecord API
This commit is contained in:
		
							parent
							
								
									05b505c8e2
								
							
						
					
					
						commit
						9787056508
					
				| 
						 | 
					@ -34,5 +34,6 @@ py_library(
 | 
				
			||||||
        "//mediapipe/python:_framework_bindings",
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
        "//mediapipe/python:packet_creator",
 | 
					        "//mediapipe/python:packet_creator",
 | 
				
			||||||
        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,6 +22,7 @@ from mediapipe.python._framework_bindings import task_runner as task_runner_modu
 | 
				
			||||||
from mediapipe.python._framework_bindings import timestamp as timestamp_module
 | 
					from mediapipe.python._framework_bindings import timestamp as timestamp_module
 | 
				
			||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
				
			||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_TaskRunner = task_runner_module.TaskRunner
 | 
					_TaskRunner = task_runner_module.TaskRunner
 | 
				
			||||||
_Packet = packet_module.Packet
 | 
					_Packet = packet_module.Packet
 | 
				
			||||||
| 
						 | 
					@ -126,6 +127,33 @@ class BaseAudioTaskApi(object):
 | 
				
			||||||
          + self._running_mode.name)
 | 
					          + self._running_mode.name)
 | 
				
			||||||
    self._runner.send(inputs)
 | 
					    self._runner.send(inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @staticmethod
 | 
				
			||||||
 | 
					  def create_audio_record(
 | 
				
			||||||
 | 
					      num_channels: int,
 | 
				
			||||||
 | 
					      sample_rate: int,
 | 
				
			||||||
 | 
					      required_input_buffer_size: int
 | 
				
			||||||
 | 
					  ) -> audio_record.AudioRecord:
 | 
				
			||||||
 | 
					    """Creates an AudioRecord instance to record audio stream.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The returned AudioRecord instance is initialized and client needs to call
 | 
				
			||||||
 | 
					    the appropriate method to start recording.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note that MediaPipe Audio tasks will up/down sample automatically to fit the
 | 
				
			||||||
 | 
					    sample rate required by the model. The default sample rate of the MediaPipe
 | 
				
			||||||
 | 
					    pretrained audio model, Yamnet is 16kHz.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      num_channels: The number of audio channels.
 | 
				
			||||||
 | 
					      sample_rate: The audio sample rate.
 | 
				
			||||||
 | 
					      required_input_buffer_size: The required input buffer size in number of
 | 
				
			||||||
 | 
					        float elements.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If there's a problem creating the AudioRecord instance.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return audio_record.AudioRecord(num_channels, sample_rate,
 | 
				
			||||||
 | 
					                                    required_input_buffer_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def close(self) -> None:
 | 
					  def close(self) -> None:
 | 
				
			||||||
    """Shuts down the mediapipe audio task instance.
 | 
					    """Shuts down the mediapipe audio task instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,6 +23,11 @@ py_library(
 | 
				
			||||||
    srcs = ["audio_data.py"],
 | 
					    srcs = ["audio_data.py"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "audio_record",
 | 
				
			||||||
 | 
					    srcs = ["audio_record.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "bounding_box",
 | 
					    name = "bounding_box",
 | 
				
			||||||
    srcs = ["bounding_box.py"],
 | 
					    srcs = ["bounding_box.py"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										126
									
								
								mediapipe/tasks/python/components/containers/audio_record.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								mediapipe/tasks/python/components/containers/audio_record.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,126 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""A module to record audio in a streaming basis."""
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					# pylint: disable=g-import-not-at-top
 | 
				
			||||||
 | 
					  import sounddevice as sd
 | 
				
			||||||
 | 
					# pylint: enable=g-import-not-at-top
 | 
				
			||||||
 | 
					except OSError as oe:
 | 
				
			||||||
 | 
					  sd = None
 | 
				
			||||||
 | 
					  sd_error = oe
 | 
				
			||||||
 | 
					except ImportError as ie:
 | 
				
			||||||
 | 
					  sd = None
 | 
				
			||||||
 | 
					  sd_error = ie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AudioRecord(object):
 | 
				
			||||||
 | 
					  """A class to record audio in a streaming basis."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, channels: int, sampling_rate: int,
 | 
				
			||||||
 | 
					               buffer_size: int) -> None:
 | 
				
			||||||
 | 
					    """Creates an AudioRecord instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      channels: Number of input channels.
 | 
				
			||||||
 | 
					      sampling_rate: Sampling rate in Hertz.
 | 
				
			||||||
 | 
					      buffer_size: Size of the ring buffer in number of samples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: if any of the arguments is non-positive.
 | 
				
			||||||
 | 
					      ImportError: if failed to import `sounddevice`.
 | 
				
			||||||
 | 
					      OSError: if failed to load `PortAudio`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if sd is None:
 | 
				
			||||||
 | 
					      raise sd_error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if channels <= 0:
 | 
				
			||||||
 | 
					      raise ValueError('channels must be postive.')
 | 
				
			||||||
 | 
					    if sampling_rate <= 0:
 | 
				
			||||||
 | 
					      raise ValueError('sampling_rate must be postive.')
 | 
				
			||||||
 | 
					    if buffer_size <= 0:
 | 
				
			||||||
 | 
					      raise ValueError('buffer_size must be postive.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._audio_buffer = []
 | 
				
			||||||
 | 
					    self._buffer_size = buffer_size
 | 
				
			||||||
 | 
					    self._channels = channels
 | 
				
			||||||
 | 
					    self._sampling_rate = sampling_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create a ring buffer to store the input audio.
 | 
				
			||||||
 | 
					    self._buffer = np.zeros([buffer_size, channels], dtype=float)
 | 
				
			||||||
 | 
					    self._lock = threading.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def audio_callback(data, *_):
 | 
				
			||||||
 | 
					      """A callback to receive recorded audio data from sounddevice."""
 | 
				
			||||||
 | 
					      self._lock.acquire()
 | 
				
			||||||
 | 
					      shift = len(data)
 | 
				
			||||||
 | 
					      if shift > buffer_size:
 | 
				
			||||||
 | 
					        self._buffer = np.copy(data[:buffer_size])
 | 
				
			||||||
 | 
					      else:
 | 
				
			||||||
 | 
					        self._buffer = np.roll(self._buffer, -shift, axis=0)
 | 
				
			||||||
 | 
					        self._buffer[-shift:, :] = np.copy(data)
 | 
				
			||||||
 | 
					      self._lock.release()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create an input stream to continuously capture the audio data.
 | 
				
			||||||
 | 
					    self._stream = sd.InputStream(
 | 
				
			||||||
 | 
					        channels=channels,
 | 
				
			||||||
 | 
					        samplerate=sampling_rate,
 | 
				
			||||||
 | 
					        callback=audio_callback,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @property
 | 
				
			||||||
 | 
					  def channels(self) -> int:
 | 
				
			||||||
 | 
					    return self._channels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @property
 | 
				
			||||||
 | 
					  def sampling_rate(self) -> int:
 | 
				
			||||||
 | 
					    return self._sampling_rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @property
 | 
				
			||||||
 | 
					  def buffer_size(self) -> int:
 | 
				
			||||||
 | 
					    return self._buffer_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def start_recording(self) -> None:
 | 
				
			||||||
 | 
					    """Starts the audio recording."""
 | 
				
			||||||
 | 
					    # Clear the internal ring buffer.
 | 
				
			||||||
 | 
					    self._buffer.fill(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Start recording using sounddevice's InputStream.
 | 
				
			||||||
 | 
					    self._stream.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def stop(self) -> None:
 | 
				
			||||||
 | 
					    """Stops the audio recording."""
 | 
				
			||||||
 | 
					    self._stream.stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def read(self, size: int) -> np.ndarray:
 | 
				
			||||||
 | 
					    """Reads the latest audio data captured in the buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      size: Number of samples to read from the buffer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      A NumPy array containing the audio data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: Raised if `size` is larger than the buffer size.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if size > self._buffer_size:
 | 
				
			||||||
 | 
					      raise ValueError('Cannot read more samples than the size of the buffer.')
 | 
				
			||||||
 | 
					    elif size <= 0:
 | 
				
			||||||
 | 
					      raise ValueError('Size must be positive.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    start_index = self._buffer_size - size
 | 
				
			||||||
 | 
					    return np.copy(self._buffer[start_index:])
 | 
				
			||||||
| 
						 | 
					@ -29,6 +29,7 @@ py_test(
 | 
				
			||||||
        "//mediapipe/tasks/python/audio:audio_classifier",
 | 
					        "//mediapipe/tasks/python/audio:audio_classifier",
 | 
				
			||||||
        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
					        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
					        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
					        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
				
			||||||
        "//mediapipe/tasks/python/core:base_options",
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
| 
						 | 
					@ -46,6 +47,7 @@ py_test(
 | 
				
			||||||
        "//mediapipe/tasks/python/audio:audio_embedder",
 | 
					        "//mediapipe/tasks/python/audio:audio_embedder",
 | 
				
			||||||
        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
					        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
					        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
				
			||||||
        "//mediapipe/tasks/python/components/containers:embedding_result",
 | 
					        "//mediapipe/tasks/python/components/containers:embedding_result",
 | 
				
			||||||
        "//mediapipe/tasks/python/core:base_options",
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,7 @@ from mediapipe.tasks.python.audio import audio_classifier
 | 
				
			||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
				
			||||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
					from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
				
			||||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
					from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_record
 | 
				
			||||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
from mediapipe.tasks.python.test import test_utils
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,6 +35,7 @@ _AudioClassifier = audio_classifier.AudioClassifier
 | 
				
			||||||
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
 | 
					_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
 | 
				
			||||||
_AudioClassifierResult = classification_result_module.ClassificationResult
 | 
					_AudioClassifierResult = classification_result_module.ClassificationResult
 | 
				
			||||||
_AudioData = audio_data_module.AudioData
 | 
					_AudioData = audio_data_module.AudioData
 | 
				
			||||||
 | 
					_AudioRecord = audio_record.AudioRecord
 | 
				
			||||||
_BaseOptions = base_options_module.BaseOptions
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
					_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -204,6 +206,18 @@ class AudioClassifierTest(parameterized.TestCase):
 | 
				
			||||||
            self._read_wav_file(audio_file))
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
        self._check_yamnet_result(classification_result_list)
 | 
					        self._check_yamnet_result(classification_result_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock())
 | 
				
			||||||
 | 
					  def test_create_audio_record_from_classifier_succeeds(self, _):
 | 
				
			||||||
 | 
					    # Creates AudioRecord instance using the classifier successfully.
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.yamnet_model_path) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _AudioClassifier)
 | 
				
			||||||
 | 
					      record = classifier.create_audio_record(1, 16000, 16000)
 | 
				
			||||||
 | 
					      self.assertIsInstance(record, _AudioRecord)
 | 
				
			||||||
 | 
					      self.assertEqual(record.channels, 1)
 | 
				
			||||||
 | 
					      self.assertEqual(record.sampling_rate, 16000)
 | 
				
			||||||
 | 
					      self.assertEqual(record.buffer_size, 16000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_max_result_options(self):
 | 
					  def test_max_result_options(self):
 | 
				
			||||||
    with _AudioClassifier.create_from_options(
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
        _AudioClassifierOptions(
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,7 @@ from scipy.io import wavfile
 | 
				
			||||||
from mediapipe.tasks.python.audio import audio_embedder
 | 
					from mediapipe.tasks.python.audio import audio_embedder
 | 
				
			||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
				
			||||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
					from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_record
 | 
				
			||||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
from mediapipe.tasks.python.test import test_utils
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,6 +34,7 @@ _AudioEmbedder = audio_embedder.AudioEmbedder
 | 
				
			||||||
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
 | 
					_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
 | 
				
			||||||
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
 | 
					_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
 | 
				
			||||||
_AudioData = audio_data_module.AudioData
 | 
					_AudioData = audio_data_module.AudioData
 | 
				
			||||||
 | 
					_AudioRecord = audio_record.AudioRecord
 | 
				
			||||||
_BaseOptions = base_options_module.BaseOptions
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
					_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -165,6 +167,18 @@ class AudioEmbedderTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertLen(embedding_result0_list, 5)
 | 
					      self.assertLen(embedding_result0_list, 5)
 | 
				
			||||||
      self.assertLen(embedding_result1_list, 5)
 | 
					      self.assertLen(embedding_result1_list, 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @mock.patch("sounddevice.InputStream", return_value=mock.MagicMock())
 | 
				
			||||||
 | 
					  def test_create_audio_record_from_embedder_succeeds(self, _):
 | 
				
			||||||
 | 
					    # Creates AudioRecord instance using the embedder successfully.
 | 
				
			||||||
 | 
					    with _AudioEmbedder.create_from_model_path(
 | 
				
			||||||
 | 
					        self.yamnet_model_path) as embedder:
 | 
				
			||||||
 | 
					      self.assertIsInstance(embedder, _AudioEmbedder)
 | 
				
			||||||
 | 
					      record = embedder.create_audio_record(1, 16000, 16000)
 | 
				
			||||||
 | 
					      self.assertIsInstance(record, _AudioRecord)
 | 
				
			||||||
 | 
					      self.assertEqual(record.channels, 1)
 | 
				
			||||||
 | 
					      self.assertEqual(record.sampling_rate, 16000)
 | 
				
			||||||
 | 
					      self.assertEqual(record.buffer_size, 16000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_embed_with_yamnet_model_and_different_inputs(self):
 | 
					  def test_embed_with_yamnet_model_and_different_inputs(self):
 | 
				
			||||||
    with _AudioEmbedder.create_from_model_path(
 | 
					    with _AudioEmbedder.create_from_model_path(
 | 
				
			||||||
        self.yamnet_model_path) as embedder:
 | 
					        self.yamnet_model_path) as embedder:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										27
									
								
								mediapipe/tasks/python/test/audio/core/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								mediapipe/tasks/python/test/audio/core/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,27 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "audio_record_test",
 | 
				
			||||||
 | 
					    srcs = ["audio_record_test.py"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_record",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										97
									
								
								mediapipe/tasks/python/test/audio/core/audio_record_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								mediapipe/tasks/python/test/audio/core/audio_record_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,97 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Tests for audio_record."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl.testing import absltest
 | 
				
			||||||
 | 
					from absl.testing import parameterized
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_mock = unittest.mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_CHANNELS = 2
 | 
				
			||||||
 | 
					_SAMPLING_RATE = 16000
 | 
				
			||||||
 | 
					_BUFFER_SIZE = 15600
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AudioRecordTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Mock sounddevice.InputStream
 | 
				
			||||||
 | 
					    with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method:
 | 
				
			||||||
 | 
					      self.mock_input_stream = _mock.MagicMock()
 | 
				
			||||||
 | 
					      mock_input_stream_new_method.return_value = self.mock_input_stream
 | 
				
			||||||
 | 
					      self.record = audio_record.AudioRecord(_CHANNELS, _SAMPLING_RATE,
 | 
				
			||||||
 | 
					                                             _BUFFER_SIZE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Save the initialization arguments of InputStream for later assertion.
 | 
				
			||||||
 | 
					      _, self.init_args = mock_input_stream_new_method.call_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_init_args(self):
 | 
				
			||||||
 | 
					    # Assert parameters of InputStream initialization
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        self.init_args["channels"], _CHANNELS,
 | 
				
			||||||
 | 
					        "InputStream's channels doesn't match the initialization argument.")
 | 
				
			||||||
 | 
					    self.assertEqual(
 | 
				
			||||||
 | 
					        self.init_args["samplerate"], _SAMPLING_RATE,
 | 
				
			||||||
 | 
					        "InputStream's samplerate doesn't match the initialization argument.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_life_cycle(self):
 | 
				
			||||||
 | 
					    # Assert start recording routine.
 | 
				
			||||||
 | 
					    self.record.start_recording()
 | 
				
			||||||
 | 
					    self.mock_input_stream.start.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Assert stop recording routine.
 | 
				
			||||||
 | 
					    self.record.stop()
 | 
				
			||||||
 | 
					    self.mock_input_stream.stop.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_read_succeeds_with_valid_sample_size(self):
 | 
				
			||||||
 | 
					    callback_fn = self.init_args["callback"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create dummy data to feed to the AudioRecord instance.
 | 
				
			||||||
 | 
					    chunk_size = int(_BUFFER_SIZE * 0.5)
 | 
				
			||||||
 | 
					    input_data = []
 | 
				
			||||||
 | 
					    for _ in range(3):
 | 
				
			||||||
 | 
					      dummy_data = np.random.rand(chunk_size, _CHANNELS).astype(float)
 | 
				
			||||||
 | 
					      input_data.append(dummy_data)
 | 
				
			||||||
 | 
					      callback_fn(dummy_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Assert read data of a single chunk.
 | 
				
			||||||
 | 
					    recorded_audio_data = self.record.read(chunk_size)
 | 
				
			||||||
 | 
					    self.assertTrue(np.array_equal(recorded_audio_data, input_data[-1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Assert read all data in buffer.
 | 
				
			||||||
 | 
					    recorded_audio_data = self.record.read(chunk_size * 2)
 | 
				
			||||||
 | 
					    print(input_data[-2].shape)
 | 
				
			||||||
 | 
					    expected_data = np.concatenate(input_data[-2:])
 | 
				
			||||||
 | 
					    self.assertTrue(np.array_equal(recorded_audio_data, expected_data))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_read_fails_with_invalid_sample_size(self):
 | 
				
			||||||
 | 
					    callback_fn = self.init_args["callback"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create dummy data to feed to the AudioRecord instance.
 | 
				
			||||||
 | 
					    dummy_data = np.zeros([_BUFFER_SIZE, 1], dtype=float)
 | 
				
			||||||
 | 
					    callback_fn(dummy_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Assert exception if read too much data.
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					      self.record.read(_BUFFER_SIZE + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					  absltest.main()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user