Add a create_from_array classmethod to the AudioData class.

PiperOrigin-RevId: 486310154
This commit is contained in:
Jiuqiang Tang 2022-11-05 00:07:43 -07:00 committed by Copybara-Service
parent 91782a2772
commit 9504c5e6a1

View File

@ -68,7 +68,11 @@ class AudioData(object):
ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array.
"""
if src.shape[1] != self._audio_format.num_channels:
if len(src.shape) == 1:
if self._audio_format.num_channels != 1:
raise ValueError(f"Input audio is mono, but the audio data is expected "
f"to have {self._audio_format.num_channels} channels.")
elif src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.")
@ -93,6 +97,28 @@ class AudioData(object):
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy()
@classmethod
def create_from_array(cls,
src: np.ndarray,
sample_rate: Optional[float] = None) -> "AudioData":
"""Creates an `AudioData` object from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
sample_rate: the optional audio sample rate.
Returns:
An `AudioData` object that contains a copy of the NumPy source array as
the data.
"""
obj = cls(
buffer_length=src.shape[0],
audio_format=AudioFormat(
num_channels=1 if len(src.shape) == 1 else src.shape[1],
sample_rate=sample_rate))
obj.load_from_array(src)
return obj
@property
def audio_format(self) -> AudioFormat:
"""Gets the audio format of the audio."""