Implement MediaPipe Tasks Python AudioData.
PiperOrigin-RevId: 486147173
This commit is contained in:
parent
5024c815f1
commit
5f5f50d8f7
|
@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_data",
|
||||||
|
srcs = ["audio_data.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "bounding_box",
|
name = "bounding_box",
|
||||||
srcs = ["bounding_box.py"],
|
srcs = ["bounding_box.py"],
|
||||||
|
|
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe audio data."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AudioFormat:
|
||||||
|
"""Audio format metadata.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_channels: the number of channels of the audio data.
|
||||||
|
sample_rate: the audio sample rate.
|
||||||
|
"""
|
||||||
|
num_channels: int = 1
|
||||||
|
sample_rate: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AudioData(object):
|
||||||
|
"""MediaPipe Tasks' audio container."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, buffer_length: int,
|
||||||
|
audio_format: AudioFormat = AudioFormat()) -> None:
|
||||||
|
"""Initializes the `AudioData` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_length: the length of the audio buffer.
|
||||||
|
audio_format: the audio format metadata.
|
||||||
|
"""
|
||||||
|
self._audio_format = audio_format
|
||||||
|
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clears the internal buffer and fill it with zeros."""
|
||||||
|
self._buffer.fill(0)
|
||||||
|
|
||||||
|
def load_from_array(self,
|
||||||
|
src: np.ndarray,
|
||||||
|
offset: int = 0,
|
||||||
|
size: int = -1) -> None:
|
||||||
|
"""Loads the audio data from a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: A NumPy source array contains the input audio.
|
||||||
|
offset: An optional offset for loading a slice of the `src` array to the
|
||||||
|
buffer.
|
||||||
|
size: An optional size parameter denoting the number of samples to load
|
||||||
|
from the `src` array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input array has an incorrect shape or if
|
||||||
|
`offset` + `size` exceeds the length of the `src` array.
|
||||||
|
"""
|
||||||
|
if src.shape[1] != self._audio_format.num_channels:
|
||||||
|
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||||
|
f"Expect {self._audio_format.num_channels}.")
|
||||||
|
|
||||||
|
if size < 0:
|
||||||
|
size = len(src)
|
||||||
|
|
||||||
|
if offset + size > len(src):
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of range. offset {offset} + size {size} should be <= "
|
||||||
|
f"src's length: {len(src)}")
|
||||||
|
|
||||||
|
if len(src) >= len(self._buffer):
|
||||||
|
# If the internal buffer is shorter than the load target (src), copy
|
||||||
|
# values from the end of the src array to the internal buffer.
|
||||||
|
new_offset = offset + size - len(self._buffer)
|
||||||
|
new_size = len(self._buffer)
|
||||||
|
self._buffer = src[new_offset:new_offset + new_size].copy()
|
||||||
|
else:
|
||||||
|
# Shift the internal buffer backward and add the incoming data to the end
|
||||||
|
# of the buffer.
|
||||||
|
shift = size
|
||||||
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||||
|
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_format(self) -> AudioFormat:
|
||||||
|
"""Gets the audio format of the audio."""
|
||||||
|
return self._audio_format
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_length(self) -> int:
|
||||||
|
"""Gets the sample count of the audio."""
|
||||||
|
return self._buffer.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer(self) -> np.ndarray:
|
||||||
|
"""Gets the internal buffer."""
|
||||||
|
return self._buffer
|
Loading…
Reference in New Issue
Block a user