mediapipe/mediapipe/tasks/python/test/audio/audio_classifier_test.py
2022-11-08 10:36:24 -08:00

380 lines
18 KiB
Python

# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for audio classifier."""
import os
from typing import List, Tuple
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_classifier
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
_AudioClassifier = audio_classifier.AudioClassifier
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
_AudioClassifierResult = classification_result_module.ClassificationResult
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
_YAMNET_MODEL_SAMPLE_RATE = 16000
_TWO_HEADS_MODEL_FILE = 'two_heads.tflite'
_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav'
_YAMNET_NUM_OF_SAMPLES = 15600
_MILLSECONDS_PER_SECOND = 1000
class AudioClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.yamnet_model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
self.two_heads_model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE))
def _read_wav_file(self, file_name) -> _AudioData:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
return _AudioData.create_from_array(
buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
audio_data_list = []
start = 0
step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
while start < len(buffer):
end = min(start + (int)(step_size), len(buffer))
audio_data_list.append((_AudioData.create_from_array(
buffer[start:end].astype(float) / np.iinfo(np.int16).max,
sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
start = end
return audio_data_list
# TODO: Compares the exact score values to capture unexpected
# changes in the inference pipeline.
def _check_yamnet_result(
self,
classification_result_list: List[_AudioClassifierResult],
expected_num_categories=521):
self.assertLen(classification_result_list, 5)
for idx, timestamp in enumerate([0, 975, 1950, 2925]):
classification_result = classification_result_list[idx]
self.assertEqual(classification_result.timestamp_ms, timestamp)
self.assertLen(classification_result.classifications, 1)
classifcation = classification_result.classifications[0]
self.assertEqual(classifcation.head_index, 0)
self.assertEqual(classifcation.head_name, 'scores')
self.assertLen(classifcation.categories, expected_num_categories)
audio_category = classifcation.categories[0]
self.assertEqual(audio_category.index, 0)
self.assertEqual(audio_category.category_name, 'Speech')
self.assertGreater(audio_category.score, 0.9)
# TODO: Compares the exact score values to capture unexpected
# changes in the inference pipeline.
def _check_two_heads_result(
self,
classification_result_list: List[_AudioClassifierResult],
first_head_expected_num_categories=521,
second_head_expected_num_categories=5):
self.assertGreaterEqual(len(classification_result_list), 1)
self.assertLessEqual(len(classification_result_list), 2)
# Checks the first result.
classification_result = classification_result_list[0]
self.assertEqual(classification_result.timestamp_ms, 0)
self.assertLen(classification_result.classifications, 2)
# Checks the first head.
yamnet_classifcation = classification_result.classifications[0]
self.assertEqual(yamnet_classifcation.head_index, 0)
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
self.assertLen(yamnet_classifcation.categories,
first_head_expected_num_categories)
# Checks the second head.
yamnet_category = yamnet_classifcation.categories[0]
self.assertEqual(yamnet_category.index, 508)
self.assertEqual(yamnet_category.category_name, 'Environmental noise')
self.assertGreater(yamnet_category.score, 0.5)
bird_classifcation = classification_result.classifications[1]
self.assertEqual(bird_classifcation.head_index, 1)
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
self.assertLen(bird_classifcation.categories,
second_head_expected_num_categories)
bird_category = bird_classifcation.categories[0]
self.assertEqual(bird_category.index, 4)
self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta')
self.assertGreater(bird_category.score, 0.93)
# Checks the second result, if present.
if len(classification_result_list) == 2:
classification_result = classification_result_list[1]
self.assertEqual(classification_result.timestamp_ms, 975)
self.assertLen(classification_result.classifications, 2)
# Checks the first head.
yamnet_classifcation = classification_result.classifications[0]
self.assertEqual(yamnet_classifcation.head_index, 0)
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
self.assertLen(yamnet_classifcation.categories,
first_head_expected_num_categories)
yamnet_category = yamnet_classifcation.categories[0]
self.assertEqual(yamnet_category.index, 494)
self.assertEqual(yamnet_category.category_name, 'Silence')
self.assertGreater(yamnet_category.score, 0.9)
bird_classifcation = classification_result.classifications[1]
self.assertEqual(bird_classifcation.head_index, 1)
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
self.assertLen(bird_classifcation.categories,
second_head_expected_num_categories)
# Checks the second head.
bird_category = bird_classifcation.categories[0]
self.assertEqual(bird_category.index, 1)
self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren')
self.assertGreater(bird_category.score, 0.99)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(
model_asset_path=self.yamnet_model_path))) as classifier:
self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
options = _AudioClassifierOptions(base_options=base_options)
_AudioClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.yamnet_model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _AudioClassifierOptions(base_options=base_options)
classifier = _AudioClassifier.create_from_options(options)
self.assertIsInstance(classifier, _AudioClassifier)
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
def test_classify_with_yamnet_model(self, audio_file):
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(classification_result_list)
def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates(
self):
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(classification_result_list)
def test_max_result_options(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_score_threshold_options(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
score_threshold=0.9))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_allow_list_option(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['Speech']))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar']))
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
@parameterized.parameters((_TWO_HEADS_WAV_16K_MONO),
(_TWO_HEADS_WAV_44K_MONO))
def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates(
self, audio_file):
with _AudioClassifier.create_from_model_path(
self.two_heads_model_path) as classifier:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list)
def test_classify_with_two_heads_model(self):
with _AudioClassifier.create_from_model_path(
self.two_heads_model_path) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list)
def test_classify_with_two_heads_model_with_max_results(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(
model_asset_path=self.two_heads_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list, 1, 1)
def test_missing_sample_rate_in_audio_clips_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with self.assertRaisesRegex(ValueError,
r'Must provide the audio sample rate'):
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify(_AudioData(buffer_length=100))
def test_missing_sample_rate_in_audio_stream_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'provide the audio sample rate in audio data'):
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify(_AudioData(buffer_length=100))
def test_missing_result_callback(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
def test_illegal_result_callback(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
def test_calling_classify_in_audio_stream_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the audio clips mode'):
classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO))
def test_calling_classify_async_in_audio_clips_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with _AudioClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(
ValueError, r'not initialized with the audio stream mode'):
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
def test_classify_async_calls_with_illegal_timestamp(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
def test_classify_async(self, audio_file):
classification_result_list = []
def save_result(result: _AudioClassifierResult, timestamp_ms: int):
result.timestamp_ms = timestamp_ms
classification_result_list.append(result)
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
classifier_options=_ClassifierOptions(max_results=1),
result_callback=save_result)
classifier = _AudioClassifier.create_from_options(options)
audio_data_list = self._read_wav_file_as_stream(audio_file)
for audio_data, timestamp_ms in audio_data_list:
classifier.classify_async(audio_data, timestamp_ms)
classifier.close()
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
if __name__ == '__main__':
absltest.main()