Implement MediaPipe AudioClassifier Tasks Python API. Adjust the AudioClassifier Tasks C++ API to remove "sample_rate" from its options.

PiperOrigin-RevId: 486763992
This commit is contained in:
Jiuqiang Tang 2022-11-07 14:26:21 -08:00 committed by Copybara-Service
parent 51dbd9779c
commit 63a759accc
19 changed files with 959 additions and 85 deletions

View File

@ -87,6 +87,7 @@ cc_library(
cc_library(
name = "builtin_task_graphs",
deps = [
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",

View File

@ -361,7 +361,7 @@ void PublicPacketCreators(pybind11::module* m) {
packet = mp.packet_creator.create_float(0.1)
data = mp.packet_getter.get_float(packet)
)doc",
py::arg().noconvert(), py::return_value_policy::move);
py::return_value_policy::move);
m->def(
"create_double", [](double data) { return MakePacket<double>(data); },
@ -380,7 +380,7 @@ void PublicPacketCreators(pybind11::module* m) {
packet = mp.packet_creator.create_double(0.1)
data = mp.packet_getter.get_float(packet)
)doc",
py::arg().noconvert(), py::return_value_policy::move);
py::return_value_policy::move);
m->def(
"create_int_array",

View File

@ -37,7 +37,6 @@ cc_library(
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",

View File

@ -63,10 +63,8 @@ CalculatorGraphConfig CreateGraphConfig(
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName);
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
if (!options_proto->base_options().use_stream_mode()) {
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
subgraph.In(kSampleRateTag);
}
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
@ -93,9 +91,6 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
&(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get());
if (options->sample_rate > 0) {
options_proto->set_default_input_audio_sample_rate(options->sample_rate);
}
return options_proto;
}
@ -129,14 +124,6 @@ absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
/* static */
absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
std::unique_ptr<AudioClassifierOptions> options) {
if (options->running_mode == core::RunningMode::AUDIO_STREAM &&
options->sample_rate < 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"The audio classifier is in audio stream mode, the sample rate must be "
"specified in the AudioClassifierOptions.",
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
@ -161,7 +148,9 @@ absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
}
absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
double audio_sample_rate,
int64 timestamp_ms) {
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
return SendAudioStreamData(
{{kAudioStreamName,
MakePacket<Matrix>(std::move(audio_block))

View File

@ -52,15 +52,10 @@ struct AudioClassifierOptions {
// 1) The audio clips mode for running classification on independent audio
// clips.
// 2) The audio stream mode for running classification on the audio stream,
// such as from microphone. In this mode, the "sample_rate" below must be
// provided, and the "result_callback" below must be specified to receive
// the classification results asynchronously.
// such as from microphone. In this mode, the "result_callback" below must
// be specified to receive the classification results asynchronously.
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
// The sample rate of the input audios. Must be set when the running mode is
// set to RunningMode::AUDIO_STREAM.
double sample_rate = -1.0;
// The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM.
@ -160,15 +155,17 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// The audio block is represented as a MediaPipe Matrix that has the number
// of channels rows and the number of samples per channel columns. The audio
// data will be resampled, accumulated, and framed to the proper size for the
// underlying model to consume. It's required to provide a timestamp (in
// milliseconds) to indicate the start time of the input audio block. The
// underlying model to consume. It's required to provide the corresponding
// audio sample rate along with the input audio block as well as a timestamp
// (in milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing.
//
// The input audio block may be longer than what the model is able to process
// in a single inference. When this occurs, the input audio block is split
// into multiple chunks. For this reason, the callback may be called multiple
// times (once per chunk) for each call to this function.
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
absl::Status ClassifyAsync(mediapipe::Matrix audio_block,
double audio_sample_rate, int64 timestamp_ms);
// Shuts down the AudioClassifier when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -72,18 +72,6 @@ struct AudioClassifierOutputStreams {
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() &&
!options.has_default_input_audio_sample_rate()) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"In the streaming mode, the default input "
"audio sample rate must be set.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
const core::ModelResources& model_resources) {
@ -170,19 +158,12 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
const auto* model_resources,
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
Graph graph;
const bool use_stream_mode =
sc->Options<proto::AudioClassifierGraphOptions>()
.base_options()
.use_stream_mode();
ASSIGN_OR_RETURN(
auto output_streams,
BuildAudioClassificationTask(
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)],
use_stream_mode
? absl::nullopt
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
graph));
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
@ -207,7 +188,6 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
const bool use_stream_mode = task_options.base_options().use_stream_mode();
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
// Checks that metadata is available.

View File

@ -70,6 +70,8 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix();
}
// TODO: Compares the exact score values to capture unexpected
// changes in the inference pipeline.
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
int expected_num_categories = 521) {
EXPECT_EQ(result.size(), 5);
@ -90,13 +92,15 @@ void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
}
}
// TODO: Compares the exact score values to capture unexpected
// changes in the inference pipeline.
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_GE(result.size(), 1);
EXPECT_LE(result.size(), 2);
// Check first result.
// Check the first result.
EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_EQ(result[0].classifications.size(), 2);
// Check first head.
// Check the first head.
EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
@ -104,19 +108,19 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
"Environmental noise");
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
// Check second head.
// Check the second head.
EXPECT_EQ(result[0].classifications[1].head_index, 1);
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
"Chestnut-crowned Antpitta");
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
// Check second result, if present.
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.93f);
// Check the second result, if present.
if (result.size() == 2) {
EXPECT_EQ(result[1].timestamp_ms, 975);
EXPECT_EQ(result[1].classifications.size(), 2);
// Check first head.
// Check the first head.
EXPECT_EQ(result[1].classifications[0].head_index, 0);
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
@ -124,7 +128,7 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
"Silence");
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
// Check second head.
// Check the second head.
EXPECT_EQ(result[1].classifications[1].head_index, 1);
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
@ -234,7 +238,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = 16000;
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options));
@ -266,25 +269,6 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
auto options = std::make_unique<AudioClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options));
EXPECT_EQ(audio_classifier_or.status().code(),
absl::StatusCode::kInvalidArgument);
EXPECT_THAT(audio_classifier_or.status().message(),
HasSubstr("the sample rate must be specified"));
EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class ClassifyTest : public tflite_shims::testing::Test {};
TEST_F(ClassifyTest, Succeeds) {
@ -493,7 +477,6 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz;
std::vector<AudioClassifierResult> outputs;
options->result_callback =
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
@ -506,7 +489,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
int num_samples = std::min((int)(audio_buffer.cols() - start_col),
kYamnetNumOfAudioSamples * 3);
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
audio_buffer.block(0, start_col, 1, num_samples),
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
start_col * kMilliSecondsPerSecond / kSampleRateHz));
start_col += kYamnetNumOfAudioSamples * 3;
}
@ -523,7 +506,6 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz;
std::vector<AudioClassifierResult> outputs;
options->result_callback =
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
@ -538,7 +520,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
std::min((int)(audio_buffer.cols() - start_col),
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
audio_buffer.block(0, start_col, 1, num_samples),
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
start_col * kMilliSecondsPerSecond / kSampleRateHz));
start_col += num_samples;
}

