Add a create_from_array
classmethod to the AudioData class.
PiperOrigin-RevId: 486310154
This commit is contained in:
parent
91782a2772
commit
9504c5e6a1
|
@ -68,7 +68,11 @@ class AudioData(object):
|
|||
ValueError: If the input array has an incorrect shape or if
|
||||
`offset` + `size` exceeds the length of the `src` array.
|
||||
"""
|
||||
if src.shape[1] != self._audio_format.num_channels:
|
||||
if len(src.shape) == 1:
|
||||
if self._audio_format.num_channels != 1:
|
||||
raise ValueError(f"Input audio is mono, but the audio data is expected "
|
||||
f"to have {self._audio_format.num_channels} channels.")
|
||||
elif src.shape[1] != self._audio_format.num_channels:
|
||||
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||
f"Expect {self._audio_format.num_channels}.")
|
||||
|
||||
|
@ -93,6 +97,28 @@ class AudioData(object):
|
|||
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||
|
||||
@classmethod
|
||||
def create_from_array(cls,
|
||||
src: np.ndarray,
|
||||
sample_rate: Optional[float] = None) -> "AudioData":
|
||||
"""Creates an `AudioData` object from a NumPy array.
|
||||
|
||||
Args:
|
||||
src: A NumPy source array contains the input audio.
|
||||
sample_rate: the optional audio sample rate.
|
||||
|
||||
Returns:
|
||||
An `AudioData` object that contains a copy of the NumPy source array as
|
||||
the data.
|
||||
"""
|
||||
obj = cls(
|
||||
buffer_length=src.shape[0],
|
||||
audio_format=AudioFormat(
|
||||
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
||||
sample_rate=sample_rate))
|
||||
obj.load_from_array(src)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def audio_format(self) -> AudioFormat:
|
||||
"""Gets the audio format of the audio."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user