diff --git a/mediapipe/tasks/python/components/containers/audio_data.py b/mediapipe/tasks/python/components/containers/audio_data.py index 21b606079..56399dea8 100644 --- a/mediapipe/tasks/python/components/containers/audio_data.py +++ b/mediapipe/tasks/python/components/containers/audio_data.py @@ -68,7 +68,11 @@ class AudioData(object): ValueError: If the input array has an incorrect shape or if `offset` + `size` exceeds the length of the `src` array. """ - if src.shape[1] != self._audio_format.num_channels: + if len(src.shape) == 1: + if self._audio_format.num_channels != 1: + raise ValueError(f"Input audio is mono, but the audio data is expected " + f"to have {self._audio_format.num_channels} channels.") + elif src.shape[1] != self._audio_format.num_channels: raise ValueError(f"Input audio contains an invalid number of channels. " f"Expect {self._audio_format.num_channels}.") @@ -93,6 +97,28 @@ class AudioData(object): self._buffer = np.roll(self._buffer, -shift, axis=0) self._buffer[-shift:, :] = src[offset:offset + size].copy() + @classmethod + def create_from_array(cls, + src: np.ndarray, + sample_rate: Optional[float] = None) -> "AudioData": + """Creates an `AudioData` object from a NumPy array. + + Args: + src: A NumPy source array contains the input audio. + sample_rate: the optional audio sample rate. + + Returns: + An `AudioData` object that contains a copy of the NumPy source array as + the data. + """ + obj = cls( + buffer_length=src.shape[0], + audio_format=AudioFormat( + num_channels=1 if len(src.shape) == 1 else src.shape[1], + sample_rate=sample_rate)) + obj.load_from_array(src) + return obj + @property def audio_format(self) -> AudioFormat: """Gets the audio format of the audio."""