View File

@ -72,8 +72,39 @@ class BaseAudioTaskApi : public tasks::core::BaseTaskApi {
return runner_->Send(std::move(inputs));
}
// Checks or sets the sample rate in the audio stream mode.
absl::Status CheckOrSetSampleRate(std::string sample_rate_stream_name,
double sample_rate) {
if (running_mode_ != RunningMode::AUDIO_STREAM) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("Task is not initialized with the audio stream mode. "
"Current running mode:",
GetRunningModeName(running_mode_)),
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError);
}
if (default_sample_rate_ > 0) {
if (std::fabs(sample_rate - default_sample_rate_) >
std::numeric_limits<double>::epsilon()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("The input audio sample rate: ", sample_rate,
" is inconsistent with the previously provided: ",
default_sample_rate_),
MediaPipeTasksStatus::kInvalidArgumentError);
}
} else {
default_sample_rate_ = sample_rate;
MP_RETURN_IF_ERROR(runner_->Send(
{{sample_rate_stream_name, MakePacket<double>(default_sample_rate_)
.At(Timestamp::PreStream())}}));
}
return absl::OkStatus();
}
private:
RunningMode running_mode_;
double default_sample_rate_ = -1.0;
};
} // namespace core

View File

@ -0,0 +1,41 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "audio_classifier",
srcs = [
"audio_classifier.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,280 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio classifier task."""
import dataclasses
from typing import Callable, Mapping, List, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.audio.core import base_audio_task_api
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
AudioClassifierResult = classification_result_module.ClassificationResult
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options_module.ClassifierOptions
_RunningMode = running_mode_module.AudioTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_AUDIO_IN_STREAM_NAME = 'audio_in'
_AUDIO_TAG = 'AUDIO'
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in'
_SAMPLE_RATE_TAG = 'SAMPLE_RATE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME = 'timestamped_classifications_out'
_TIMESTAMPED_CLASSIFICATIONS_TAG = 'TIMESTAMPED_CLASSIFICATIONS'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class AudioClassifierOptions:
"""Options for the audio classifier task.
Attributes:
base_options: Base options for the audio classifier task.
running_mode: The running mode of the task. Default to the audio clips mode.
Audio classifier task has two running modes: 1) The audio clips mode for
running classification on independent audio clips. 2) The audio stream
mode for running classification on the audio stream, such as from
microphone. In this mode, the "result_callback" below must be specified
to receive the classification results asynchronously.
classifier_options: Options for configuring the classifier behavior, such as
score threshold, number of results, etc.
result_callback: The user-defined result callback for processing audio
stream data. The result callback should only be specified when the running
mode is set to the audio stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _AudioClassifierGraphOptionsProto:
"""Generates an AudioClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
classifier_options_proto = self.classifier_options.to_pb2()
return _AudioClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
class AudioClassifier(base_audio_task_api.BaseAudioTaskApi):
"""Class that performs audio classification on audio data."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'AudioClassifier':
"""Creates an `AudioClassifier` object from a TensorFlow Lite model and the default `AudioClassifierOptions`.
Note that the created `AudioClassifier` instance is in audio clips mode, for
classifying on independent audio clips.
Args:
model_path: Path to the model.
Returns:
`AudioClassifier` object that's created from the model file and the
default `AudioClassifierOptions`.
Raises:
ValueError: If failed to create `AudioClassifier` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = AudioClassifierOptions(
base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: AudioClassifierOptions) -> 'AudioClassifier':
"""Creates the `AudioClassifier` object from audio classifier options.
Args:
options: Options for the audio classifier task.
Returns:
`AudioClassifier` object that's created from `options`.
Raises:
ValueError: If failed to create `AudioClassifier` object from
`AudioClassifierOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet.Packet]):
timestamp_ms = output_packets[
_CLASSIFICATIONS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND
if output_packets[_CLASSIFICATIONS_STREAM_NAME].is_empty():
options.result_callback(
AudioClassifierResult(classifications=[]), timestamp_ms)
return
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
options.result_callback(
AudioClassifierResult.create_from_pb2(classification_result_proto),
timestamp_ms)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]),
':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME])
],
output_streams=[
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
':'.join([
_TIMESTAMPED_CLASSIFICATIONS_TAG,
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME
])
],
task_options=options)
return cls(
# Audio tasks should not drop input audio due to flow limiting, which
# may cause data inconsistency.
task_info.generate_graph_config(enable_flow_limiting=False),
options.running_mode,
packets_callback if options.result_callback else None)
def classify(self, audio_clip: _AudioData) -> List[AudioClassifierResult]:
"""Performs audio classification on the provided audio clip.
The audio clip is represented as a MediaPipe AudioData. The method accepts
audio clips with various length and audio sample rate. It's required to
provide the corresponding audio sample rate within the `AudioData` object.
The input audio clip may be longer than what the model is able to process
in a single inference. When this occurs, the input audio clip is split into
multiple chunks starting at different timestamps. For this reason, this
function returns a vector of ClassificationResult objects, each associated
ith a timestamp corresponding to the start (in milliseconds) of the chunk
data that was classified, e.g:
ClassificationResult #0 (first chunk of data):
timestamp_ms: 0 (starts at 0ms)
classifications #0 (single head model):
category #0:
category_name: "Speech"
score: 0.6
category #1:
category_name: "Music"
score: 0.2
ClassificationResult #1 (second chunk of data):
timestamp_ms: 800 (starts at 800ms)
classifications #0 (single head model):
category #0:
category_name: "Speech"
score: 0.5
category #1:
category_name: "Silence"
score: 0.1
Args:
audio_clip: MediaPipe AudioData.
Returns:
An `AudioClassifierResult` object that contains a list of
classification result objects, each associated with a timestamp
corresponding to the start (in milliseconds) of the chunk data that was
classified.
Raises:
ValueError: If any of the input arguments is invalid, such as the sample
rate is not provided in the `AudioData` object.
RuntimeError: If audio classification failed to run.
"""
if not audio_clip.audio_format.sample_rate:
raise ValueError('Must provide the audio sample rate in audio data.')
output_packets = self._process_audio_clip({
_AUDIO_IN_STREAM_NAME:
packet_creator.create_matrix(audio_clip.buffer, transpose=True),
_SAMPLE_RATE_IN_STREAM_NAME:
packet_creator.create_double(audio_clip.audio_format.sample_rate)
})
output_list = []
classification_result_proto_list = packet_getter.get_proto_list(
output_packets[_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME])
for proto in classification_result_proto_list:
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(proto)
output_list.append(
AudioClassifierResult.create_from_pb2(classification_result_proto))
return output_list
def classify_async(self, audio_block: _AudioData, timestamp_ms: int) -> None:
"""Sends audio data (a block in a continuous audio stream) to perform audio classification.
Only use this method when the AudioClassifier is created with the audio
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input audio data is accepted. The results will be available via the
`result_callback` provided in the `AudioClassifierOptions`. The
`classify_async` method is designed to process auido stream data such as
microphone input.
The input audio data may be longer than what the model is able to process
in a single inference. When this occurs, the input audio block is split
into multiple chunks. For this reason, the callback may be called multiple
times (once per chunk) for each call to this function.
The `result_callback` provides:
- An `AudioClassifierResult` object that contains a list of
classifications.
- The input timestamp in milliseconds.
Args:
audio_block: MediaPipe AudioData.
timestamp_ms: The timestamp of the input audio data in milliseconds.
Raises:
ValueError: If any of the followings:
1) The sample rate is not provided in the `AudioData` object or the
provided sample rate is inconsisent with the previously recevied.
2) The current input timestamp is smaller than what the audio
classifier has already processed.
"""
if not audio_block.audio_format.sample_rate:
raise ValueError('Must provide the audio sample rate in audio data.')
if not self._default_sample_rate:
self._default_sample_rate = audio_block.audio_format.sample_rate
self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME,
self._default_sample_rate)
elif audio_block.audio_format.sample_rate != self._default_sample_rate:
raise ValueError(
f'The audio sample rate provided in audio data: '
f'{audio_block.audio_format.sample_rate} is inconsisent with '
f'the previously received: {self._default_sample_rate}.')
self._send_audio_stream_data({
_AUDIO_IN_STREAM_NAME:
packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})

View File

@ -32,6 +32,7 @@ py_library(
":audio_task_running_mode",
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -16,14 +16,17 @@
from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2
from mediapipe.python import packet_creator
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.python._framework_bindings import timestamp as timestamp_module
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_TaskRunner = task_runner_module.TaskRunner
_Packet = packet_module.Packet
_RunningMode = running_mode_module.AudioTaskRunningMode
_Timestamp = timestamp_module.Timestamp
class BaseAudioTaskApi(object):
@ -59,6 +62,7 @@ class BaseAudioTaskApi(object):
'callback should not be provided.')
self._runner = _TaskRunner.create(graph_config, packet_callback)
self._running_mode = running_mode
self._default_sample_rate = None
def _process_audio_clip(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
@ -82,6 +86,27 @@ class BaseAudioTaskApi(object):
+ self._running_mode.name)
return self._runner.process(inputs)
def _set_sample_rate(self, sample_rate_stream_name: str,
sample_rate: float) -> None:
"""An asynchronous method to set audio sample rate in the audio stream mode.
Args:
sample_rate_stream_name: The audio sample rate stream name.
sample_rate: The audio sample rate.
Raises:
ValueError: If the task's running mode is not set to the audio stream
mode.
"""
if self._running_mode != _RunningMode.AUDIO_STREAM:
raise ValueError(
'Task is not initialized with the audio stream mode. Current running mode:'
+ self._running_mode.name)
self._runner.send({
sample_rate_stream_name:
packet_creator.create_double(sample_rate).at(_Timestamp.PRESTREAM)
})
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
"""An asynchronous method to send audio stream data to the runner.

