Implement MediaPipe AudioClassifier Tasks Python API. Adjust the AudioClassifier Tasks C++ API to remove "sample_rate" from its options.
PiperOrigin-RevId: 486763992
This commit is contained in:
parent
51dbd9779c
commit
63a759accc
|
@ -87,6 +87,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
|
|
|
@ -361,7 +361,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
|||
packet = mp.packet_creator.create_float(0.1)
|
||||
data = mp.packet_getter.get_float(packet)
|
||||
)doc",
|
||||
py::arg().noconvert(), py::return_value_policy::move);
|
||||
py::return_value_policy::move);
|
||||
|
||||
m->def(
|
||||
"create_double", [](double data) { return MakePacket<double>(data); },
|
||||
|
@ -380,7 +380,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
|||
packet = mp.packet_creator.create_double(0.1)
|
||||
data = mp.packet_getter.get_float(packet)
|
||||
)doc",
|
||||
py::arg().noconvert(), py::return_value_policy::move);
|
||||
py::return_value_policy::move);
|
||||
|
||||
m->def(
|
||||
"create_int_array",
|
||||
|
|
|
@ -37,7 +37,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
|
|
|
@ -63,10 +63,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||
if (!options_proto->base_options().use_stream_mode()) {
|
||||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||
subgraph.In(kSampleRateTag);
|
||||
}
|
||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||
|
@ -93,9 +91,6 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
|||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
if (options->sample_rate > 0) {
|
||||
options_proto->set_default_input_audio_sample_rate(options->sample_rate);
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
|
@ -129,14 +124,6 @@ absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
|
|||
/* static */
|
||||
absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
||||
std::unique_ptr<AudioClassifierOptions> options) {
|
||||
if (options->running_mode == core::RunningMode::AUDIO_STREAM &&
|
||||
options->sample_rate < 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"The audio classifier is in audio stream mode, the sample rate must be "
|
||||
"specified in the AudioClassifierOptions.",
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
|
||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||
if (options->result_callback) {
|
||||
|
@ -161,7 +148,9 @@ absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
|
|||
}
|
||||
|
||||
absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
||||
double audio_sample_rate,
|
||||
int64 timestamp_ms) {
|
||||
MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
|
||||
return SendAudioStreamData(
|
||||
{{kAudioStreamName,
|
||||
MakePacket<Matrix>(std::move(audio_block))
|
||||
|
|
|
@ -52,15 +52,10 @@ struct AudioClassifierOptions {
|
|||
// 1) The audio clips mode for running classification on independent audio
|
||||
// clips.
|
||||
// 2) The audio stream mode for running classification on the audio stream,
|
||||
// such as from microphone. In this mode, the "sample_rate" below must be
|
||||
// provided, and the "result_callback" below must be specified to receive
|
||||
// the classification results asynchronously.
|
||||
// such as from microphone. In this mode, the "result_callback" below must
|
||||
// be specified to receive the classification results asynchronously.
|
||||
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
||||
|
||||
// The sample rate of the input audios. Must be set when the running mode is
|
||||
// set to RunningMode::AUDIO_STREAM.
|
||||
double sample_rate = -1.0;
|
||||
|
||||
// The user-defined result callback for processing audio stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::AUDIO_STREAM.
|
||||
|
@ -160,15 +155,17 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
// The audio block is represented as a MediaPipe Matrix that has the number
|
||||
// of channels rows and the number of samples per channel columns. The audio
|
||||
// data will be resampled, accumulated, and framed to the proper size for the
|
||||
// underlying model to consume. It's required to provide a timestamp (in
|
||||
// milliseconds) to indicate the start time of the input audio block. The
|
||||
// underlying model to consume. It's required to provide the corresponding
|
||||
// audio sample rate along with the input audio block as well as a timestamp
|
||||
// (in milliseconds) to indicate the start time of the input audio block. The
|
||||
// timestamps must be monotonically increasing.
|
||||
//
|
||||
// The input audio block may be longer than what the model is able to process
|
||||
// in a single inference. When this occurs, the input audio block is split
|
||||
// into multiple chunks. For this reason, the callback may be called multiple
|
||||
// times (once per chunk) for each call to this function.
|
||||
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
|
||||
absl::Status ClassifyAsync(mediapipe::Matrix audio_block,
|
||||
double audio_sample_rate, int64 timestamp_ms);
|
||||
|
||||
// Shuts down the AudioClassifier when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -72,18 +72,6 @@ struct AudioClassifierOutputStreams {
|
|||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(
|
||||
const proto::AudioClassifierGraphOptions& options) {
|
||||
if (options.base_options().use_stream_mode() &&
|
||||
!options.has_default_input_audio_sample_rate()) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
"In the streaming mode, the default input "
|
||||
"audio sample rate must be set.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
|
||||
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
|
||||
const core::ModelResources& model_resources) {
|
||||
|
@ -170,19 +158,12 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
const auto* model_resources,
|
||||
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
const bool use_stream_mode =
|
||||
sc->Options<proto::AudioClassifierGraphOptions>()
|
||||
.base_options()
|
||||
.use_stream_mode();
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildAudioClassificationTask(
|
||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Matrix>(kAudioTag)],
|
||||
use_stream_mode
|
||||
? absl::nullopt
|
||||
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
|
||||
graph));
|
||||
absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.timestamped_classifications >>
|
||||
|
@ -207,7 +188,6 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
const proto::AudioClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
const bool use_stream_mode = task_options.base_options().use_stream_mode();
|
||||
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||
// Checks that metadata is available.
|
||||
|
|
|
@ -70,6 +70,8 @@ Matrix GetAudioData(absl::string_view filename) {
|
|||
return matrix_mapping.matrix();
|
||||
}
|
||||
|
||||
// TODO: Compares the exact score values to capture unexpected
|
||||
// changes in the inference pipeline.
|
||||
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
||||
int expected_num_categories = 521) {
|
||||
EXPECT_EQ(result.size(), 5);
|
||||
|
@ -90,13 +92,15 @@ void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Compares the exact score values to capture unexpected
|
||||
// changes in the inference pipeline.
|
||||
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||
EXPECT_GE(result.size(), 1);
|
||||
EXPECT_LE(result.size(), 2);
|
||||
// Check first result.
|
||||
// Check the first result.
|
||||
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||
EXPECT_EQ(result[0].classifications.size(), 2);
|
||||
// Check first head.
|
||||
// Check the first head.
|
||||
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
|
||||
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||
|
@ -104,19 +108,19 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
|||
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||
"Environmental noise");
|
||||
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
|
||||
// Check second head.
|
||||
// Check the second head.
|
||||
EXPECT_EQ(result[0].classifications[1].head_index, 1);
|
||||
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
|
||||
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
|
||||
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
|
||||
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
|
||||
"Chestnut-crowned Antpitta");
|
||||
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
|
||||
// Check second result, if present.
|
||||
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.93f);
|
||||
// Check the second result, if present.
|
||||
if (result.size() == 2) {
|
||||
EXPECT_EQ(result[1].timestamp_ms, 975);
|
||||
EXPECT_EQ(result[1].classifications.size(), 2);
|
||||
// Check first head.
|
||||
// Check the first head.
|
||||
EXPECT_EQ(result[1].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
|
||||
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
|
||||
|
@ -124,7 +128,7 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
|||
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
|
||||
"Silence");
|
||||
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
|
||||
// Check second head.
|
||||
// Check the second head.
|
||||
EXPECT_EQ(result[1].classifications[1].head_index, 1);
|
||||
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
|
||||
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
|
||||
|
@ -234,7 +238,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->sample_rate = 16000;
|
||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||
AudioClassifier::Create(std::move(options));
|
||||
|
||||
|
@ -266,25 +269,6 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
||||
auto options = std::make_unique<AudioClassifierOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||
AudioClassifier::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(audio_classifier_or.status().code(),
|
||||
absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(audio_classifier_or.status().message(),
|
||||
HasSubstr("the sample rate must be specified"));
|
||||
EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class ClassifyTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(ClassifyTest, Succeeds) {
|
||||
|
@ -493,7 +477,6 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
|||
options->classifier_options.max_results = 1;
|
||||
options->classifier_options.score_threshold = 0.3f;
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->sample_rate = kSampleRateHz;
|
||||
std::vector<AudioClassifierResult> outputs;
|
||||
options->result_callback =
|
||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||
|
@ -506,7 +489,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
|||
int num_samples = std::min((int)(audio_buffer.cols() - start_col),
|
||||
kYamnetNumOfAudioSamples * 3);
|
||||
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
||||
audio_buffer.block(0, start_col, 1, num_samples),
|
||||
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
|
||||
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
||||
start_col += kYamnetNumOfAudioSamples * 3;
|
||||
}
|
||||
|
@ -523,7 +506,6 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
options->classifier_options.max_results = 1;
|
||||
options->classifier_options.score_threshold = 0.3f;
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->sample_rate = kSampleRateHz;
|
||||
std::vector<AudioClassifierResult> outputs;
|
||||
options->result_callback =
|
||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||
|
@ -538,7 +520,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
std::min((int)(audio_buffer.cols() - start_col),
|
||||
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
|
||||
MP_ASSERT_OK(audio_classifier->ClassifyAsync(
|
||||
audio_buffer.block(0, start_col, 1, num_samples),
|
||||
audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
|
||||
start_col * kMilliSecondsPerSecond / kSampleRateHz));
|
||||
start_col += num_samples;
|
||||
}
|
||||
|
|
|
@ -72,8 +72,39 @@ class BaseAudioTaskApi : public tasks::core::BaseTaskApi {
|
|||
return runner_->Send(std::move(inputs));
|
||||
}
|
||||
|
||||
// Checks or sets the sample rate in the audio stream mode.
|
||||
absl::Status CheckOrSetSampleRate(std::string sample_rate_stream_name,
|
||||
double sample_rate) {
|
||||
if (running_mode_ != RunningMode::AUDIO_STREAM) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("Task is not initialized with the audio stream mode. "
|
||||
"Current running mode:",
|
||||
GetRunningModeName(running_mode_)),
|
||||
MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError);
|
||||
}
|
||||
if (default_sample_rate_ > 0) {
|
||||
if (std::fabs(sample_rate - default_sample_rate_) >
|
||||
std::numeric_limits<double>::epsilon()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("The input audio sample rate: ", sample_rate,
|
||||
" is inconsistent with the previously provided: ",
|
||||
default_sample_rate_),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
} else {
|
||||
default_sample_rate_ = sample_rate;
|
||||
MP_RETURN_IF_ERROR(runner_->Send(
|
||||
{{sample_rate_stream_name, MakePacket<double>(default_sample_rate_)
|
||||
.At(Timestamp::PreStream())}}));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
RunningMode running_mode_;
|
||||
double default_sample_rate_ = -1.0;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
41
mediapipe/tasks/python/audio/BUILD
Normal file
41
mediapipe/tasks/python/audio/BUILD
Normal file
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "audio_classifier",
|
||||
srcs = [
|
||||
"audio_classifier.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
],
|
||||
)
|
13
mediapipe/tasks/python/audio/__init__.py
Normal file
13
mediapipe/tasks/python/audio/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
280
mediapipe/tasks/python/audio/audio_classifier.py
Normal file
280
mediapipe/tasks/python/audio/audio_classifier.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe audio classifier task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Mapping, List, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.audio.core import base_audio_task_api
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
AudioClassifierResult = classification_result_module.ClassificationResult
|
||||
_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options_module.ClassifierOptions
|
||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_AUDIO_IN_STREAM_NAME = 'audio_in'
|
||||
_AUDIO_TAG = 'AUDIO'
|
||||
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||
_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in'
|
||||
_SAMPLE_RATE_TAG = 'SAMPLE_RATE'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'
|
||||
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME = 'timestamped_classifications_out'
|
||||
_TIMESTAMPED_CLASSIFICATIONS_TAG = 'TIMESTAMPED_CLASSIFICATIONS'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AudioClassifierOptions:
|
||||
"""Options for the audio classifier task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the audio classifier task.
|
||||
running_mode: The running mode of the task. Default to the audio clips mode.
|
||||
Audio classifier task has two running modes: 1) The audio clips mode for
|
||||
running classification on independent audio clips. 2) The audio stream
|
||||
mode for running classification on the audio stream, such as from
|
||||
microphone. In this mode, the "result_callback" below must be specified
|
||||
to receive the classification results asynchronously.
|
||||
classifier_options: Options for configuring the classifier behavior, such as
|
||||
score threshold, number of results, etc.
|
||||
result_callback: The user-defined result callback for processing audio
|
||||
stream data. The result callback should only be specified when the running
|
||||
mode is set to the audio stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
|
||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||
result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _AudioClassifierGraphOptionsProto:
|
||||
"""Generates an AudioClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
|
||||
classifier_options_proto = self.classifier_options.to_pb2()
|
||||
|
||||
return _AudioClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
classifier_options=classifier_options_proto)
|
||||
|
||||
|
||||
class AudioClassifier(base_audio_task_api.BaseAudioTaskApi):
|
||||
"""Class that performs audio classification on audio data."""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'AudioClassifier':
|
||||
"""Creates an `AudioClassifier` object from a TensorFlow Lite model and the default `AudioClassifierOptions`.
|
||||
|
||||
Note that the created `AudioClassifier` instance is in audio clips mode, for
|
||||
classifying on independent audio clips.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
|
||||
Returns:
|
||||
`AudioClassifier` object that's created from the model file and the
|
||||
default `AudioClassifierOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `AudioClassifier` object from the provided
|
||||
file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = AudioClassifierOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: AudioClassifierOptions) -> 'AudioClassifier':
|
||||
"""Creates the `AudioClassifier` object from audio classifier options.
|
||||
|
||||
Args:
|
||||
options: Options for the audio classifier task.
|
||||
|
||||
Returns:
|
||||
`AudioClassifier` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `AudioClassifier` object from
|
||||
`AudioClassifierOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
|
||||
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||
timestamp_ms = output_packets[
|
||||
_CLASSIFICATIONS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND
|
||||
if output_packets[_CLASSIFICATIONS_STREAM_NAME].is_empty():
|
||||
options.result_callback(
|
||||
AudioClassifierResult(classifications=[]), timestamp_ms)
|
||||
return
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
options.result_callback(
|
||||
AudioClassifierResult.create_from_pb2(classification_result_proto),
|
||||
timestamp_ms)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]),
|
||||
':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME])
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
|
||||
':'.join([
|
||||
_TIMESTAMPED_CLASSIFICATIONS_TAG,
|
||||
_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME
|
||||
])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(
|
||||
# Audio tasks should not drop input audio due to flow limiting, which
|
||||
# may cause data inconsistency.
|
||||
task_info.generate_graph_config(enable_flow_limiting=False),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
|
||||
def classify(self, audio_clip: _AudioData) -> List[AudioClassifierResult]:
|
||||
"""Performs audio classification on the provided audio clip.
|
||||
|
||||
The audio clip is represented as a MediaPipe AudioData. The method accepts
|
||||
audio clips with various length and audio sample rate. It's required to
|
||||
provide the corresponding audio sample rate within the `AudioData` object.
|
||||
|
||||
The input audio clip may be longer than what the model is able to process
|
||||
in a single inference. When this occurs, the input audio clip is split into
|
||||
multiple chunks starting at different timestamps. For this reason, this
|
||||
function returns a vector of ClassificationResult objects, each associated
|
||||
ith a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||
data that was classified, e.g:
|
||||
|
||||
ClassificationResult #0 (first chunk of data):
|
||||
timestamp_ms: 0 (starts at 0ms)
|
||||
classifications #0 (single head model):
|
||||
category #0:
|
||||
category_name: "Speech"
|
||||
score: 0.6
|
||||
category #1:
|
||||
category_name: "Music"
|
||||
score: 0.2
|
||||
ClassificationResult #1 (second chunk of data):
|
||||
timestamp_ms: 800 (starts at 800ms)
|
||||
classifications #0 (single head model):
|
||||
category #0:
|
||||
category_name: "Speech"
|
||||
score: 0.5
|
||||
category #1:
|
||||
category_name: "Silence"
|
||||
score: 0.1
|
||||
|
||||
Args:
|
||||
audio_clip: MediaPipe AudioData.
|
||||
|
||||
Returns:
|
||||
An `AudioClassifierResult` object that contains a list of
|
||||
classification result objects, each associated with a timestamp
|
||||
corresponding to the start (in milliseconds) of the chunk data that was
|
||||
classified.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid, such as the sample
|
||||
rate is not provided in the `AudioData` object.
|
||||
RuntimeError: If audio classification failed to run.
|
||||
"""
|
||||
if not audio_clip.audio_format.sample_rate:
|
||||
raise ValueError('Must provide the audio sample rate in audio data.')
|
||||
output_packets = self._process_audio_clip({
|
||||
_AUDIO_IN_STREAM_NAME:
|
||||
packet_creator.create_matrix(audio_clip.buffer, transpose=True),
|
||||
_SAMPLE_RATE_IN_STREAM_NAME:
|
||||
packet_creator.create_double(audio_clip.audio_format.sample_rate)
|
||||
})
|
||||
output_list = []
|
||||
classification_result_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME])
|
||||
for proto in classification_result_proto_list:
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(proto)
|
||||
output_list.append(
|
||||
AudioClassifierResult.create_from_pb2(classification_result_proto))
|
||||
return output_list
|
||||
|
||||
def classify_async(self, audio_block: _AudioData, timestamp_ms: int) -> None:
|
||||
"""Sends audio data (a block in a continuous audio stream) to perform audio classification.
|
||||
|
||||
Only use this method when the AudioClassifier is created with the audio
|
||||
stream running mode. The input timestamps should be monotonically increasing
|
||||
for adjacent calls of this method. This method will return immediately after
|
||||
the input audio data is accepted. The results will be available via the
|
||||
`result_callback` provided in the `AudioClassifierOptions`. The
|
||||
`classify_async` method is designed to process auido stream data such as
|
||||
microphone input.
|
||||
|
||||
The input audio data may be longer than what the model is able to process
|
||||
in a single inference. When this occurs, the input audio block is split
|
||||
into multiple chunks. For this reason, the callback may be called multiple
|
||||
times (once per chunk) for each call to this function.
|
||||
|
||||
The `result_callback` provides:
|
||||
- An `AudioClassifierResult` object that contains a list of
|
||||
classifications.
|
||||
- The input timestamp in milliseconds.
|
||||
|
||||
Args:
|
||||
audio_block: MediaPipe AudioData.
|
||||
timestamp_ms: The timestamp of the input audio data in milliseconds.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the followings:
|
||||
1) The sample rate is not provided in the `AudioData` object or the
|
||||
provided sample rate is inconsisent with the previously recevied.
|
||||
2) The current input timestamp is smaller than what the audio
|
||||
classifier has already processed.
|
||||
"""
|
||||
if not audio_block.audio_format.sample_rate:
|
||||
raise ValueError('Must provide the audio sample rate in audio data.')
|
||||
if not self._default_sample_rate:
|
||||
self._default_sample_rate = audio_block.audio_format.sample_rate
|
||||
self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME,
|
||||
self._default_sample_rate)
|
||||
elif audio_block.audio_format.sample_rate != self._default_sample_rate:
|
||||
raise ValueError(
|
||||
f'The audio sample rate provided in audio data: '
|
||||
f'{audio_block.audio_format.sample_rate} is inconsisent with '
|
||||
f'the previously received: {self._default_sample_rate}.')
|
||||
|
||||
self._send_audio_stream_data({
|
||||
_AUDIO_IN_STREAM_NAME:
|
||||
packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
|
@ -32,6 +32,7 @@ py_library(
|
|||
":audio_task_running_mode",
|
||||
"//mediapipe/framework:calculator_py_pb2",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,14 +16,17 @@
|
|||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.python._framework_bindings import timestamp as timestamp_module
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_TaskRunner = task_runner_module.TaskRunner
|
||||
_Packet = packet_module.Packet
|
||||
_RunningMode = running_mode_module.AudioTaskRunningMode
|
||||
_Timestamp = timestamp_module.Timestamp
|
||||
|
||||
|
||||
class BaseAudioTaskApi(object):
|
||||
|
@ -59,6 +62,7 @@ class BaseAudioTaskApi(object):
|
|||
'callback should not be provided.')
|
||||
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
||||
self._running_mode = running_mode
|
||||
self._default_sample_rate = None
|
||||
|
||||
def _process_audio_clip(
|
||||
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||
|
@ -82,6 +86,27 @@ class BaseAudioTaskApi(object):
|
|||
+ self._running_mode.name)
|
||||
return self._runner.process(inputs)
|
||||
|
||||
def _set_sample_rate(self, sample_rate_stream_name: str,
|
||||
sample_rate: float) -> None:
|
||||
"""An asynchronous method to set audio sample rate in the audio stream mode.
|
||||
|
||||
Args:
|
||||
sample_rate_stream_name: The audio sample rate stream name.
|
||||
sample_rate: The audio sample rate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task's running mode is not set to the audio stream
|
||||
mode.
|
||||
"""
|
||||
if self._running_mode != _RunningMode.AUDIO_STREAM:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the audio stream mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
self._runner.send({
|
||||
sample_rate_stream_name:
|
||||
packet_creator.create_double(sample_rate).at(_Timestamp.PRESTREAM)
|
||||
})
|
||||
|
||||
def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
||||
"""An asynchronous method to send audio stream data to the runner.
|
||||
|
||||
|
|
|
@ -94,3 +94,13 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.py"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AudioFormat:
|
||||
class AudioDataFormat:
|
||||
"""Audio format metadata.
|
||||
|
||||
Attributes:
|
||||
|
@ -35,8 +35,10 @@ class AudioData(object):
|
|||
"""MediaPipe Tasks' audio container."""
|
||||
|
||||
def __init__(
|
||||
self, buffer_length: int,
|
||||
audio_format: AudioFormat = AudioFormat()) -> None:
|
||||
self,
|
||||
buffer_length: int,
|
||||
audio_format: AudioDataFormat = AudioDataFormat()
|
||||
) -> None:
|
||||
"""Initializes the `AudioData` object.
|
||||
|
||||
Args:
|
||||
|
@ -113,14 +115,14 @@ class AudioData(object):
|
|||
"""
|
||||
obj = cls(
|
||||
buffer_length=src.shape[0],
|
||||
audio_format=AudioFormat(
|
||||
audio_format=AudioDataFormat(
|
||||
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
||||
sample_rate=sample_rate))
|
||||
obj.load_from_array(src)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def audio_format(self) -> AudioFormat:
|
||||
def audio_format(self) -> AudioDataFormat:
|
||||
"""Gets the audio format of the audio."""
|
||||
return self._audio_format
|
||||
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Classifications data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_ClassificationsProto = classifications_pb2.Classifications
|
||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Classifications:
|
||||
"""Represents the classification results for a given classifier head.
|
||||
|
||||
Attributes:
|
||||
categories: The array of predicted categories, usually sorted by descending
|
||||
scores (e.g. from high to low probability).
|
||||
head_index: The index of the classifier head these categories refer to. This
|
||||
is useful for multi-head models.
|
||||
head_name: The name of the classifier head, which is the corresponding
|
||||
tensor metadata name.
|
||||
"""
|
||||
|
||||
categories: List[category_module.Category]
|
||||
head_index: int
|
||||
head_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||
"""Creates a `Classifications` object from the given protobuf object."""
|
||||
categories = []
|
||||
for entry in pb2_obj.classification_list.classification:
|
||||
categories.append(
|
||||
category_module.Category(
|
||||
index=entry.index,
|
||||
score=entry.score,
|
||||
display_name=entry.display_name,
|
||||
category_name=entry.label))
|
||||
|
||||
return Classifications(
|
||||
categories=categories,
|
||||
head_index=pb2_obj.head_index,
|
||||
head_name=pb2_obj.head_name)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationResult:
|
||||
"""Contains the classification results of a model.
|
||||
|
||||
Attributes:
|
||||
classifications: A list of `Classifications` objects, each for a head of the
|
||||
model.
|
||||
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
|
||||
chunk of data corresponding to these results. This is only used for
|
||||
classification on time series (e.g. audio classification). In these use
|
||||
cases, the amount of data to process might exceed the maximum size that
|
||||
the model can process: to solve this, the input data is split into
|
||||
multiple chunks starting at different timestamps.
|
||||
"""
|
||||
|
||||
classifications: List[Classifications]
|
||||
timestamp_ms: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
||||
"""Creates a `ClassificationResult` object from the given protobuf object.
|
||||
"""
|
||||
return ClassificationResult(
|
||||
classifications=[
|
||||
Classifications.create_from_pb2(classification)
|
||||
for classification in pb2_obj.classifications
|
||||
],
|
||||
timestamp_ms=pb2_obj.timestamp_ms)
|
37
mediapipe/tasks/python/test/audio/BUILD
Normal file
37
mediapipe/tasks/python/test/audio/BUILD
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_test(
|
||||
name = "audio_classifier_test",
|
||||
srcs = ["audio_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/audio:test_audio_clips",
|
||||
"//mediapipe/tasks/testdata/audio:test_models",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/audio:audio_classifier",
|
||||
"//mediapipe/tasks/python/audio/core:audio_task_running_mode",
|
||||
"//mediapipe/tasks/python/components/containers:audio_data",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
13
mediapipe/tasks/python/test/audio/__init__.py
Normal file
13
mediapipe/tasks/python/test/audio/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
381
mediapipe/tasks/python/test/audio/audio_classifier_test.py
Normal file
381
mediapipe/tasks/python/test/audio/audio_classifier_test.py
Normal file
|
@ -0,0 +1,381 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for audio classifier."""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
from mediapipe.tasks.python.audio import audio_classifier
|
||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode
|
||||
from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_AudioClassifier = audio_classifier.AudioClassifier
|
||||
_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
|
||||
_AudioClassifierResult = classification_result_module.ClassificationResult
|
||||
_AudioData = audio_data_module.AudioData
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
|
||||
|
||||
_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
|
||||
_YAMNET_MODEL_SAMPLE_RATE = 16000
|
||||
_TWO_HEADS_MODEL_FILE = 'two_heads.tflite'
|
||||
_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
|
||||
_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
|
||||
_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
|
||||
_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav'
|
||||
_YAMNET_NUM_OF_SAMPLES = 15600
|
||||
_MILLSECONDS_PER_SECOND = 1000
|
||||
|
||||
|
||||
class AudioClassifierTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.yamnet_model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
|
||||
self.two_heads_model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE))
|
||||
|
||||
def _read_wav_file(self, file_name) -> _AudioData:
|
||||
sample_rate, buffer = wavfile.read(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
|
||||
return _AudioData.create_from_array(
|
||||
buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
|
||||
|
||||
def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
|
||||
sample_rate, buffer = wavfile.read(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
|
||||
audio_data_list = []
|
||||
start = 0
|
||||
step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
|
||||
while start < len(buffer):
|
||||
end = min(start + (int)(step_size), len(buffer))
|
||||
audio_data_list.append((_AudioData.create_from_array(
|
||||
buffer[start:end].astype(float) / np.iinfo(np.int16).max,
|
||||
sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
|
||||
start = end
|
||||
return audio_data_list
|
||||
|
||||
# TODO: Compares the exact score values to capture unexpected
|
||||
# changes in the inference pipeline.
|
||||
def _check_yamnet_result(
|
||||
self,
|
||||
classification_result_list: List[_AudioClassifierResult],
|
||||
expected_num_categories=521):
|
||||
self.assertLen(classification_result_list, 5)
|
||||
for idx, timestamp in enumerate([0, 975, 1950, 2925]):
|
||||
classification_result = classification_result_list[idx]
|
||||
self.assertEqual(classification_result.timestamp_ms, timestamp)
|
||||
self.assertLen(classification_result.classifications, 1)
|
||||
classifcation = classification_result.classifications[0]
|
||||
self.assertEqual(classifcation.head_index, 0)
|
||||
self.assertEqual(classifcation.head_name, 'scores')
|
||||
self.assertLen(classifcation.categories, expected_num_categories)
|
||||
audio_category = classifcation.categories[0]
|
||||
self.assertEqual(audio_category.index, 0)
|
||||
self.assertEqual(audio_category.category_name, 'Speech')
|
||||
self.assertGreater(audio_category.score, 0.9)
|
||||
|
||||
# TODO: Compares the exact score values to capture unexpected
|
||||
# changes in the inference pipeline.
|
||||
def _check_two_heads_result(
|
||||
self,
|
||||
classification_result_list: List[_AudioClassifierResult],
|
||||
first_head_expected_num_categories=521,
|
||||
second_head_expected_num_categories=5):
|
||||
self.assertGreaterEqual(len(classification_result_list), 1)
|
||||
self.assertLessEqual(len(classification_result_list), 2)
|
||||
# Checks the first result.
|
||||
classification_result = classification_result_list[0]
|
||||
self.assertEqual(classification_result.timestamp_ms, 0)
|
||||
self.assertLen(classification_result.classifications, 2)
|
||||
# Checks the first head.
|
||||
yamnet_classifcation = classification_result.classifications[0]
|
||||
self.assertEqual(yamnet_classifcation.head_index, 0)
|
||||
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
|
||||
self.assertLen(yamnet_classifcation.categories,
|
||||
first_head_expected_num_categories)
|
||||
# Checks the second head.
|
||||
yamnet_category = yamnet_classifcation.categories[0]
|
||||
self.assertEqual(yamnet_category.index, 508)
|
||||
self.assertEqual(yamnet_category.category_name, 'Environmental noise')
|
||||
self.assertGreater(yamnet_category.score, 0.5)
|
||||
bird_classifcation = classification_result.classifications[1]
|
||||
self.assertEqual(bird_classifcation.head_index, 1)
|
||||
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
|
||||
self.assertLen(bird_classifcation.categories,
|
||||
second_head_expected_num_categories)
|
||||
bird_category = bird_classifcation.categories[0]
|
||||
self.assertEqual(bird_category.index, 4)
|
||||
self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta')
|
||||
self.assertGreater(bird_category.score, 0.93)
|
||||
# Checks the second result, if present.
|
||||
if len(classification_result_list) == 2:
|
||||
classification_result = classification_result_list[1]
|
||||
self.assertEqual(classification_result.timestamp_ms, 975)
|
||||
self.assertLen(classification_result.classifications, 2)
|
||||
# Checks the first head.
|
||||
yamnet_classifcation = classification_result.classifications[0]
|
||||
self.assertEqual(yamnet_classifcation.head_index, 0)
|
||||
self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
|
||||
self.assertLen(yamnet_classifcation.categories,
|
||||
first_head_expected_num_categories)
|
||||
yamnet_category = yamnet_classifcation.categories[0]
|
||||
self.assertEqual(yamnet_category.index, 494)
|
||||
self.assertEqual(yamnet_category.category_name, 'Silence')
|
||||
self.assertGreater(yamnet_category.score, 0.9)
|
||||
bird_classifcation = classification_result.classifications[1]
|
||||
self.assertEqual(bird_classifcation.head_index, 1)
|
||||
self.assertEqual(bird_classifcation.head_name, 'bird_classification')
|
||||
self.assertLen(bird_classifcation.categories,
|
||||
second_head_expected_num_categories)
|
||||
# Checks the second head.
|
||||
bird_category = bird_classifcation.categories[0]
|
||||
self.assertEqual(bird_category.index, 1)
|
||||
self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren')
|
||||
self.assertGreater(bird_category.score, 0.99)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.yamnet_model_path) as classifier:
|
||||
self.assertIsInstance(classifier, _AudioClassifier)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(
|
||||
model_asset_path=self.yamnet_model_path))) as classifier:
|
||||
self.assertIsInstance(classifier, _AudioClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
options = _AudioClassifierOptions(base_options=base_options)
|
||||
_AudioClassifier.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.yamnet_model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _AudioClassifierOptions(base_options=base_options)
|
||||
classifier = _AudioClassifier.create_from_options(options)
|
||||
self.assertIsInstance(classifier, _AudioClassifier)
|
||||
|
||||
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
|
||||
def test_classify_with_yamnet_model(self, audio_file):
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.yamnet_model_path) as classifier:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(classification_result_list)
|
||||
|
||||
def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates(
|
||||
self):
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.yamnet_model_path) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(classification_result_list)
|
||||
|
||||
def test_max_result_options(self):
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
max_results=1))) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(
|
||||
classification_result_list, expected_num_categories=1)
|
||||
|
||||
def test_score_threshold_options(self):
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
score_threshold=0.9))) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(
|
||||
classification_result_list, expected_num_categories=1)
|
||||
|
||||
def test_allow_list_option(self):
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
category_allowlist=['Speech']))) as classifier:
|
||||
for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_yamnet_result(
|
||||
classification_result_list, expected_num_categories=1)
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.'):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
category_allowlist=['foo'], category_denylist=['bar']))
|
||||
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_TWO_HEADS_WAV_16K_MONO),
|
||||
(_TWO_HEADS_WAV_44K_MONO))
|
||||
def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates(
|
||||
self, audio_file):
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.two_heads_model_path) as classifier:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_two_heads_result(classification_result_list)
|
||||
|
||||
def test_classify_with_two_heads_model(self):
|
||||
with _AudioClassifier.create_from_model_path(
|
||||
self.two_heads_model_path) as classifier:
|
||||
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_two_heads_result(classification_result_list)
|
||||
|
||||
def test_classify_with_two_heads_model_with_max_results(self):
|
||||
with _AudioClassifier.create_from_options(
|
||||
_AudioClassifierOptions(
|
||||
base_options=_BaseOptions(
|
||||
model_asset_path=self.two_heads_model_path),
|
||||
classifier_options=_ClassifierOptions(
|
||||
max_results=1))) as classifier:
|
||||
for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
|
||||
classification_result_list = classifier.classify(
|
||||
self._read_wav_file(audio_file))
|
||||
self._check_two_heads_result(classification_result_list, 1, 1)
|
||||
|
||||
def test_missing_sample_rate_in_audio_clips_mode(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'Must provide the audio sample rate'):
|
||||
with _AudioClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify(_AudioData(buffer_length=100))
|
||||
|
||||
def test_missing_sample_rate_in_audio_stream_mode(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'provide the audio sample rate in audio data'):
|
||||
with _AudioClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify(_AudioData(buffer_length=100))
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_illegal_result_callback(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_CLIPS,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _AudioClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_calling_classify_in_audio_stream_mode(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _AudioClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the audio clips mode'):
|
||||
classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO))
|
||||
|
||||
def test_calling_classify_async_in_audio_clips_mode(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_CLIPS)
|
||||
with _AudioClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the audio stream mode'):
|
||||
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _AudioClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
|
||||
|
||||
@parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
|
||||
def test_classify_async(self, audio_file):
|
||||
classification_result_list = []
|
||||
|
||||
def save_result(result: _AudioClassifierResult, timestamp_ms: int):
|
||||
result.timestamp_ms = timestamp_ms
|
||||
classification_result_list.append(result)
|
||||
|
||||
options = _AudioClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
|
||||
running_mode=_RUNNING_MODE.AUDIO_STREAM,
|
||||
classifier_options=_ClassifierOptions(max_results=1),
|
||||
result_callback=save_result)
|
||||
classifier = _AudioClassifier.create_from_options(options)
|
||||
audio_data_list = self._read_wav_file_as_stream(audio_file)
|
||||
for audio_data, timestamp_ms in audio_data_list:
|
||||
classifier.classify_async(audio_data, timestamp_ms)
|
||||
classifier.close()
|
||||
self._check_yamnet_result(
|
||||
classification_result_list, expected_num_categories=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
Loading…
Reference in New Issue
Block a user