Add a create_from_array
classmethod to the AudioData class.
PiperOrigin-RevId: 486310154
This commit is contained in:
parent
91782a2772
commit
9504c5e6a1
|
@ -68,7 +68,11 @@ class AudioData(object):
|
||||||
ValueError: If the input array has an incorrect shape or if
|
ValueError: If the input array has an incorrect shape or if
|
||||||
`offset` + `size` exceeds the length of the `src` array.
|
`offset` + `size` exceeds the length of the `src` array.
|
||||||
"""
|
"""
|
||||||
if src.shape[1] != self._audio_format.num_channels:
|
if len(src.shape) == 1:
|
||||||
|
if self._audio_format.num_channels != 1:
|
||||||
|
raise ValueError(f"Input audio is mono, but the audio data is expected "
|
||||||
|
f"to have {self._audio_format.num_channels} channels.")
|
||||||
|
elif src.shape[1] != self._audio_format.num_channels:
|
||||||
raise ValueError(f"Input audio contains an invalid number of channels. "
|
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||||
f"Expect {self._audio_format.num_channels}.")
|
f"Expect {self._audio_format.num_channels}.")
|
||||||
|
|
||||||
|
@ -93,6 +97,28 @@ class AudioData(object):
|
||||||
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||||
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_array(cls,
|
||||||
|
src: np.ndarray,
|
||||||
|
sample_rate: Optional[float] = None) -> "AudioData":
|
||||||
|
"""Creates an `AudioData` object from a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: A NumPy source array contains the input audio.
|
||||||
|
sample_rate: the optional audio sample rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `AudioData` object that contains a copy of the NumPy source array as
|
||||||
|
the data.
|
||||||
|
"""
|
||||||
|
obj = cls(
|
||||||
|
buffer_length=src.shape[0],
|
||||||
|
audio_format=AudioFormat(
|
||||||
|
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
||||||
|
sample_rate=sample_rate))
|
||||||
|
obj.load_from_array(src)
|
||||||
|
return obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_format(self) -> AudioFormat:
|
def audio_format(self) -> AudioFormat:
|
||||||
"""Gets the audio format of the audio."""
|
"""Gets the audio format of the audio."""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user