View File

@ -94,3 +94,13 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "classification_result",
srcs = ["classification_result.py"],
deps = [
":category",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -20,7 +20,7 @@ import numpy as np
@dataclasses.dataclass
class AudioFormat:
class AudioDataFormat:
"""Audio format metadata.
Attributes:
@ -35,8 +35,10 @@ class AudioData(object):
"""MediaPipe Tasks' audio container."""
def __init__(
self, buffer_length: int,
audio_format: AudioFormat = AudioFormat()) -> None:
self,
buffer_length: int,
audio_format: AudioDataFormat = AudioDataFormat()
) -> None:
"""Initializes the `AudioData` object.
Args:
@ -113,14 +115,14 @@ class AudioData(object):
"""
obj = cls(
buffer_length=src.shape[0],
audio_format=AudioFormat(
audio_format=AudioDataFormat(
num_channels=1 if len(src.shape) == 1 else src.shape[1],
sample_rate=sample_rate))
obj.load_from_array(src)
return obj
@property
def audio_format(self) -> AudioFormat:
def audio_format(self) -> AudioDataFormat:
"""Gets the audio format of the audio."""
return self._audio_format

View File

@ -0,0 +1,92 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifications data class."""
import dataclasses
from typing import List, Optional
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@dataclasses.dataclass
class Classifications:
"""Represents the classification results for a given classifier head.
Attributes:
categories: The array of predicted categories, usually sorted by descending
scores (e.g. from high to low probability).
head_index: The index of the classifier head these categories refer to. This
is useful for multi-head models.
head_name: The name of the classifier head, which is the corresponding
tensor metadata name.
"""
categories: List[category_module.Category]
head_index: int
head_name: Optional[str] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
categories = []
for entry in pb2_obj.classification_list.classification:
categories.append(
category_module.Category(
index=entry.index,
score=entry.score,
display_name=entry.display_name,
category_name=entry.label))
return Classifications(
categories=categories,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
@dataclasses.dataclass
class ClassificationResult:
"""Contains the classification results of a model.
Attributes:
classifications: A list of `Classifications` objects, each for a head of the
model.
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
chunk of data corresponding to these results. This is only used for
classification on time series (e.g. audio classification). In these use
cases, the amount of data to process might exceed the maximum size that
the model can process: to solve this, the input data is split into
multiple chunks starting at different timestamps.
"""
classifications: List[Classifications]
timestamp_ms: Optional[int] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
"""Creates a `ClassificationResult` object from the given protobuf object.
"""
return ClassificationResult(
classifications=[
Classifications.create_from_pb2(classification)
for classification in pb2_obj.classifications
],
timestamp_ms=pb2_obj.timestamp_ms)

View File

@ -0,0 +1,37 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_test(
name = "audio_classifier_test",
srcs = ["audio_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/audio:test_audio_clips",
"//mediapipe/tasks/testdata/audio:test_models",
],
deps = [
"//mediapipe/tasks/python/audio:audio_classifier",
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
"//mediapipe/tasks/python/components/containers:audio_data",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,381 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for audio classifier."""
import os
from typing import List, Tuple
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from scipy.io import wavfile
from mediapipe.tasks.python.audio import audio_classifier
from mediapipe.tasks.python.audio.core import audio_task_running_mode
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
_AudioClassifier = audio_classifier.AudioClassifier
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
_AudioClassifierResult = classification_result_module.ClassificationResult
_AudioData = audio_data_module.AudioData
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
_YAMNET_MODEL_SAMPLE_RATE = 16000
_TWO_HEADS_MODEL_FILE = 'two_heads.tflite'
_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav'
_YAMNET_NUM_OF_SAMPLES = 15600
_MILLSECONDS_PER_SECOND = 1000
class AudioClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.yamnet_model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
self.two_heads_model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE))
def _read_wav_file(self, file_name) -> _AudioData:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
return _AudioData.create_from_array(
buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
sample_rate, buffer = wavfile.read(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
audio_data_list = []
start = 0
step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
while start < len(buffer):
end = min(start + (int)(step_size), len(buffer))
audio_data_list.append((_AudioData.create_from_array(
buffer[start:end].astype(float) / np.iinfo(np.int16).max,
sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
start = end
return audio_data_list
# TODO: Compares the exact score values to capture unexpected
# changes in the inference pipeline.
def _check_yamnet_result(
self,
classification_result_list: List[_AudioClassifierResult],
expected_num_categories=521):
self.assertLen(classification_result_list, 5)
for idx, timestamp in enumerate([0, 975, 1950, 2925]):
classification_result = classification_result_list[idx]
self.assertEqual(classification_result.timestamp_ms, timestamp)
self.assertLen(classification_result.classifications, 1)
classifcation = classification_result.classifications[0]
self.assertEqual(classifcation.head_index, 0)
self.assertEqual(classifcation.head_name, 'scores')
self.assertLen(classifcation.categories, expected_num_categories)
audio_category = classifcation.categories[0]
self.assertEqual(audio_category.index, 0)
self.assertEqual(audio_category.category_name, 'Speech')
self.assertGreater(audio_category.score, 0.9)
# TODO: Compares the exact score values to capture unexpected
# changes in the inference pipeline.
def _check_two_heads_result(
self,
classification_result_list: List[_AudioClassifierResult],
first_head_expected_num_categories=521,
second_head_expected_num_categories=5):
self.assertGreaterEqual(len(classification_result_list), 1)
self.assertLessEqual(len(classification_result_list), 2)
# Checks the first result.
classification_result = classification_result_list[0]
self.assertEqual(classification_result.timestamp_ms, 0)
self.assertLen(classification_result.classifications, 2)
# Checks the first head.
yamnet_classifcation = classification_result.classifications[0]
self.assertEqual(yamnet_classifcation.head_index, 0)
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
self.assertLen(yamnet_classifcation.categories,
first_head_expected_num_categories)
# Checks the second head.
yamnet_category = yamnet_classifcation.categories[0]
self.assertEqual(yamnet_category.index, 508)
self.assertEqual(yamnet_category.category_name, 'Environmental noise')
self.assertGreater(yamnet_category.score, 0.5)
bird_classifcation = classification_result.classifications[1]
self.assertEqual(bird_classifcation.head_index, 1)
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
self.assertLen(bird_classifcation.categories,
second_head_expected_num_categories)
bird_category = bird_classifcation.categories[0]
self.assertEqual(bird_category.index, 4)
self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta')
self.assertGreater(bird_category.score, 0.93)
# Checks the second result, if present.
if len(classification_result_list) == 2:
classification_result = classification_result_list[1]
self.assertEqual(classification_result.timestamp_ms, 975)
self.assertLen(classification_result.classifications, 2)
# Checks the first head.
yamnet_classifcation = classification_result.classifications[0]
self.assertEqual(yamnet_classifcation.head_index, 0)
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
self.assertLen(yamnet_classifcation.categories,
first_head_expected_num_categories)
yamnet_category = yamnet_classifcation.categories[0]
self.assertEqual(yamnet_category.index, 494)
self.assertEqual(yamnet_category.category_name, 'Silence')
self.assertGreater(yamnet_category.score, 0.9)
bird_classifcation = classification_result.classifications[1]
self.assertEqual(bird_classifcation.head_index, 1)
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
self.assertLen(bird_classifcation.categories,
second_head_expected_num_categories)
# Checks the second head.
bird_category = bird_classifcation.categories[0]
self.assertEqual(bird_category.index, 1)
self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren')
self.assertGreater(bird_category.score, 0.99)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(
model_asset_path=self.yamnet_model_path))) as classifier:
self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _AudioClassifierOptions(base_options=base_options)
_AudioClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.yamnet_model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _AudioClassifierOptions(base_options=base_options)
classifier = _AudioClassifier.create_from_options(options)
self.assertIsInstance(classifier, _AudioClassifier)
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
def test_classify_with_yamnet_model(self, audio_file):
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(classification_result_list)
def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates(
self):
with _AudioClassifier.create_from_model_path(
self.yamnet_model_path) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(classification_result_list)
def test_max_result_options(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_score_threshold_options(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
score_threshold=0.9))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_allow_list_option(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['Speech']))) as classifier:
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
classifier_options=_ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar']))
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
@parameterized.parameters((_TWO_HEADS_WAV_16K_MONO),
(_TWO_HEADS_WAV_44K_MONO))
def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates(
self, audio_file):
with _AudioClassifier.create_from_model_path(
self.two_heads_model_path) as classifier:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list)
def test_classify_with_two_heads_model(self):
with _AudioClassifier.create_from_model_path(
self.two_heads_model_path) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list)
def test_classify_with_two_heads_model_with_max_results(self):
with _AudioClassifier.create_from_options(
_AudioClassifierOptions(
base_options=_BaseOptions(
model_asset_path=self.two_heads_model_path),
classifier_options=_ClassifierOptions(
max_results=1))) as classifier:
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
classification_result_list = classifier.classify(
self._read_wav_file(audio_file))
self._check_two_heads_result(classification_result_list, 1, 1)
def test_missing_sample_rate_in_audio_clips_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with self.assertRaisesRegex(ValueError,
r'Must provide the audio sample rate'):
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify(_AudioData(buffer_length=100))
def test_missing_sample_rate_in_audio_stream_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'provide the audio sample rate in audio data'):
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify(_AudioData(buffer_length=100))
def test_missing_result_callback(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
def test_illegal_result_callback(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _AudioClassifier.create_from_options(options) as unused_classifier:
pass
def test_calling_classify_in_audio_stream_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the audio clips mode'):
classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO))
def test_calling_classify_async_in_audio_clips_mode(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
with _AudioClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(
ValueError, r'not initialized with the audio stream mode'):
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
def test_classify_async_calls_with_illegal_timestamp(self):
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
result_callback=mock.MagicMock())
with _AudioClassifier.create_from_options(options) as classifier:
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
def test_classify_async(self, audio_file):
classification_result_list = []
def save_result(result: _AudioClassifierResult, timestamp_ms: int):
result.timestamp_ms = timestamp_ms
classification_result_list.append(result)
options = _AudioClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
running_mode=_RUNNING_MODE.AUDIO_STREAM,
classifier_options=_ClassifierOptions(max_results=1),
result_callback=save_result)
classifier = _AudioClassifier.create_from_options(options)
audio_data_list = self._read_wav_file_as_stream(audio_file)
for audio_data, timestamp_ms in audio_data_list:
classifier.classify_async(audio_data, timestamp_ms)
classifier.close()
self._check_yamnet_result(
classification_result_list, expected_num_categories=1)
if __name__ == '__main__':
absltest.main()