Merge pull request #3853 from kinaryml:audio-embedder-python

PiperOrigin-RevId: 488434586
This commit is contained in:
Copybara-Service 2022-11-14 12:16:45 -08:00
commit 9a2af2f2a1
5 changed files with 645 additions and 1 deletions

View File

@ -94,10 +94,11 @@ cc_library(
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows. # TODO: Build text_classifier_graph and text_embedder_graph on Windows.
# TODO: Build audio_classifier_graph on Windows. # TODO: Build audio_classifier_graph and audio_embedder_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],
"//conditions:default": [ "//conditions:default": [
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
], ],

View File

@ -39,3 +39,26 @@ py_library(
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
], ],
) )
py_library(
name = "audio_embedder",
srcs = [
"audio_embedder.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/components/utils:cosine_similarity",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
],
)

View File

@ -0,0 +1,284 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio embedder task."""
import dataclasses
from typing import Callable, Mapping, List, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module
from mediapipe.tasks.python.components.utils import cosine_similarity
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
AudioEmbedderResult = embedding_result_module.EmbeddingResult
_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options_module.EmbedderOptions
_RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_AUDIO_IN_STREAM_NAME = 'audio_in'
_AUDIO_TAG = 'AUDIO'
_EMBEDDINGS_STREAM_NAME = 'embeddings_out'
_EMBEDDINGS_TAG = 'EMBEDDINGS'
_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in'
_SAMPLE_RATE_TAG = 'SAMPLE_RATE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'
_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME = 'timestamped_embeddings_out'
_TIMESTAMPTED_EMBEDDINGS_TAG = 'TIMESTAMPED_EMBEDDINGS'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class AudioEmbedderOptions:
"""Options for the audio embedder task.
Attributes:
base_options: Base options for the audio embedder task.
running_mode: The running mode of the task. Default to the audio clips mode.
Audio embedder task has two running modes: 1) The audio clips mode for
running embedding extraction on independent audio clips. 2) The audio
stream mode for running embedding extraction on the audio stream, such as
from microphone. In this mode, the "result_callback" below must be
specified to receive the embedding results asynchronously.
embedder_options: Options for configuring the embedder behavior, such as
l2_normalize and quantize.
result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running
mode is set to the audio stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
embedder_options: _EmbedderOptions = _EmbedderOptions()
result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _AudioEmbedderGraphOptionsProto:
"""Generates an AudioEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
embedder_options_proto = self.embedder_options.to_pb2()
return _AudioEmbedderGraphOptionsProto(
base_options=base_options_proto,
embedder_options=embedder_options_proto)
class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi):
"""Class that performs embedding extraction on audio clips or audio stream."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'AudioEmbedder':
"""Creates an `AudioEmbedder` object from a TensorFlow Lite model and the default `AudioEmbedderOptions`.
Note that the created `AudioEmbedder` instance is in audio clips mode, for
embedding extraction on the independent audio clips.
Args:
model_path: Path to the model.
Returns:
`AudioEmbedder` object that's created from the model file and the
default `AudioEmbedderOptions`.
Raises:
ValueError: If failed to create `AudioEmbedder` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = AudioEmbedderOptions(
base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: AudioEmbedderOptions) -> 'AudioEmbedder':
"""Creates the `AudioEmbedder` object from audio embedder options.
Args:
options: Options for the audio embedder task.
Returns:
`AudioEmbedder` object that's created from `options`.
Raises:
ValueError: If failed to create `AudioEmbedder` object from
`AudioEmbedderOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet.Packet]):
timestamp_ms = output_packets[
_EMBEDDINGS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND
if output_packets[_EMBEDDINGS_STREAM_NAME].is_empty():
options.result_callback(
AudioEmbedderResult(embeddings=[]), timestamp_ms)
return
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_STREAM_NAME]))
options.result_callback(
AudioEmbedderResult.create_from_pb2(embedding_result_proto),
timestamp_ms)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]),
':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME])
],
output_streams=[
':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_STREAM_NAME]), ':'.join([
_TIMESTAMPTED_EMBEDDINGS_TAG,
_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME
])
],
task_options=options)
return cls(
# Audio tasks should not drop input audio due to flow limiting, which
# may cause data inconsistency.
task_info.generate_graph_config(enable_flow_limiting=False),
options.running_mode,
packets_callback if options.result_callback else None)
def embed(self, audio_clip: _AudioData) -> List[AudioEmbedderResult]:
"""Performs embedding extraction on the provided audio clips.
The audio clip is represented as a MediaPipe AudioData. The method accepts
audio clips with various length and audio sample rate. It's required to
provide the corresponding audio sample rate within the `AudioData` object.
The input audio clip may be longer than what the model is able to process
in a single inference. When this occurs, the input audio clip is split into
multiple chunks starting at different timestamps. For this reason, this
function returns a vector of EmbeddingResult objects, each associated
ith a timestamp corresponding to the start (in milliseconds) of the chunk
data on which embedding extraction was carried out.
Args:
audio_clip: MediaPipe AudioData.
Returns:
An `AudioEmbedderResult` object that contains a list of embedding result
objects, each associated with a timestamp corresponding to the start
(in milliseconds) of the chunk data on which embedding extraction was
carried out.
Raises:
ValueError: If any of the input arguments is invalid, such as the sample
rate is not provided in the `AudioData` object.
RuntimeError: If audio embedding extraction failed to run.
"""
if not audio_clip.audio_format.sample_rate:
raise ValueError('Must provide the audio sample rate in audio data.')
output_packets = self._process_audio_clip({
_AUDIO_IN_STREAM_NAME:
packet_creator.create_matrix(audio_clip.buffer, transpose=True),
_SAMPLE_RATE_IN_STREAM_NAME:
packet_creator.create_double(audio_clip.audio_format.sample_rate)
})
output_list = []
embeddings_proto_list = packet_getter.get_proto_list(
output_packets[_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME])
for proto in embeddings_proto_list:
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(proto)
output_list.append(
AudioEmbedderResult.create_from_pb2(embedding_result_proto))
return output_list
def embed_async(self, audio_block: _AudioData, timestamp_ms: int) -> None:
"""Sends audio data (a block in a continuous audio stream) to perform audio embedding extraction.
Only use this method when the AudioEmbedder is created with the audio
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input audio data is accepted. The results will be available via the
`result_callback` provided in the `AudioEmbedderOptions`. The
`embed_async` method is designed to process auido stream data such as
microphone input.
The input audio data may be longer than what the model is able to process
in a single inference. When this occurs, the input audio block is split
into multiple chunks. For this reason, the callback may be called multiple
times (once per chunk) for each call to this function.
The `result_callback` provides:
- An `AudioEmbedderResult` object that contains a list of
embeddings.
- The input timestamp in milliseconds.
Args:
audio_block: MediaPipe AudioData.
timestamp_ms: The timestamp of the input audio data in milliseconds.
Raises:
ValueError: If any of the followings:
1) The sample rate is not provided in the `AudioData` object or the
provided sample rate is inconsistent with the previously received.
2) The current input timestamp is smaller than what the audio
embedder has already processed.
"""
if not audio_block.audio_format.sample_rate:
raise ValueError('Must provide the audio sample rate in audio data.')
if not self._default_sample_rate:
self._default_sample_rate = audio_block.audio_format.sample_rate
self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME,
self._default_sample_rate)
elif audio_block.audio_format.sample_rate != self._default_sample_rate:
raise ValueError(
f'The audio sample rate provided in audio data: '
f'{audio_block.audio_format.sample_rate} is inconsistent with '
f'the previously received: {self._default_sample_rate}.')
self._send_audio_stream_data({
_AUDIO_IN_STREAM_NAME:
packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
@classmethod
def cosine_similarity(cls, u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding) -> float:
"""Utility function to compute cosine similarity between two embedding entries.
May return an InvalidArgumentError if e.g. the feature vectors are
of different types (quantized vs. float), have different sizes, or have a
an L2-norm of 0.
Args:
u: An embedding entry.
v: An embedding entry.
Returns:
The cosine similarity for the two embeddings.
Raises:
ValueError: May return an error if e.g. the feature vectors are of
different types (quantized vs. float), have different sizes, or have
an L2-norm of 0.
"""
return cosine_similarity.cosine_similarity(u, v)

View File

@ -35,3 +35,21 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_test(
name = "audio_embedder_test",
srcs = ["audio_embedder_test.py"],
data = [
"//mediapipe/tasks/testdata/audio:test_audio_clips",
"//mediapipe/tasks/testdata/audio:test_models",
],
deps = [
"//mediapipe/tasks/python/audio:audio_embedder",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:embedder_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,318 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for audio embedder."""
import enum
import os
from typing import List, Tuple
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_embedder
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.processors import embedder_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
_AudioEmbedder = audio_embedder.AudioEmbedder
_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions
_AudioEmbedderResult = audio_embedder.AudioEmbedderResult
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_EmbedderOptions = embedder_options.EmbedderOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite'
_YAMNET_MODEL_SAMPLE_RATE = 16000
_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
_SPEECH_SIMILARITIES = [0.985359, 0.994349, 0.993227, 0.996658, 0.996384]
_YAMNET_NUM_OF_SAMPLES = 15600
_MILLSECONDS_PER_SECOND = 1000
# Tolerance for embedding vector coordinate values.
_EPSILON = 3e-6
# Tolerance for cosine similarity evaluation.
_SIMILARITY_TOLERANCE = 1e-6
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class AudioEmbedderTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.yamnet_model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
def _read_wav_file(self, file_name) -> _AudioData:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
return _AudioData.create_from_array(
buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
audio_data_list = []
start = 0
step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
while start < len(buffer):
end = min(start + (int)(step_size), len(buffer))
audio_data_list.append((_AudioData.create_from_array(
buffer[start:end].astype(float) / np.iinfo(np.int16).max,
sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
start = end
return audio_data_list
def _check_embedding_value(self, result, expected_first_value):
# Check embedding first value.
self.assertAlmostEqual(
result.embeddings[0].embedding[0], expected_first_value, delta=_EPSILON)
def _check_embedding_size(self, result, quantize, expected_embedding_size):
# Check embedding size.
self.assertLen(result.embeddings, 1)
embedding_result = result.embeddings[0]
self.assertLen(embedding_result.embedding, expected_embedding_size)
if quantize:
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
else:
self.assertEqual(embedding_result.embedding.dtype, float)
def _check_cosine_similarity(self, result0, result1, expected_similarity):
# Checks cosine similarity.
similarity = _AudioEmbedder.cosine_similarity(result0.embeddings[0],
result1.embeddings[0])
self.assertAlmostEqual(
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
def _check_yamnet_result(self,
embedding_result0_list: List[_AudioEmbedderResult],
embedding_result1_list: List[_AudioEmbedderResult],
expected_similarities: List[float]):
expected_size = len(expected_similarities)
self.assertLen(embedding_result0_list, expected_size)
self.assertLen(embedding_result1_list, expected_size)
for idx in range(expected_size):
embedding_result0 = embedding_result0_list[idx]
embedding_result1 = embedding_result1_list[idx]
self._check_cosine_similarity(embedding_result0, embedding_result1,
expected_similarities[idx])
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _AudioEmbedder.create_from_model_path(
self.yamnet_model_path) as embedder:
self.assertIsInstance(embedder, _AudioEmbedder)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
with _AudioEmbedder.create_from_options(
_AudioEmbedderOptions(
base_options=_BaseOptions(
model_asset_path=self.yamnet_model_path))) as embedder:
self.assertIsInstance(embedder, _AudioEmbedder)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
options = _AudioEmbedderOptions(base_options=base_options)
_AudioEmbedder.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.yamnet_model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _AudioEmbedderOptions(base_options=base_options)
embedder = _AudioEmbedder.create_from_options(options)
self.assertIsInstance(embedder, _AudioEmbedder)
@parameterized.parameters(
# Same audio inputs but different sample rates.
(False, False, ModelFileType.FILE_NAME, _SPEECH_WAV_16K_MONO,
_SPEECH_WAV_48K_MONO, 1024, (0, 0)),
(False, False, ModelFileType.FILE_CONTENT, _SPEECH_WAV_16K_MONO,
_SPEECH_WAV_48K_MONO, 1024, (0, 0)))
def test_embed_with_yamnet_model(self, l2_normalize, quantize,
model_file_type, audio_file0, audio_file1,
expected_size, expected_first_values):
# Creates embedder.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.yamnet_model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.yamnet_model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _AudioEmbedderOptions(
base_options=base_options,
embedder_options=_EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize))
with _AudioEmbedder.create_from_options(options) as embedder:
embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0))
embedding_result1_list = embedder.embed(self._read_wav_file(audio_file1))
# Checks embeddings and cosine similarity.
expected_result0_value, expected_result1_value = expected_first_values
self._check_embedding_size(embedding_result0_list[0], quantize,
expected_size)
self._check_embedding_size(embedding_result1_list[0], quantize,
expected_size)
self._check_embedding_value(embedding_result0_list[0],
expected_result0_value)
self._check_embedding_value(embedding_result1_list[0],
expected_result1_value)
self._check_yamnet_result(
embedding_result0_list,
embedding_result1_list,
expected_similarities=_SPEECH_SIMILARITIES)
def test_embed_with_yamnet_model_and_different_inputs(self):
with _AudioEmbedder.create_from_model_path(
self.yamnet_model_path) as embedder:
embedding_result0_list = embedder.embed(
self._read_wav_file(_SPEECH_WAV_16K_MONO))
embedding_result1_list = embedder.embed(
self._read_wav_file(_TWO_HEADS_WAV_16K_MONO))
self.assertLen(embedding_result0_list, 5)
self.assertLen(embedding_result1_list, 1)
self._check_cosine_similarity(
embedding_result0_list[0],
embedding_result1_list[0],
expected_similarity=0.09017)
def test_missing_sample_rate_in_audio_clips_mode(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with self.assertRaisesRegex(ValueError,
r'Must provide the audio sample rate'):
with _AudioEmbedder.create_from_options(options) as embedder:
embedder.embed(_AudioData(buffer_length=100))
def test_missing_sample_rate_in_audio_stream_mode(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'provide the audio sample rate in audio data'):
with _AudioEmbedder.create_from_options(options) as embedder:
embedder.embed(_AudioData(buffer_length=100))
def test_missing_result_callback(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _AudioEmbedder.create_from_options(options) as unused_embedder:
pass
def test_illegal_result_callback(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _AudioEmbedder.create_from_options(options) as unused_embedder:
pass
def test_calling_embed_in_audio_stream_mode(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(ValueError,
r'not initialized with the audio clips mode'):
embedder.embed(self._read_wav_file(_SPEECH_WAV_16K_MONO))
def test_calling_embed_async_in_audio_clips_mode(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with _AudioEmbedder.create_from_options(options) as embedder:
with self.assertRaisesRegex(
ValueError, r'not initialized with the audio stream mode'):
embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
def test_embed_async_calls_with_illegal_timestamp(self):
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioEmbedder.create_from_options(options) as embedder:
embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
@parameterized.parameters(
# Same audio inputs but different sample rates.
(False, False, _SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO))
def test_embed_async(self, l2_normalize, quantize, audio_file0, audio_file1):
embedding_result_list = []
embedding_result_list_copy = embedding_result_list.copy()
def save_result(result: _AudioEmbedderResult, timestamp_ms: int):
result.timestamp_ms = timestamp_ms
embedding_result_list.append(result)
options = _AudioEmbedderOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
embedder_options=_EmbedderOptions(
l2_normalize=l2_normalize, quantize=quantize),
result_callback=save_result)
with _AudioEmbedder.create_from_options(options) as embedder:
audio_data0_list = self._read_wav_file_as_stream(audio_file0)
for audio_data, timestamp_ms in audio_data0_list:
embedder.embed_async(audio_data, timestamp_ms)
embedding_result0_list = embedding_result_list
with _AudioEmbedder.create_from_options(options) as embedder:
audio_data1_list = self._read_wav_file_as_stream(audio_file1)
embedding_result_list = embedding_result_list_copy
for audio_data, timestamp_ms in audio_data1_list:
embedder.embed_async(audio_data, timestamp_ms)
embedding_result1_list = embedding_result_list
self._check_yamnet_result(
embedding_result0_list,
embedding_result1_list,
expected_similarities=_SPEECH_SIMILARITIES)
if __name__ == '__main__':
absltest.main()