63a759accc
PiperOrigin-RevId: 486763992
138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""MediaPipe audio data."""
|
|
|
|
import dataclasses
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AudioDataFormat:
|
|
"""Audio format metadata.
|
|
|
|
Attributes:
|
|
num_channels: the number of channels of the audio data.
|
|
sample_rate: the audio sample rate.
|
|
"""
|
|
num_channels: int = 1
|
|
sample_rate: Optional[float] = None
|
|
|
|
|
|
class AudioData(object):
|
|
"""MediaPipe Tasks' audio container."""
|
|
|
|
def __init__(
|
|
self,
|
|
buffer_length: int,
|
|
audio_format: AudioDataFormat = AudioDataFormat()
|
|
) -> None:
|
|
"""Initializes the `AudioData` object.
|
|
|
|
Args:
|
|
buffer_length: the length of the audio buffer.
|
|
audio_format: the audio format metadata.
|
|
"""
|
|
self._audio_format = audio_format
|
|
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
|
|
dtype=np.float32)
|
|
|
|
def clear(self):
|
|
"""Clears the internal buffer and fill it with zeros."""
|
|
self._buffer.fill(0)
|
|
|
|
def load_from_array(self,
|
|
src: np.ndarray,
|
|
offset: int = 0,
|
|
size: int = -1) -> None:
|
|
"""Loads the audio data from a NumPy array.
|
|
|
|
Args:
|
|
src: A NumPy source array contains the input audio.
|
|
offset: An optional offset for loading a slice of the `src` array to the
|
|
buffer.
|
|
size: An optional size parameter denoting the number of samples to load
|
|
from the `src` array.
|
|
|
|
Raises:
|
|
ValueError: If the input array has an incorrect shape or if
|
|
`offset` + `size` exceeds the length of the `src` array.
|
|
"""
|
|
if len(src.shape) == 1:
|
|
if self._audio_format.num_channels != 1:
|
|
raise ValueError(f"Input audio is mono, but the audio data is expected "
|
|
f"to have {self._audio_format.num_channels} channels.")
|
|
elif src.shape[1] != self._audio_format.num_channels:
|
|
raise ValueError(f"Input audio contains an invalid number of channels. "
|
|
f"Expect {self._audio_format.num_channels}.")
|
|
|
|
if size < 0:
|
|
size = len(src)
|
|
|
|
if offset + size > len(src):
|
|
raise ValueError(
|
|
f"Index out of range. offset {offset} + size {size} should be <= "
|
|
f"src's length: {len(src)}")
|
|
|
|
if len(src) >= len(self._buffer):
|
|
# If the internal buffer is shorter than the load target (src), copy
|
|
# values from the end of the src array to the internal buffer.
|
|
new_offset = offset + size - len(self._buffer)
|
|
new_size = len(self._buffer)
|
|
self._buffer = src[new_offset:new_offset + new_size].copy()
|
|
else:
|
|
# Shift the internal buffer backward and add the incoming data to the end
|
|
# of the buffer.
|
|
shift = size
|
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
|
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
|
|
|
@classmethod
|
|
def create_from_array(cls,
|
|
src: np.ndarray,
|
|
sample_rate: Optional[float] = None) -> "AudioData":
|
|
"""Creates an `AudioData` object from a NumPy array.
|
|
|
|
Args:
|
|
src: A NumPy source array contains the input audio.
|
|
sample_rate: the optional audio sample rate.
|
|
|
|
Returns:
|
|
An `AudioData` object that contains a copy of the NumPy source array as
|
|
the data.
|
|
"""
|
|
obj = cls(
|
|
buffer_length=src.shape[0],
|
|
audio_format=AudioDataFormat(
|
|
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
|
sample_rate=sample_rate))
|
|
obj.load_from_array(src)
|
|
return obj
|
|
|
|
@property
|
|
def audio_format(self) -> AudioDataFormat:
|
|
"""Gets the audio format of the audio."""
|
|
return self._audio_format
|
|
|
|
@property
|
|
def buffer_length(self) -> int:
|
|
"""Gets the sample count of the audio."""
|
|
return self._buffer.shape[0]
|
|
|
|
@property
|
|
def buffer(self) -> np.ndarray:
|
|
"""Gets the internal buffer."""
|
|
return self._buffer
|