Implement MediaPipe Tasks Python AudioData.

PiperOrigin-RevId: 486147173
This commit is contained in:
Jiuqiang Tang 2022-11-04 08:30:30 -07:00 committed by Copybara-Service
parent 5024c815f1
commit 5f5f50d8f7
2 changed files with 114 additions and 0 deletions

View File

@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
py_library(
name = "audio_data",
srcs = ["audio_data.py"],
)
py_library( py_library(
name = "bounding_box", name = "bounding_box",
srcs = ["bounding_box.py"], srcs = ["bounding_box.py"],

View File

@ -0,0 +1,109 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio data."""
import dataclasses
from typing import Optional
import numpy as np
@dataclasses.dataclass
class AudioFormat:
"""Audio format metadata.
Attributes:
num_channels: the number of channels of the audio data.
sample_rate: the audio sample rate.
"""
num_channels: int = 1
sample_rate: Optional[float] = None
class AudioData(object):
"""MediaPipe Tasks' audio container."""
def __init__(
self, buffer_length: int,
audio_format: AudioFormat = AudioFormat()) -> None:
"""Initializes the `AudioData` object.
Args:
buffer_length: the length of the audio buffer.
audio_format: the audio format metadata.
"""
self._audio_format = audio_format
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
dtype=np.float32)
def clear(self):
"""Clears the internal buffer and fill it with zeros."""
self._buffer.fill(0)
def load_from_array(self,
src: np.ndarray,
offset: int = 0,
size: int = -1) -> None:
"""Loads the audio data from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
offset: An optional offset for loading a slice of the `src` array to the
buffer.
size: An optional size parameter denoting the number of samples to load
from the `src` array.
Raises:
ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array.
"""
if src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.")
if size < 0:
size = len(src)
if offset + size > len(src):
raise ValueError(
f"Index out of range. offset {offset} + size {size} should be <= "
f"src's length: {len(src)}")
if len(src) >= len(self._buffer):
# If the internal buffer is shorter than the load target (src), copy
# values from the end of the src array to the internal buffer.
new_offset = offset + size - len(self._buffer)
new_size = len(self._buffer)
self._buffer = src[new_offset:new_offset + new_size].copy()
else:
# Shift the internal buffer backward and add the incoming data to the end
# of the buffer.
shift = size
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy()
@property
def audio_format(self) -> AudioFormat:
"""Gets the audio format of the audio."""
return self._audio_format
@property
def buffer_length(self) -> int:
"""Gets the sample count of the audio."""
return self._buffer.shape[0]
@property
def buffer(self) -> np.ndarray:
"""Gets the internal buffer."""
return self._buffer