diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD index dd8719151..2e5815ff0 100644 --- a/mediapipe/tasks/python/audio/BUILD +++ b/mediapipe/tasks/python/audio/BUILD @@ -39,3 +39,26 @@ py_library( "//mediapipe/tasks/python/core:task_info", ], ) + +py_library( + name = "audio_embedder", + srcs = [ + "audio_embedder.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/audio/core:base_audio_task_api", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/components/utils:cosine_similarity", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + ], +) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index e04e778b5..a081e5ecd 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -257,7 +257,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): Raises: ValueError: If any of the followings: 1) The sample rate is not provided in the `AudioData` object or the - provided sample rate is inconsisent with the previously recevied. + provided sample rate is inconsistent with the previously received. 2) The current input timestamp is smaller than what the audio classifier has already processed. """ @@ -270,7 +270,7 @@ class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): elif audio_block.audio_format.sample_rate != self._default_sample_rate: raise ValueError( f'The audio sample rate provided in audio data: ' - f'{audio_block.audio_format.sample_rate} is inconsisent with ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' f'the previously received: {self._default_sample_rate}.') self._send_audio_stream_data({ diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py new file mode 100644 index 000000000..0580b3518 --- /dev/null +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -0,0 +1,285 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio embedder task.""" + +import dataclasses +from typing import Callable, Mapping, List, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.audio.audio_embedder.proto import audio_embedder_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2 +from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.audio.core import base_audio_task_api +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.components.processors import embedder_options as embedder_options_module +from mediapipe.tasks.python.components.utils import cosine_similarity +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +AudioEmbedderResult = embedding_result_module.EmbeddingResult +_AudioEmbedderGraphOptionsProto = audio_embedder_graph_options_pb2.AudioEmbedderGraphOptions +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptions = embedder_options_module.EmbedderOptions +_RunningMode = running_mode_module.AudioTaskRunningMode +_TaskInfo = task_info_module.TaskInfo + +_AUDIO_IN_STREAM_NAME = 'audio_in' +_AUDIO_TAG = 'AUDIO' +_EMBEDDINGS_STREAM_NAME = 'embeddings_out' +_EMBEDDINGS_TAG = 'EMBEDDINGS' +_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in' +_SAMPLE_RATE_TAG = 'SAMPLE_RATE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph' +_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME = 'timestamped_embeddings_out' +_TIMESTAMPTED_EMBEDDINGS_TAG = 'TIMESTAMPED_EMBEDDINGS' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class AudioEmbedderOptions: + """Options for the audio embedder task. + + Attributes: + base_options: Base options for the audio embedder task. + running_mode: The running mode of the task. Default to the audio clips mode. + Audio embedder task has two running modes: 1) The audio clips mode for + running embedding extraction on independent audio clips. 2) The audio + stream mode for running embedding extraction on the audio stream, such as + from microphone. In this mode, the "result_callback" below must be + specified to receive the embedding results asynchronously. + embedder_options: Options for configuring the embedder behavior, such as + l2_normalize and quantize. + result_callback: The user-defined result callback for processing audio + stream data. The result callback should only be specified when the running + mode is set to the audio stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS + embedder_options: _EmbedderOptions = _EmbedderOptions() + result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _AudioEmbedderGraphOptionsProto: + """Generates an AudioEmbedderOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True + embedder_options_proto = self.embedder_options.to_pb2() + + return _AudioEmbedderGraphOptionsProto( + base_options=base_options_proto, + embedder_options=embedder_options_proto) + + +class AudioEmbedder(base_audio_task_api.BaseAudioTaskApi): + """Class that performs embedding extraction on audio clips or audio stream.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'AudioEmbedder': + """Creates an `AudioEmbedder` object from a TensorFlow Lite model and the default `AudioEmbedderOptions`. + + Note that the created `AudioEmbedder` instance is in audio clips mode, for + embedding extraction on the independent audio clips. + + Args: + model_path: Path to the model. + + Returns: + `AudioEmbedder` object that's created from the model file and the + default `AudioEmbedderOptions`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = AudioEmbedderOptions( + base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: AudioEmbedderOptions) -> 'AudioEmbedder': + """Creates the `AudioEmbedder` object from audio embedder options. + + Args: + options: Options for the audio embedder task. + + Returns: + `AudioEmbedder` object that's created from `options`. + + Raises: + ValueError: If failed to create `AudioEmbedder` object from + `AudioEmbedderOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet.Packet]): + timestamp_ms = output_packets[ + _EMBEDDINGS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND + if output_packets[_EMBEDDINGS_STREAM_NAME].is_empty(): + options.result_callback( + AudioEmbedderResult(embeddings=[]), timestamp_ms) + return + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_EMBEDDINGS_STREAM_NAME])) + options.result_callback( + AudioEmbedderResult.create_from_pb2(embedding_result_proto), + timestamp_ms) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]), + ':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME]) + ], + output_streams=[ + ':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_STREAM_NAME]), + ':'.join([ + _TIMESTAMPTED_EMBEDDINGS_TAG, + _TIMESTAMPTED_EMBEDDINGS_STREAM_NAME + ]) + ], + task_options=options) + return cls( + # Audio tasks should not drop input audio due to flow limiting, which + # may cause data inconsistency. + task_info.generate_graph_config(enable_flow_limiting=False), + options.running_mode, + packets_callback if options.result_callback else None) + + def embed(self, audio_clip: _AudioData) -> List[AudioEmbedderResult]: + """Performs embedding extraction on the provided audio clips. + + The audio clip is represented as a MediaPipe AudioData. The method accepts + audio clips with various length and audio sample rate. It's required to + provide the corresponding audio sample rate within the `AudioData` object. + + The input audio clip may be longer than what the model is able to process + in a single inference. When this occurs, the input audio clip is split into + multiple chunks starting at different timestamps. For this reason, this + function returns a vector of EmbeddingResult objects, each associated + ith a timestamp corresponding to the start (in milliseconds) of the chunk + data on which embedding extraction was carried out. + + Args: + audio_clip: MediaPipe AudioData. + + Returns: + An `AudioEmbedderResult` object that contains a list of embedding result + objects, each associated with a timestamp corresponding to the start + (in milliseconds) of the chunk data on which embedding extraction was + carried out. + + Raises: + ValueError: If any of the input arguments is invalid, such as the sample + rate is not provided in the `AudioData` object. + RuntimeError: If audio embedding extraction failed to run. + """ + if not audio_clip.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + output_packets = self._process_audio_clip({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_clip.buffer, transpose=True), + _SAMPLE_RATE_IN_STREAM_NAME: + packet_creator.create_double(audio_clip.audio_format.sample_rate) + }) + output_list = [] + embeddings_proto_list = packet_getter.get_proto_list( + output_packets[_TIMESTAMPTED_EMBEDDINGS_STREAM_NAME]) + for proto in embeddings_proto_list: + embedding_result_proto = embeddings_pb2.EmbeddingResult() + embedding_result_proto.CopyFrom(proto) + output_list.append( + AudioEmbedderResult.create_from_pb2(embedding_result_proto)) + return output_list + + def embed_async(self, audio_block: _AudioData, timestamp_ms: int) -> None: + """Sends audio data (a block in a continuous audio stream) to perform audio embedding extraction. + + Only use this method when the AudioEmbedder is created with the audio + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input audio data is accepted. The results will be available via the + `result_callback` provided in the `AudioEmbedderOptions`. The + `embed_async` method is designed to process auido stream data such as + microphone input. + + The input audio data may be longer than what the model is able to process + in a single inference. When this occurs, the input audio block is split + into multiple chunks. For this reason, the callback may be called multiple + times (once per chunk) for each call to this function. + + The `result_callback` provides: + - An `AudioEmbedderResult` object that contains a list of + embeddings. + - The input timestamp in milliseconds. + + Args: + audio_block: MediaPipe AudioData. + timestamp_ms: The timestamp of the input audio data in milliseconds. + + Raises: + ValueError: If any of the followings: + 1) The sample rate is not provided in the `AudioData` object or the + provided sample rate is inconsistent with the previously received. + 2) The current input timestamp is smaller than what the audio + embedder has already processed. + """ + if not audio_block.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + if not self._default_sample_rate: + self._default_sample_rate = audio_block.audio_format.sample_rate + self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME, + self._default_sample_rate) + elif audio_block.audio_format.sample_rate != self._default_sample_rate: + raise ValueError( + f'The audio sample rate provided in audio data: ' + f'{audio_block.audio_format.sample_rate} is inconsistent with ' + f'the previously received: {self._default_sample_rate}.') + + self._send_audio_stream_data({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_block.buffer, transpose=True).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + + @classmethod + def cosine_similarity(cls, u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding) -> float: + """Utility function to compute cosine similarity between two embedding entries. + + May return an InvalidArgumentError if e.g. the feature vectors are + of different types (quantized vs. float), have different sizes, or have a + an L2-norm of 0. + + Args: + u: An embedding entry. + v: An embedding entry. + + Returns: + The cosine similarity for the two embeddings. + + Raises: + ValueError: May return an error if e.g. the feature vectors are of + different types (quantized vs. float), have different sizes, or have + an L2-norm of 0. + """ + return cosine_similarity.cosine_similarity(u, v) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD index 863449126..9278cea55 100644 --- a/mediapipe/tasks/python/test/audio/BUILD +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -35,3 +35,21 @@ py_test( "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "audio_embedder_test", + srcs = ["audio_embedder_test.py"], + data = [ + "//mediapipe/tasks/testdata/audio:test_audio_clips", + "//mediapipe/tasks/testdata/audio:test_models", + ], + deps = [ + "//mediapipe/tasks/python/audio:audio_embedder", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:embedding_result", + "//mediapipe/tasks/python/components/processors:embedder_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/audio/audio_embedder_test.py b/mediapipe/tasks/python/test/audio/audio_embedder_test.py new file mode 100644 index 000000000..d085317e6 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/audio_embedder_test.py @@ -0,0 +1,317 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio embedder.""" +import enum +import os +from typing import List, Tuple +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +from scipy.io import wavfile + +from mediapipe.tasks.python.audio import audio_embedder +from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module +from mediapipe.tasks.python.components.processors import embedder_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils + +_AudioEmbedder = audio_embedder.AudioEmbedder +_AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +_AudioEmbedderResult = embedding_result_module.EmbeddingResult +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_EmbedderOptions = embedder_options.EmbedderOptions +_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode + +_YAMNET_MODEL_FILE = 'yamnet_embedding_metadata.tflite' +_YAMNET_MODEL_SAMPLE_RATE = 16000 +_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav' +_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav' +_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio' +_SPEECH_SIMILARITIES = [0.985359, 0.994349, 0.993227, 0.996658, 0.996384] +_YAMNET_NUM_OF_SAMPLES = 15600 +_MILLSECONDS_PER_SECOND = 1000 +# Tolerance for embedding vector coordinate values. +_EPSILON = 3e-6 +# Tolerance for cosine similarity evaluation. +_SIMILARITY_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class AudioEmbedderTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.yamnet_model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE)) + + def _read_wav_file(self, file_name) -> _AudioData: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + return _AudioData.create_from_array( + buffer.astype(float) / np.iinfo(np.int16).max, sample_rate) + + def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + audio_data_list = [] + start = 0 + step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE + while start < len(buffer): + end = min(start + (int)(step_size), len(buffer)) + audio_data_list.append((_AudioData.create_from_array( + buffer[start:end].astype(float) / np.iinfo(np.int16).max, + sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND))) + start = end + return audio_data_list + + def _check_embedding_value(self, result, expected_first_value): + # Check embedding first value. + self.assertAlmostEqual( + result.embeddings[0].embedding[0], expected_first_value, delta=_EPSILON) + + def _check_embedding_size(self, result, quantize, expected_embedding_size): + # Check embedding size. + self.assertLen(result.embeddings, 1) + embedding_result = result.embeddings[0] + self.assertLen(embedding_result.embedding, expected_embedding_size) + if quantize: + self.assertEqual(embedding_result.embedding.dtype, np.uint8) + else: + self.assertEqual(embedding_result.embedding.dtype, float) + + def _check_cosine_similarity(self, result0, result1, expected_similarity): + # Checks cosine similarity. + similarity = _AudioEmbedder.cosine_similarity(result0.embeddings[0], + result1.embeddings[0]) + self.assertAlmostEqual( + similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) + + # TODO: Compares the exact score values to capture unexpected + # changes in the inference pipeline. + def _check_yamnet_result( + self, + embedding_result0_list: List[_AudioEmbedderResult], + embedding_result1_list: List[_AudioEmbedderResult], + expected_similarities: List[float]): + expected_size = len(expected_similarities) + self.assertLen(embedding_result0_list, expected_size) + self.assertLen(embedding_result1_list, expected_size) + + for idx in range(expected_size): + embedding_result0 = embedding_result0_list[idx] + embedding_result1 = embedding_result1_list[idx] + self._check_cosine_similarity(embedding_result0, embedding_result1, + expected_similarities[idx]) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + with _AudioEmbedder.create_from_options( + _AudioEmbedderOptions( + base_options=_BaseOptions( + model_asset_path=self.yamnet_model_path))) as embedder: + self.assertIsInstance(embedder, _AudioEmbedder) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') + options = _AudioEmbedderOptions(base_options=base_options) + _AudioEmbedder.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.yamnet_model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _AudioEmbedderOptions(base_options=base_options) + embedder = _AudioEmbedder.create_from_options(options) + self.assertIsInstance(embedder, _AudioEmbedder) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, ModelFileType.FILE_NAME, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0)), + (False, False, ModelFileType.FILE_CONTENT, _SPEECH_WAV_16K_MONO, + _SPEECH_WAV_48K_MONO, 1024, (0, 0))) + def test_embed_with_yamnet_model( + self, l2_normalize, quantize, model_file_type, audio_file0, audio_file1, + expected_size, expected_first_values): + # Creates embedder. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.yamnet_model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.yamnet_model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + embedder_options = _EmbedderOptions( + l2_normalize=l2_normalize, quantize=quantize) + options = _AudioEmbedderOptions( + base_options=base_options, embedder_options=embedder_options) + + with _AudioEmbedder.create_from_options(options) as embedder: + embedding_result0_list = embedder.embed(self._read_wav_file(audio_file0)) + embedding_result1_list = embedder.embed(self._read_wav_file(audio_file1)) + + # Checks embeddings and cosine similarity. + expected_result0_value, expected_result1_value = expected_first_values + self._check_embedding_size(embedding_result0_list[0], quantize, + expected_size) + self._check_embedding_size(embedding_result1_list[0], quantize, + expected_size) + self._check_embedding_value(embedding_result0_list[0], + expected_result0_value) + self._check_embedding_value(embedding_result1_list[0], + expected_result1_value) + self._check_yamnet_result(embedding_result0_list, embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + def test_embed_with_yamnet_model_and_different_inputs(self): + with _AudioEmbedder.create_from_model_path( + self.yamnet_model_path) as embedder: + embedding_result0_list = embedder.embed( + self._read_wav_file(_SPEECH_WAV_16K_MONO)) + embedding_result1_list = embedder.embed( + self._read_wav_file(_TWO_HEADS_WAV_16K_MONO)) + self.assertLen(embedding_result0_list, 5) + self.assertLen(embedding_result1_list, 1) + self._check_cosine_similarity(embedding_result0_list[0], + embedding_result1_list[0], + expected_similarity=0.09017) + + def test_missing_sample_rate_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with self.assertRaisesRegex(ValueError, + r'Must provide the audio sample rate'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_sample_rate_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'provide the audio sample rate in audio data'): + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed(_AudioData(buffer_length=100)) + + def test_missing_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_illegal_result_callback(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _AudioEmbedder.create_from_options(options) as unused_embedder: + pass + + def test_calling_embed_in_audio_stream_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex(ValueError, + r'not initialized with the audio clips mode'): + embedder.embed(self._read_wav_file(_SPEECH_WAV_16K_MONO)) + + def test_calling_embed_async_in_audio_clips_mode(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with _AudioEmbedder.create_from_options(options) as embedder: + with self.assertRaisesRegex( + ValueError, r'not initialized with the audio stream mode'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + def test_embed_async_calls_with_illegal_timestamp(self): + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioEmbedder.create_from_options(options) as embedder: + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + embedder.embed_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + @parameterized.parameters( + # Same audio inputs but different sample rates. + (False, False, _SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO)) + def test_embed_async(self, l2_normalize, quantize, audio_file0, audio_file1): + embedding_result_list = [] + embedding_result_list_copy = embedding_result_list.copy() + + def save_result(result: _AudioEmbedderResult, timestamp_ms: int): + result.timestamp_ms = timestamp_ms + embedding_result_list.append(result) + + options = _AudioEmbedderOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + embedder_options=_EmbedderOptions(l2_normalize=l2_normalize, + quantize=quantize), + result_callback=save_result) + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data0_list = self._read_wav_file_as_stream(audio_file0) + for audio_data, timestamp_ms in audio_data0_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result0_list = embedding_result_list + + with _AudioEmbedder.create_from_options(options) as embedder: + audio_data1_list = self._read_wav_file_as_stream(audio_file1) + embedding_result_list = embedding_result_list_copy + for audio_data, timestamp_ms in audio_data1_list: + embedder.embed_async(audio_data, timestamp_ms) + embedding_result1_list = embedding_result_list + + self._check_yamnet_result(embedding_result0_list, embedding_result1_list, + expected_similarities=_SPEECH_SIMILARITIES) + + +if __name__ == '__main__': + absltest.main()