From 5f5f50d8f72d5ca4ee4de26a1aa42a7d2d3ca506 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 4 Nov 2022 08:30:30 -0700 Subject: [PATCH] Implement MediaPipe Tasks Python AudioData. PiperOrigin-RevId: 486147173 --- .../tasks/python/components/containers/BUILD | 5 + .../components/containers/audio_data.py | 109 ++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 mediapipe/tasks/python/components/containers/audio_data.py diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 20ee501cc..91e115476 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +py_library( + name = "audio_data", + srcs = ["audio_data.py"], +) + py_library( name = "bounding_box", srcs = ["bounding_box.py"], diff --git a/mediapipe/tasks/python/components/containers/audio_data.py b/mediapipe/tasks/python/components/containers/audio_data.py new file mode 100644 index 000000000..21b606079 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/audio_data.py @@ -0,0 +1,109 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio data.""" + +import dataclasses +from typing import Optional + +import numpy as np + + +@dataclasses.dataclass +class AudioFormat: + """Audio format metadata. + + Attributes: + num_channels: the number of channels of the audio data. + sample_rate: the audio sample rate. + """ + num_channels: int = 1 + sample_rate: Optional[float] = None + + +class AudioData(object): + """MediaPipe Tasks' audio container.""" + + def __init__( + self, buffer_length: int, + audio_format: AudioFormat = AudioFormat()) -> None: + """Initializes the `AudioData` object. + + Args: + buffer_length: the length of the audio buffer. + audio_format: the audio format metadata. + """ + self._audio_format = audio_format + self._buffer = np.zeros([buffer_length, self._audio_format.num_channels], + dtype=np.float32) + + def clear(self): + """Clears the internal buffer and fill it with zeros.""" + self._buffer.fill(0) + + def load_from_array(self, + src: np.ndarray, + offset: int = 0, + size: int = -1) -> None: + """Loads the audio data from a NumPy array. + + Args: + src: A NumPy source array contains the input audio. + offset: An optional offset for loading a slice of the `src` array to the + buffer. + size: An optional size parameter denoting the number of samples to load + from the `src` array. + + Raises: + ValueError: If the input array has an incorrect shape or if + `offset` + `size` exceeds the length of the `src` array. + """ + if src.shape[1] != self._audio_format.num_channels: + raise ValueError(f"Input audio contains an invalid number of channels. " + f"Expect {self._audio_format.num_channels}.") + + if size < 0: + size = len(src) + + if offset + size > len(src): + raise ValueError( + f"Index out of range. offset {offset} + size {size} should be <= " + f"src's length: {len(src)}") + + if len(src) >= len(self._buffer): + # If the internal buffer is shorter than the load target (src), copy + # values from the end of the src array to the internal buffer. + new_offset = offset + size - len(self._buffer) + new_size = len(self._buffer) + self._buffer = src[new_offset:new_offset + new_size].copy() + else: + # Shift the internal buffer backward and add the incoming data to the end + # of the buffer. + shift = size + self._buffer = np.roll(self._buffer, -shift, axis=0) + self._buffer[-shift:, :] = src[offset:offset + size].copy() + + @property + def audio_format(self) -> AudioFormat: + """Gets the audio format of the audio.""" + return self._audio_format + + @property + def buffer_length(self) -> int: + """Gets the sample count of the audio.""" + return self._buffer.shape[0] + + @property + def buffer(self) -> np.ndarray: + """Gets the internal buffer.""" + return self._buffer