diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index dd94ccede..a1a43128f 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -87,6 +87,7 @@ cc_library( cc_library( name = "builtin_task_graphs", deps = [ + "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index b36fa306a..6f2bd0d3b 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -361,7 +361,7 @@ void PublicPacketCreators(pybind11::module* m) { packet = mp.packet_creator.create_float(0.1) data = mp.packet_getter.get_float(packet) )doc", - py::arg().noconvert(), py::return_value_policy::move); + py::return_value_policy::move); m->def( "create_double", [](double data) { return MakePacket(data); }, @@ -380,7 +380,7 @@ void PublicPacketCreators(pybind11::module* m) { packet = mp.packet_creator.create_double(0.1) data = mp.packet_getter.get_float(packet) )doc", - py::arg().noconvert(), py::return_value_policy::move); + py::return_value_policy::move); m->def( "create_int_array", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 501a9e6fd..1955adfe7 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 3b01ddb88..91c853120 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -63,10 +63,8 @@ CalculatorGraphConfig CreateGraphConfig( api2::builder::Graph graph; auto& subgraph = graph.AddNode(kSubgraphTypeName); graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); - if (!options_proto->base_options().use_stream_mode()) { - graph.In(kSampleRateTag).SetName(kSampleRateName) >> - subgraph.In(kSampleRateTag); - } + graph.In(kSampleRateTag).SetName(kSampleRateName) >> + subgraph.In(kSampleRateTag); subgraph.GetOptions().Swap( options_proto.get()); subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >> @@ -93,9 +91,6 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); - if (options->sample_rate > 0) { - options_proto->set_default_input_audio_sample_rate(options->sample_rate); - } return options_proto; } @@ -129,14 +124,6 @@ absl::StatusOr ConvertAsyncOutputPackets( /* static */ absl::StatusOr> AudioClassifier::Create( std::unique_ptr options) { - if (options->running_mode == core::RunningMode::AUDIO_STREAM && - options->sample_rate < 0) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "The audio classifier is in audio stream mode, the sample rate must be " - "specified in the AudioClassifierOptions.", - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } auto options_proto = ConvertAudioClassifierOptionsToProto(options.get()); tasks::core::PacketsCallback packets_callback = nullptr; if (options->result_callback) { @@ -161,7 +148,9 @@ absl::StatusOr> AudioClassifier::Classify( } absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block, + double audio_sample_rate, int64 timestamp_ms) { + MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate)); return SendAudioStreamData( {{kAudioStreamName, MakePacket(std::move(audio_block)) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index 4b5d2c04b..dd611ec81 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -52,15 +52,10 @@ struct AudioClassifierOptions { // 1) The audio clips mode for running classification on independent audio // clips. // 2) The audio stream mode for running classification on the audio stream, - // such as from microphone. In this mode, the "sample_rate" below must be - // provided, and the "result_callback" below must be specified to receive - // the classification results asynchronously. + // such as from microphone. In this mode, the "result_callback" below must + // be specified to receive the classification results asynchronously. core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS; - // The sample rate of the input audios. Must be set when the running mode is - // set to RunningMode::AUDIO_STREAM. - double sample_rate = -1.0; - // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. @@ -160,15 +155,17 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // The audio block is represented as a MediaPipe Matrix that has the number // of channels rows and the number of samples per channel columns. The audio // data will be resampled, accumulated, and framed to the proper size for the - // underlying model to consume. It's required to provide a timestamp (in - // milliseconds) to indicate the start time of the input audio block. The + // underlying model to consume. It's required to provide the corresponding + // audio sample rate along with the input audio block as well as a timestamp + // (in milliseconds) to indicate the start time of the input audio block. The // timestamps must be monotonically increasing. // // The input audio block may be longer than what the model is able to process // in a single inference. When this occurs, the input audio block is split // into multiple chunks. For this reason, the callback may be called multiple // times (once per chunk) for each call to this function. - absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms); + absl::Status ClassifyAsync(mediapipe::Matrix audio_block, + double audio_sample_rate, int64 timestamp_ms); // Shuts down the AudioClassifier when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 2b75209bb..b232afc72 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -72,18 +72,6 @@ struct AudioClassifierOutputStreams { Source> timestamped_classifications; }; -absl::Status SanityCheckOptions( - const proto::AudioClassifierGraphOptions& options) { - if (options.base_options().use_stream_mode() && - !options.has_default_input_audio_sample_rate()) { - return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, - "In the streaming mode, the default input " - "audio sample rate must be set.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - return absl::OkStatus(); -} - // Builds an AudioTensorSpecs for configuring the preprocessing calculators. absl::StatusOr BuildPreprocessingSpecs( const core::ModelResources& model_resources) { @@ -170,19 +158,12 @@ class AudioClassifierGraph : public core::ModelTaskGraph { const auto* model_resources, CreateModelResources(sc)); Graph graph; - const bool use_stream_mode = - sc->Options() - .base_options() - .use_stream_mode(); ASSIGN_OR_RETURN( auto output_streams, BuildAudioClassificationTask( sc->Options(), *model_resources, graph[Input(kAudioTag)], - use_stream_mode - ? absl::nullopt - : absl::make_optional(graph[Input(kSampleRateTag)]), - graph)); + absl::make_optional(graph[Input(kSampleRateTag)]), graph)); output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -207,7 +188,6 @@ class AudioClassifierGraph : public core::ModelTaskGraph { const proto::AudioClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source audio_in, absl::optional> sample_rate_in, Graph& graph) { - MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); const bool use_stream_mode = task_options.base_options().use_stream_mode(); const auto* metadata_extractor = model_resources.GetMetadataExtractor(); // Checks that metadata is available. diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index a4fe5e32e..596b910f8 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -70,6 +70,8 @@ Matrix GetAudioData(absl::string_view filename) { return matrix_mapping.matrix(); } +// TODO: Compares the exact score values to capture unexpected +// changes in the inference pipeline. void CheckSpeechResult(const std::vector& result, int expected_num_categories = 521) { EXPECT_EQ(result.size(), 5); @@ -90,13 +92,15 @@ void CheckSpeechResult(const std::vector& result, } } +// TODO: Compares the exact score values to capture unexpected +// changes in the inference pipeline. void CheckTwoHeadsResult(const std::vector& result) { EXPECT_GE(result.size(), 1); EXPECT_LE(result.size(), 2); - // Check first result. + // Check the first result. EXPECT_EQ(result[0].timestamp_ms, 0); EXPECT_EQ(result[0].classifications.size(), 2); - // Check first head. + // Check the first head. EXPECT_EQ(result[0].classifications[0].head_index, 0); EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification"); EXPECT_EQ(result[0].classifications[0].categories.size(), 521); @@ -104,19 +108,19 @@ void CheckTwoHeadsResult(const std::vector& result) { EXPECT_EQ(result[0].classifications[0].categories[0].category_name, "Environmental noise"); EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f); - // Check second head. + // Check the second head. EXPECT_EQ(result[0].classifications[1].head_index, 1); EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification"); EXPECT_EQ(result[0].classifications[1].categories.size(), 5); EXPECT_EQ(result[0].classifications[1].categories[0].index, 4); EXPECT_EQ(result[0].classifications[1].categories[0].category_name, "Chestnut-crowned Antpitta"); - EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f); - // Check second result, if present. + EXPECT_GT(result[0].classifications[1].categories[0].score, 0.93f); + // Check the second result, if present. if (result.size() == 2) { EXPECT_EQ(result[1].timestamp_ms, 975); EXPECT_EQ(result[1].classifications.size(), 2); - // Check first head. + // Check the first head. EXPECT_EQ(result[1].classifications[0].head_index, 0); EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification"); EXPECT_EQ(result[1].classifications[0].categories.size(), 521); @@ -124,7 +128,7 @@ void CheckTwoHeadsResult(const std::vector& result) { EXPECT_EQ(result[1].classifications[0].categories[0].category_name, "Silence"); EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f); - // Check second head. + // Check the second head. EXPECT_EQ(result[1].classifications[1].head_index, 1); EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification"); EXPECT_EQ(result[1].classifications[1].categories.size(), 5); @@ -234,7 +238,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); options->running_mode = core::RunningMode::AUDIO_STREAM; - options->sample_rate = 16000; StatusOr> audio_classifier_or = AudioClassifier::Create(std::move(options)); @@ -266,25 +269,6 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); } -TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { - auto options = std::make_unique(); - options->base_options.model_asset_path = - JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); - options->running_mode = core::RunningMode::AUDIO_STREAM; - options->result_callback = - [](absl::StatusOr status_or_result) {}; - StatusOr> audio_classifier_or = - AudioClassifier::Create(std::move(options)); - - EXPECT_EQ(audio_classifier_or.status().code(), - absl::StatusCode::kInvalidArgument); - EXPECT_THAT(audio_classifier_or.status().message(), - HasSubstr("the sample rate must be specified")); - EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); -} - class ClassifyTest : public tflite_shims::testing::Test {}; TEST_F(ClassifyTest, Succeeds) { @@ -493,7 +477,6 @@ TEST_F(ClassifyAsyncTest, Succeeds) { options->classifier_options.max_results = 1; options->classifier_options.score_threshold = 0.3f; options->running_mode = core::RunningMode::AUDIO_STREAM; - options->sample_rate = kSampleRateHz; std::vector outputs; options->result_callback = [&outputs](absl::StatusOr status_or_result) { @@ -506,7 +489,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) { int num_samples = std::min((int)(audio_buffer.cols() - start_col), kYamnetNumOfAudioSamples * 3); MP_ASSERT_OK(audio_classifier->ClassifyAsync( - audio_buffer.block(0, start_col, 1, num_samples), + audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz, start_col * kMilliSecondsPerSecond / kSampleRateHz)); start_col += kYamnetNumOfAudioSamples * 3; } @@ -523,7 +506,6 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { options->classifier_options.max_results = 1; options->classifier_options.score_threshold = 0.3f; options->running_mode = core::RunningMode::AUDIO_STREAM; - options->sample_rate = kSampleRateHz; std::vector outputs; options->result_callback = [&outputs](absl::StatusOr status_or_result) { @@ -538,7 +520,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { std::min((int)(audio_buffer.cols() - start_col), rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3); MP_ASSERT_OK(audio_classifier->ClassifyAsync( - audio_buffer.block(0, start_col, 1, num_samples), + audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz, start_col * kMilliSecondsPerSecond / kSampleRateHz)); start_col += num_samples; } diff --git a/mediapipe/tasks/cc/audio/core/base_audio_task_api.h b/mediapipe/tasks/cc/audio/core/base_audio_task_api.h index 495951bae..c04b3cf32 100644 --- a/mediapipe/tasks/cc/audio/core/base_audio_task_api.h +++ b/mediapipe/tasks/cc/audio/core/base_audio_task_api.h @@ -72,8 +72,39 @@ class BaseAudioTaskApi : public tasks::core::BaseTaskApi { return runner_->Send(std::move(inputs)); } + // Checks or sets the sample rate in the audio stream mode. + absl::Status CheckOrSetSampleRate(std::string sample_rate_stream_name, + double sample_rate) { + if (running_mode_ != RunningMode::AUDIO_STREAM) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Task is not initialized with the audio stream mode. " + "Current running mode:", + GetRunningModeName(running_mode_)), + MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError); + } + if (default_sample_rate_ > 0) { + if (std::fabs(sample_rate - default_sample_rate_) > + std::numeric_limits::epsilon()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("The input audio sample rate: ", sample_rate, + " is inconsistent with the previously provided: ", + default_sample_rate_), + MediaPipeTasksStatus::kInvalidArgumentError); + } + } else { + default_sample_rate_ = sample_rate; + MP_RETURN_IF_ERROR(runner_->Send( + {{sample_rate_stream_name, MakePacket(default_sample_rate_) + .At(Timestamp::PreStream())}})); + } + return absl::OkStatus(); + } + private: RunningMode running_mode_; + double default_sample_rate_ = -1.0; }; } // namespace core diff --git a/mediapipe/tasks/python/audio/BUILD b/mediapipe/tasks/python/audio/BUILD new file mode 100644 index 000000000..dd8719151 --- /dev/null +++ b/mediapipe/tasks/python/audio/BUILD @@ -0,0 +1,41 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "audio_classifier", + srcs = [ + "audio_classifier.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/audio/core:base_audio_task_api", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + ], +) diff --git a/mediapipe/tasks/python/audio/__init__.py b/mediapipe/tasks/python/audio/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/audio/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py new file mode 100644 index 000000000..e04e778b5 --- /dev/null +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -0,0 +1,280 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe audio classifier task.""" + +import dataclasses +from typing import Callable, Mapping, List, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2 +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module +from mediapipe.tasks.python.audio.core import base_audio_task_api +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +AudioClassifierResult = classification_result_module.ClassificationResult +_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options_module.ClassifierOptions +_RunningMode = running_mode_module.AudioTaskRunningMode +_TaskInfo = task_info_module.TaskInfo + +_AUDIO_IN_STREAM_NAME = 'audio_in' +_AUDIO_TAG = 'AUDIO' +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' +_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in' +_SAMPLE_RATE_TAG = 'SAMPLE_RATE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph' +_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME = 'timestamped_classifications_out' +_TIMESTAMPED_CLASSIFICATIONS_TAG = 'TIMESTAMPED_CLASSIFICATIONS' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class AudioClassifierOptions: + """Options for the audio classifier task. + + Attributes: + base_options: Base options for the audio classifier task. + running_mode: The running mode of the task. Default to the audio clips mode. + Audio classifier task has two running modes: 1) The audio clips mode for + running classification on independent audio clips. 2) The audio stream + mode for running classification on the audio stream, such as from + microphone. In this mode, the "result_callback" below must be specified + to receive the classification results asynchronously. + classifier_options: Options for configuring the classifier behavior, such as + score threshold, number of results, etc. + result_callback: The user-defined result callback for processing audio + stream data. The result callback should only be specified when the running + mode is set to the audio stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS + classifier_options: _ClassifierOptions = _ClassifierOptions() + result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _AudioClassifierGraphOptionsProto: + """Generates an AudioClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True + classifier_options_proto = self.classifier_options.to_pb2() + + return _AudioClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class AudioClassifier(base_audio_task_api.BaseAudioTaskApi): + """Class that performs audio classification on audio data.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'AudioClassifier': + """Creates an `AudioClassifier` object from a TensorFlow Lite model and the default `AudioClassifierOptions`. + + Note that the created `AudioClassifier` instance is in audio clips mode, for + classifying on independent audio clips. + + Args: + model_path: Path to the model. + + Returns: + `AudioClassifier` object that's created from the model file and the + default `AudioClassifierOptions`. + + Raises: + ValueError: If failed to create `AudioClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = AudioClassifierOptions( + base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: AudioClassifierOptions) -> 'AudioClassifier': + """Creates the `AudioClassifier` object from audio classifier options. + + Args: + options: Options for the audio classifier task. + + Returns: + `AudioClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `AudioClassifier` object from + `AudioClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet.Packet]): + timestamp_ms = output_packets[ + _CLASSIFICATIONS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND + if output_packets[_CLASSIFICATIONS_STREAM_NAME].is_empty(): + options.result_callback( + AudioClassifierResult(classifications=[]), timestamp_ms) + return + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + options.result_callback( + AudioClassifierResult.create_from_pb2(classification_result_proto), + timestamp_ms) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]), + ':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME]) + ], + output_streams=[ + ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]), + ':'.join([ + _TIMESTAMPED_CLASSIFICATIONS_TAG, + _TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME + ]) + ], + task_options=options) + return cls( + # Audio tasks should not drop input audio due to flow limiting, which + # may cause data inconsistency. + task_info.generate_graph_config(enable_flow_limiting=False), + options.running_mode, + packets_callback if options.result_callback else None) + + def classify(self, audio_clip: _AudioData) -> List[AudioClassifierResult]: + """Performs audio classification on the provided audio clip. + + The audio clip is represented as a MediaPipe AudioData. The method accepts + audio clips with various length and audio sample rate. It's required to + provide the corresponding audio sample rate within the `AudioData` object. + + The input audio clip may be longer than what the model is able to process + in a single inference. When this occurs, the input audio clip is split into + multiple chunks starting at different timestamps. For this reason, this + function returns a vector of ClassificationResult objects, each associated + ith a timestamp corresponding to the start (in milliseconds) of the chunk + data that was classified, e.g: + + ClassificationResult #0 (first chunk of data): + timestamp_ms: 0 (starts at 0ms) + classifications #0 (single head model): + category #0: + category_name: "Speech" + score: 0.6 + category #1: + category_name: "Music" + score: 0.2 + ClassificationResult #1 (second chunk of data): + timestamp_ms: 800 (starts at 800ms) + classifications #0 (single head model): + category #0: + category_name: "Speech" + score: 0.5 + category #1: + category_name: "Silence" + score: 0.1 + + Args: + audio_clip: MediaPipe AudioData. + + Returns: + An `AudioClassifierResult` object that contains a list of + classification result objects, each associated with a timestamp + corresponding to the start (in milliseconds) of the chunk data that was + classified. + + Raises: + ValueError: If any of the input arguments is invalid, such as the sample + rate is not provided in the `AudioData` object. + RuntimeError: If audio classification failed to run. + """ + if not audio_clip.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + output_packets = self._process_audio_clip({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_clip.buffer, transpose=True), + _SAMPLE_RATE_IN_STREAM_NAME: + packet_creator.create_double(audio_clip.audio_format.sample_rate) + }) + output_list = [] + classification_result_proto_list = packet_getter.get_proto_list( + output_packets[_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME]) + for proto in classification_result_proto_list: + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom(proto) + output_list.append( + AudioClassifierResult.create_from_pb2(classification_result_proto)) + return output_list + + def classify_async(self, audio_block: _AudioData, timestamp_ms: int) -> None: + """Sends audio data (a block in a continuous audio stream) to perform audio classification. + + Only use this method when the AudioClassifier is created with the audio + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input audio data is accepted. The results will be available via the + `result_callback` provided in the `AudioClassifierOptions`. The + `classify_async` method is designed to process auido stream data such as + microphone input. + + The input audio data may be longer than what the model is able to process + in a single inference. When this occurs, the input audio block is split + into multiple chunks. For this reason, the callback may be called multiple + times (once per chunk) for each call to this function. + + The `result_callback` provides: + - An `AudioClassifierResult` object that contains a list of + classifications. + - The input timestamp in milliseconds. + + Args: + audio_block: MediaPipe AudioData. + timestamp_ms: The timestamp of the input audio data in milliseconds. + + Raises: + ValueError: If any of the followings: + 1) The sample rate is not provided in the `AudioData` object or the + provided sample rate is inconsisent with the previously recevied. + 2) The current input timestamp is smaller than what the audio + classifier has already processed. + """ + if not audio_block.audio_format.sample_rate: + raise ValueError('Must provide the audio sample rate in audio data.') + if not self._default_sample_rate: + self._default_sample_rate = audio_block.audio_format.sample_rate + self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME, + self._default_sample_rate) + elif audio_block.audio_format.sample_rate != self._default_sample_rate: + raise ValueError( + f'The audio sample rate provided in audio data: ' + f'{audio_block.audio_format.sample_rate} is inconsisent with ' + f'the previously received: {self._default_sample_rate}.') + + self._send_audio_stream_data({ + _AUDIO_IN_STREAM_NAME: + packet_creator.create_matrix(audio_block.buffer, transpose=True).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) diff --git a/mediapipe/tasks/python/audio/core/BUILD b/mediapipe/tasks/python/audio/core/BUILD index d9f169c65..3cb9cb8e8 100644 --- a/mediapipe/tasks/python/audio/core/BUILD +++ b/mediapipe/tasks/python/audio/core/BUILD @@ -32,6 +32,7 @@ py_library( ":audio_task_running_mode", "//mediapipe/framework:calculator_py_pb2", "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/audio/core/base_audio_task_api.py b/mediapipe/tasks/python/audio/core/base_audio_task_api.py index b6a2e0e40..b2197c142 100644 --- a/mediapipe/tasks/python/audio/core/base_audio_task_api.py +++ b/mediapipe/tasks/python/audio/core/base_audio_task_api.py @@ -16,14 +16,17 @@ from typing import Callable, Mapping, Optional from mediapipe.framework import calculator_pb2 +from mediapipe.python import packet_creator from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module +from mediapipe.python._framework_bindings import timestamp as timestamp_module from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _TaskRunner = task_runner_module.TaskRunner _Packet = packet_module.Packet _RunningMode = running_mode_module.AudioTaskRunningMode +_Timestamp = timestamp_module.Timestamp class BaseAudioTaskApi(object): @@ -59,6 +62,7 @@ class BaseAudioTaskApi(object): 'callback should not be provided.') self._runner = _TaskRunner.create(graph_config, packet_callback) self._running_mode = running_mode + self._default_sample_rate = None def _process_audio_clip( self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: @@ -82,6 +86,27 @@ class BaseAudioTaskApi(object): + self._running_mode.name) return self._runner.process(inputs) + def _set_sample_rate(self, sample_rate_stream_name: str, + sample_rate: float) -> None: + """An asynchronous method to set audio sample rate in the audio stream mode. + + Args: + sample_rate_stream_name: The audio sample rate stream name. + sample_rate: The audio sample rate. + + Raises: + ValueError: If the task's running mode is not set to the audio stream + mode. + """ + if self._running_mode != _RunningMode.AUDIO_STREAM: + raise ValueError( + 'Task is not initialized with the audio stream mode. Current running mode:' + + self._running_mode.name) + self._runner.send({ + sample_rate_stream_name: + packet_creator.create_double(sample_rate).at(_Timestamp.PRESTREAM) + }) + def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None: """An asynchronous method to send audio stream data to the runner. diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 91e115476..9e0a90911 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -94,3 +94,13 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "classification_result", + srcs = ["classification_result.py"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/containers/audio_data.py b/mediapipe/tasks/python/components/containers/audio_data.py index 56399dea8..1d0267998 100644 --- a/mediapipe/tasks/python/components/containers/audio_data.py +++ b/mediapipe/tasks/python/components/containers/audio_data.py @@ -20,7 +20,7 @@ import numpy as np @dataclasses.dataclass -class AudioFormat: +class AudioDataFormat: """Audio format metadata. Attributes: @@ -35,8 +35,10 @@ class AudioData(object): """MediaPipe Tasks' audio container.""" def __init__( - self, buffer_length: int, - audio_format: AudioFormat = AudioFormat()) -> None: + self, + buffer_length: int, + audio_format: AudioDataFormat = AudioDataFormat() + ) -> None: """Initializes the `AudioData` object. Args: @@ -113,14 +115,14 @@ class AudioData(object): """ obj = cls( buffer_length=src.shape[0], - audio_format=AudioFormat( + audio_format=AudioDataFormat( num_channels=1 if len(src.shape) == 1 else src.shape[1], sample_rate=sample_rate)) obj.load_from_array(src) return obj @property - def audio_format(self) -> AudioFormat: + def audio_format(self) -> AudioDataFormat: """Gets the audio format of the audio.""" return self._audio_format diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py new file mode 100644 index 000000000..cc25fc708 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -0,0 +1,92 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifications data class.""" + +import dataclasses +from typing import List, Optional + +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassificationsProto = classifications_pb2.Classifications +_ClassificationResultProto = classifications_pb2.ClassificationResult + + +@dataclasses.dataclass +class Classifications: + """Represents the classification results for a given classifier head. + + Attributes: + categories: The array of predicted categories, usually sorted by descending + scores (e.g. from high to low probability). + head_index: The index of the classifier head these categories refer to. This + is useful for multi-head models. + head_name: The name of the classifier head, which is the corresponding + tensor metadata name. + """ + + categories: List[category_module.Category] + head_index: int + head_name: Optional[str] = None + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': + """Creates a `Classifications` object from the given protobuf object.""" + categories = [] + for entry in pb2_obj.classification_list.classification: + categories.append( + category_module.Category( + index=entry.index, + score=entry.score, + display_name=entry.display_name, + category_name=entry.label)) + + return Classifications( + categories=categories, + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + +@dataclasses.dataclass +class ClassificationResult: + """Contains the classification results of a model. + + Attributes: + classifications: A list of `Classifications` objects, each for a head of the + model. + timestamp_ms: The optional timestamp (in milliseconds) of the start of the + chunk of data corresponding to these results. This is only used for + classification on time series (e.g. audio classification). In these use + cases, the amount of data to process might exceed the maximum size that + the model can process: to solve this, the input data is split into + multiple chunks starting at different timestamps. + """ + + classifications: List[Classifications] + timestamp_ms: Optional[int] = None + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult': + """Creates a `ClassificationResult` object from the given protobuf object. + """ + return ClassificationResult( + classifications=[ + Classifications.create_from_pb2(classification) + for classification in pb2_obj.classifications + ], + timestamp_ms=pb2_obj.timestamp_ms) diff --git a/mediapipe/tasks/python/test/audio/BUILD b/mediapipe/tasks/python/test/audio/BUILD new file mode 100644 index 000000000..863449126 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/BUILD @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "audio_classifier_test", + srcs = ["audio_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/audio:test_audio_clips", + "//mediapipe/tasks/testdata/audio:test_models", + ], + deps = [ + "//mediapipe/tasks/python/audio:audio_classifier", + "//mediapipe/tasks/python/audio/core:audio_task_running_mode", + "//mediapipe/tasks/python/components/containers:audio_data", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/audio/__init__.py b/mediapipe/tasks/python/test/audio/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/test/audio/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py new file mode 100644 index 000000000..983e922e7 --- /dev/null +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -0,0 +1,381 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for audio classifier.""" + +import os +from typing import List, Tuple +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +from scipy.io import wavfile + +from mediapipe.tasks.python.audio import audio_classifier +from mediapipe.tasks.python.audio.core import audio_task_running_mode +from mediapipe.tasks.python.components.containers import audio_data as audio_data_module +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils + +_AudioClassifier = audio_classifier.AudioClassifier +_AudioClassifierOptions = audio_classifier.AudioClassifierOptions +_AudioClassifierResult = classification_result_module.ClassificationResult +_AudioData = audio_data_module.AudioData +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode + +_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' +_YAMNET_MODEL_SAMPLE_RATE = 16000 +_TWO_HEADS_MODEL_FILE = 'two_heads.tflite' +_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav' +_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio' +_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav' +_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav' +_YAMNET_NUM_OF_SAMPLES = 15600 +_MILLSECONDS_PER_SECOND = 1000 + + +class AudioClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.yamnet_model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE)) + self.two_heads_model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE)) + + def _read_wav_file(self, file_name) -> _AudioData: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + return _AudioData.create_from_array( + buffer.astype(float) / np.iinfo(np.int16).max, sample_rate) + + def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]: + sample_rate, buffer = wavfile.read( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name))) + audio_data_list = [] + start = 0 + step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE + while start < len(buffer): + end = min(start + (int)(step_size), len(buffer)) + audio_data_list.append((_AudioData.create_from_array( + buffer[start:end].astype(float) / np.iinfo(np.int16).max, + sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND))) + start = end + return audio_data_list + + # TODO: Compares the exact score values to capture unexpected + # changes in the inference pipeline. + def _check_yamnet_result( + self, + classification_result_list: List[_AudioClassifierResult], + expected_num_categories=521): + self.assertLen(classification_result_list, 5) + for idx, timestamp in enumerate([0, 975, 1950, 2925]): + classification_result = classification_result_list[idx] + self.assertEqual(classification_result.timestamp_ms, timestamp) + self.assertLen(classification_result.classifications, 1) + classifcation = classification_result.classifications[0] + self.assertEqual(classifcation.head_index, 0) + self.assertEqual(classifcation.head_name, 'scores') + self.assertLen(classifcation.categories, expected_num_categories) + audio_category = classifcation.categories[0] + self.assertEqual(audio_category.index, 0) + self.assertEqual(audio_category.category_name, 'Speech') + self.assertGreater(audio_category.score, 0.9) + + # TODO: Compares the exact score values to capture unexpected + # changes in the inference pipeline. + def _check_two_heads_result( + self, + classification_result_list: List[_AudioClassifierResult], + first_head_expected_num_categories=521, + second_head_expected_num_categories=5): + self.assertGreaterEqual(len(classification_result_list), 1) + self.assertLessEqual(len(classification_result_list), 2) + # Checks the first result. + classification_result = classification_result_list[0] + self.assertEqual(classification_result.timestamp_ms, 0) + self.assertLen(classification_result.classifications, 2) + # Checks the first head. + yamnet_classifcation = classification_result.classifications[0] + self.assertEqual(yamnet_classifcation.head_index, 0) + self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification') + self.assertLen(yamnet_classifcation.categories, + first_head_expected_num_categories) + # Checks the second head. + yamnet_category = yamnet_classifcation.categories[0] + self.assertEqual(yamnet_category.index, 508) + self.assertEqual(yamnet_category.category_name, 'Environmental noise') + self.assertGreater(yamnet_category.score, 0.5) + bird_classifcation = classification_result.classifications[1] + self.assertEqual(bird_classifcation.head_index, 1) + self.assertEqual(bird_classifcation.head_name, 'bird_classification') + self.assertLen(bird_classifcation.categories, + second_head_expected_num_categories) + bird_category = bird_classifcation.categories[0] + self.assertEqual(bird_category.index, 4) + self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta') + self.assertGreater(bird_category.score, 0.93) + # Checks the second result, if present. + if len(classification_result_list) == 2: + classification_result = classification_result_list[1] + self.assertEqual(classification_result.timestamp_ms, 975) + self.assertLen(classification_result.classifications, 2) + # Checks the first head. + yamnet_classifcation = classification_result.classifications[0] + self.assertEqual(yamnet_classifcation.head_index, 0) + self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification') + self.assertLen(yamnet_classifcation.categories, + first_head_expected_num_categories) + yamnet_category = yamnet_classifcation.categories[0] + self.assertEqual(yamnet_category.index, 494) + self.assertEqual(yamnet_category.category_name, 'Silence') + self.assertGreater(yamnet_category.score, 0.9) + bird_classifcation = classification_result.classifications[1] + self.assertEqual(bird_classifcation.head_index, 1) + self.assertEqual(bird_classifcation.head_name, 'bird_classification') + self.assertLen(bird_classifcation.categories, + second_head_expected_num_categories) + # Checks the second head. + bird_category = bird_classifcation.categories[0] + self.assertEqual(bird_category.index, 1) + self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren') + self.assertGreater(bird_category.score, 0.99) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path) as classifier: + self.assertIsInstance(classifier, _AudioClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + with _AudioClassifier.create_from_options( + _AudioClassifierOptions( + base_options=_BaseOptions( + model_asset_path=self.yamnet_model_path))) as classifier: + self.assertIsInstance(classifier, _AudioClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _AudioClassifierOptions(base_options=base_options) + _AudioClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.yamnet_model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _AudioClassifierOptions(base_options=base_options) + classifier = _AudioClassifier.create_from_options(options) + self.assertIsInstance(classifier, _AudioClassifier) + + @parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO)) + def test_classify_with_yamnet_model(self, audio_file): + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path) as classifier: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_yamnet_result(classification_result_list) + + def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates( + self): + with _AudioClassifier.create_from_model_path( + self.yamnet_model_path) as classifier: + for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_yamnet_result(classification_result_list) + + def test_max_result_options(self): + with _AudioClassifier.create_from_options( + _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + classifier_options=_ClassifierOptions( + max_results=1))) as classifier: + for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_yamnet_result( + classification_result_list, expected_num_categories=1) + + def test_score_threshold_options(self): + with _AudioClassifier.create_from_options( + _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + classifier_options=_ClassifierOptions( + score_threshold=0.9))) as classifier: + for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_yamnet_result( + classification_result_list, expected_num_categories=1) + + def test_allow_list_option(self): + with _AudioClassifier.create_from_options( + _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + classifier_options=_ClassifierOptions( + category_allowlist=['Speech']))) as classifier: + for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_yamnet_result( + classification_result_list, expected_num_categories=1) + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + classifier_options=_ClassifierOptions( + category_allowlist=['foo'], category_denylist=['bar'])) + with _AudioClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_TWO_HEADS_WAV_16K_MONO), + (_TWO_HEADS_WAV_44K_MONO)) + def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates( + self, audio_file): + with _AudioClassifier.create_from_model_path( + self.two_heads_model_path) as classifier: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_two_heads_result(classification_result_list) + + def test_classify_with_two_heads_model(self): + with _AudioClassifier.create_from_model_path( + self.two_heads_model_path) as classifier: + for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_two_heads_result(classification_result_list) + + def test_classify_with_two_heads_model_with_max_results(self): + with _AudioClassifier.create_from_options( + _AudioClassifierOptions( + base_options=_BaseOptions( + model_asset_path=self.two_heads_model_path), + classifier_options=_ClassifierOptions( + max_results=1))) as classifier: + for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]: + classification_result_list = classifier.classify( + self._read_wav_file(audio_file)) + self._check_two_heads_result(classification_result_list, 1, 1) + + def test_missing_sample_rate_in_audio_clips_mode(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with self.assertRaisesRegex(ValueError, + r'Must provide the audio sample rate'): + with _AudioClassifier.create_from_options(options) as classifier: + classifier.classify(_AudioData(buffer_length=100)) + + def test_missing_sample_rate_in_audio_stream_mode(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'provide the audio sample rate in audio data'): + with _AudioClassifier.create_from_options(options) as classifier: + classifier.classify(_AudioData(buffer_length=100)) + + def test_missing_result_callback(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _AudioClassifier.create_from_options(options) as unused_classifier: + pass + + def test_illegal_result_callback(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _AudioClassifier.create_from_options(options) as unused_classifier: + pass + + def test_calling_classify_in_audio_stream_mode(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the audio clips mode'): + classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO)) + + def test_calling_classify_async_in_audio_clips_mode(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_CLIPS) + with _AudioClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex( + ValueError, r'not initialized with the audio stream mode'): + classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + def test_classify_async_calls_with_illegal_timestamp(self): + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + result_callback=mock.MagicMock()) + with _AudioClassifier.create_from_options(options) as classifier: + classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0) + + @parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO)) + def test_classify_async(self, audio_file): + classification_result_list = [] + + def save_result(result: _AudioClassifierResult, timestamp_ms: int): + result.timestamp_ms = timestamp_ms + classification_result_list.append(result) + + options = _AudioClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.yamnet_model_path), + running_mode=_RUNNING_MODE.AUDIO_STREAM, + classifier_options=_ClassifierOptions(max_results=1), + result_callback=save_result) + classifier = _AudioClassifier.create_from_options(options) + audio_data_list = self._read_wav_file_as_stream(audio_file) + for audio_data, timestamp_ms in audio_data_list: + classifier.classify_async(audio_data, timestamp_ms) + classifier.close() + self._check_yamnet_result( + classification_result_list, expected_num_categories=1) + + +if __name__ == '__main__': + absltest.main()