Implement MediaPipe Tasks Python AudioData.
PiperOrigin-RevId: 486147173
This commit is contained in:
parent
5024c815f1
commit
5f5f50d8f7
|
@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "audio_data",
|
||||
srcs = ["audio_data.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "bounding_box",
|
||||
srcs = ["bounding_box.py"],
|
||||
|
|
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe audio data."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AudioFormat:
|
||||
"""Audio format metadata.
|
||||
|
||||
Attributes:
|
||||
num_channels: the number of channels of the audio data.
|
||||
sample_rate: the audio sample rate.
|
||||
"""
|
||||
num_channels: int = 1
|
||||
sample_rate: Optional[float] = None
|
||||
|
||||
|
||||
class AudioData(object):
|
||||
"""MediaPipe Tasks' audio container."""
|
||||
|
||||
def __init__(
|
||||
self, buffer_length: int,
|
||||
audio_format: AudioFormat = AudioFormat()) -> None:
|
||||
"""Initializes the `AudioData` object.
|
||||
|
||||
Args:
|
||||
buffer_length: the length of the audio buffer.
|
||||
audio_format: the audio format metadata.
|
||||
"""
|
||||
self._audio_format = audio_format
|
||||
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
|
||||
dtype=np.float32)
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal buffer and fill it with zeros."""
|
||||
self._buffer.fill(0)
|
||||
|
||||
def load_from_array(self,
|
||||
src: np.ndarray,
|
||||
offset: int = 0,
|
||||
size: int = -1) -> None:
|
||||
"""Loads the audio data from a NumPy array.
|
||||
|
||||
Args:
|
||||
src: A NumPy source array contains the input audio.
|
||||
offset: An optional offset for loading a slice of the `src` array to the
|
||||
buffer.
|
||||
size: An optional size parameter denoting the number of samples to load
|
||||
from the `src` array.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input array has an incorrect shape or if
|
||||
`offset` + `size` exceeds the length of the `src` array.
|
||||
"""
|
||||
if src.shape[1] != self._audio_format.num_channels:
|
||||
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||
f"Expect {self._audio_format.num_channels}.")
|
||||
|
||||
if size < 0:
|
||||
size = len(src)
|
||||
|
||||
if offset + size > len(src):
|
||||
raise ValueError(
|
||||
f"Index out of range. offset {offset} + size {size} should be <= "
|
||||
f"src's length: {len(src)}")
|
||||
|
||||
if len(src) >= len(self._buffer):
|
||||
# If the internal buffer is shorter than the load target (src), copy
|
||||
# values from the end of the src array to the internal buffer.
|
||||
new_offset = offset + size - len(self._buffer)
|
||||
new_size = len(self._buffer)
|
||||
self._buffer = src[new_offset:new_offset + new_size].copy()
|
||||
else:
|
||||
# Shift the internal buffer backward and add the incoming data to the end
|
||||
# of the buffer.
|
||||
shift = size
|
||||
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||
|
||||
@property
|
||||
def audio_format(self) -> AudioFormat:
|
||||
"""Gets the audio format of the audio."""
|
||||
return self._audio_format
|
||||
|
||||
@property
|
||||
def buffer_length(self) -> int:
|
||||
"""Gets the sample count of the audio."""
|
||||
return self._buffer.shape[0]
|
||||
|
||||
@property
|
||||
def buffer(self) -> np.ndarray:
|
||||
"""Gets the internal buffer."""
|
||||
return self._buffer
|
Loading…
Reference in New Issue
Block a user