Implement MediaPipe AudioClassifier Tasks Python API. Adjust the AudioClassifier Tasks C++ API to remove "sample_rate" from its options.
PiperOrigin-RevId: 486763992
This commit is contained in:
		
							parent
							
								
									51dbd9779c
								
							
						
					
					
						commit
						63a759accc
					
				| 
						 | 
					@ -87,6 +87,7 @@ cc_library(
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "builtin_task_graphs",
 | 
					    name = "builtin_task_graphs",
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
 | 
					        "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
 | 
					        "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
 | 
					        "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -361,7 +361,7 @@ void PublicPacketCreators(pybind11::module* m) {
 | 
				
			||||||
    packet = mp.packet_creator.create_float(0.1)
 | 
					    packet = mp.packet_creator.create_float(0.1)
 | 
				
			||||||
    data = mp.packet_getter.get_float(packet)
 | 
					    data = mp.packet_getter.get_float(packet)
 | 
				
			||||||
)doc",
 | 
					)doc",
 | 
				
			||||||
      py::arg().noconvert(), py::return_value_policy::move);
 | 
					      py::return_value_policy::move);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  m->def(
 | 
					  m->def(
 | 
				
			||||||
      "create_double", [](double data) { return MakePacket<double>(data); },
 | 
					      "create_double", [](double data) { return MakePacket<double>(data); },
 | 
				
			||||||
| 
						 | 
					@ -380,7 +380,7 @@ void PublicPacketCreators(pybind11::module* m) {
 | 
				
			||||||
    packet = mp.packet_creator.create_double(0.1)
 | 
					    packet = mp.packet_creator.create_double(0.1)
 | 
				
			||||||
    data = mp.packet_getter.get_float(packet)
 | 
					    data = mp.packet_getter.get_float(packet)
 | 
				
			||||||
)doc",
 | 
					)doc",
 | 
				
			||||||
      py::arg().noconvert(), py::return_value_policy::move);
 | 
					      py::return_value_policy::move);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  m->def(
 | 
					  m->def(
 | 
				
			||||||
      "create_int_array",
 | 
					      "create_int_array",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,6 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
 | 
					        "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
 | 
					        "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
 | 
					        "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_resources",
 | 
					        "//mediapipe/tasks/cc/core:model_resources",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
					        "//mediapipe/tasks/cc/core:model_task_graph",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,10 +63,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
				
			||||||
  api2::builder::Graph graph;
 | 
					  api2::builder::Graph graph;
 | 
				
			||||||
  auto& subgraph = graph.AddNode(kSubgraphTypeName);
 | 
					  auto& subgraph = graph.AddNode(kSubgraphTypeName);
 | 
				
			||||||
  graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
 | 
					  graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
 | 
				
			||||||
  if (!options_proto->base_options().use_stream_mode()) {
 | 
					  graph.In(kSampleRateTag).SetName(kSampleRateName) >>
 | 
				
			||||||
    graph.In(kSampleRateTag).SetName(kSampleRateName) >>
 | 
					      subgraph.In(kSampleRateTag);
 | 
				
			||||||
        subgraph.In(kSampleRateTag);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
 | 
					  subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
 | 
				
			||||||
      options_proto.get());
 | 
					      options_proto.get());
 | 
				
			||||||
  subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
 | 
					  subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
 | 
				
			||||||
| 
						 | 
					@ -93,9 +91,6 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
 | 
				
			||||||
              &(options->classifier_options)));
 | 
					              &(options->classifier_options)));
 | 
				
			||||||
  options_proto->mutable_classifier_options()->Swap(
 | 
					  options_proto->mutable_classifier_options()->Swap(
 | 
				
			||||||
      classifier_options_proto.get());
 | 
					      classifier_options_proto.get());
 | 
				
			||||||
  if (options->sample_rate > 0) {
 | 
					 | 
				
			||||||
    options_proto->set_default_input_audio_sample_rate(options->sample_rate);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return options_proto;
 | 
					  return options_proto;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -129,14 +124,6 @@ absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
 | 
				
			||||||
/* static */
 | 
					/* static */
 | 
				
			||||||
absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
 | 
					absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
 | 
				
			||||||
    std::unique_ptr<AudioClassifierOptions> options) {
 | 
					    std::unique_ptr<AudioClassifierOptions> options) {
 | 
				
			||||||
  if (options->running_mode == core::RunningMode::AUDIO_STREAM &&
 | 
					 | 
				
			||||||
      options->sample_rate < 0) {
 | 
					 | 
				
			||||||
    return CreateStatusWithPayload(
 | 
					 | 
				
			||||||
        absl::StatusCode::kInvalidArgument,
 | 
					 | 
				
			||||||
        "The audio classifier is in audio stream mode, the sample rate must be "
 | 
					 | 
				
			||||||
        "specified in the AudioClassifierOptions.",
 | 
					 | 
				
			||||||
        MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
 | 
					  auto options_proto = ConvertAudioClassifierOptionsToProto(options.get());
 | 
				
			||||||
  tasks::core::PacketsCallback packets_callback = nullptr;
 | 
					  tasks::core::PacketsCallback packets_callback = nullptr;
 | 
				
			||||||
  if (options->result_callback) {
 | 
					  if (options->result_callback) {
 | 
				
			||||||
| 
						 | 
					@ -161,7 +148,9 @@ absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
 | 
					absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
 | 
				
			||||||
 | 
					                                            double audio_sample_rate,
 | 
				
			||||||
                                            int64 timestamp_ms) {
 | 
					                                            int64 timestamp_ms) {
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(CheckOrSetSampleRate(kSampleRateName, audio_sample_rate));
 | 
				
			||||||
  return SendAudioStreamData(
 | 
					  return SendAudioStreamData(
 | 
				
			||||||
      {{kAudioStreamName,
 | 
					      {{kAudioStreamName,
 | 
				
			||||||
        MakePacket<Matrix>(std::move(audio_block))
 | 
					        MakePacket<Matrix>(std::move(audio_block))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,15 +52,10 @@ struct AudioClassifierOptions {
 | 
				
			||||||
  // 1) The audio clips mode for running classification on independent audio
 | 
					  // 1) The audio clips mode for running classification on independent audio
 | 
				
			||||||
  //    clips.
 | 
					  //    clips.
 | 
				
			||||||
  // 2) The audio stream mode for running classification on the audio stream,
 | 
					  // 2) The audio stream mode for running classification on the audio stream,
 | 
				
			||||||
  //    such as from microphone. In this mode, the "sample_rate" below must be
 | 
					  //    such as from microphone. In this mode, the "result_callback" below must
 | 
				
			||||||
  //    provided, and the "result_callback" below must be specified to receive
 | 
					  //    be specified to receive the classification results asynchronously.
 | 
				
			||||||
  //    the classification results asynchronously.
 | 
					 | 
				
			||||||
  core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
 | 
					  core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The sample rate of the input audios. Must be set when the running mode is
 | 
					 | 
				
			||||||
  // set to RunningMode::AUDIO_STREAM.
 | 
					 | 
				
			||||||
  double sample_rate = -1.0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // The user-defined result callback for processing audio stream data.
 | 
					  // The user-defined result callback for processing audio stream data.
 | 
				
			||||||
  // The result callback should only be specified when the running mode is set
 | 
					  // The result callback should only be specified when the running mode is set
 | 
				
			||||||
  // to RunningMode::AUDIO_STREAM.
 | 
					  // to RunningMode::AUDIO_STREAM.
 | 
				
			||||||
| 
						 | 
					@ -160,15 +155,17 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
 | 
				
			||||||
  // The audio block is represented as a MediaPipe Matrix that has the number
 | 
					  // The audio block is represented as a MediaPipe Matrix that has the number
 | 
				
			||||||
  // of channels rows and the number of samples per channel columns. The audio
 | 
					  // of channels rows and the number of samples per channel columns. The audio
 | 
				
			||||||
  // data will be resampled, accumulated, and framed to the proper size for the
 | 
					  // data will be resampled, accumulated, and framed to the proper size for the
 | 
				
			||||||
  // underlying model to consume. It's required to provide a timestamp (in
 | 
					  // underlying model to consume. It's required to provide the corresponding
 | 
				
			||||||
  // milliseconds) to indicate the start time of the input audio block. The
 | 
					  // audio sample rate along with the input audio block as well as a timestamp
 | 
				
			||||||
 | 
					  // (in milliseconds) to indicate the start time of the input audio block. The
 | 
				
			||||||
  // timestamps must be monotonically increasing.
 | 
					  // timestamps must be monotonically increasing.
 | 
				
			||||||
  //
 | 
					  //
 | 
				
			||||||
  // The input audio block may be longer than what the model is able to process
 | 
					  // The input audio block may be longer than what the model is able to process
 | 
				
			||||||
  // in a single inference. When this occurs, the input audio block is split
 | 
					  // in a single inference. When this occurs, the input audio block is split
 | 
				
			||||||
  // into multiple chunks. For this reason, the callback may be called multiple
 | 
					  // into multiple chunks. For this reason, the callback may be called multiple
 | 
				
			||||||
  // times (once per chunk) for each call to this function.
 | 
					  // times (once per chunk) for each call to this function.
 | 
				
			||||||
  absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
 | 
					  absl::Status ClassifyAsync(mediapipe::Matrix audio_block,
 | 
				
			||||||
 | 
					                             double audio_sample_rate, int64 timestamp_ms);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Shuts down the AudioClassifier when all works are done.
 | 
					  // Shuts down the AudioClassifier when all works are done.
 | 
				
			||||||
  absl::Status Close() { return runner_->Close(); }
 | 
					  absl::Status Close() { return runner_->Close(); }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,18 +72,6 @@ struct AudioClassifierOutputStreams {
 | 
				
			||||||
  Source<std::vector<ClassificationResult>> timestamped_classifications;
 | 
					  Source<std::vector<ClassificationResult>> timestamped_classifications;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status SanityCheckOptions(
 | 
					 | 
				
			||||||
    const proto::AudioClassifierGraphOptions& options) {
 | 
					 | 
				
			||||||
  if (options.base_options().use_stream_mode() &&
 | 
					 | 
				
			||||||
      !options.has_default_input_audio_sample_rate()) {
 | 
					 | 
				
			||||||
    return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
 | 
					 | 
				
			||||||
                                   "In the streaming mode, the default input "
 | 
					 | 
				
			||||||
                                   "audio sample rate must be set.",
 | 
					 | 
				
			||||||
                                   MediaPipeTasksStatus::kInvalidArgumentError);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return absl::OkStatus();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
 | 
					// Builds an AudioTensorSpecs for configuring the preprocessing calculators.
 | 
				
			||||||
absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
 | 
					absl::StatusOr<AudioTensorSpecs> BuildPreprocessingSpecs(
 | 
				
			||||||
    const core::ModelResources& model_resources) {
 | 
					    const core::ModelResources& model_resources) {
 | 
				
			||||||
| 
						 | 
					@ -170,19 +158,12 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
 | 
				
			||||||
        const auto* model_resources,
 | 
					        const auto* model_resources,
 | 
				
			||||||
        CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
 | 
					        CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
 | 
				
			||||||
    Graph graph;
 | 
					    Graph graph;
 | 
				
			||||||
    const bool use_stream_mode =
 | 
					 | 
				
			||||||
        sc->Options<proto::AudioClassifierGraphOptions>()
 | 
					 | 
				
			||||||
            .base_options()
 | 
					 | 
				
			||||||
            .use_stream_mode();
 | 
					 | 
				
			||||||
    ASSIGN_OR_RETURN(
 | 
					    ASSIGN_OR_RETURN(
 | 
				
			||||||
        auto output_streams,
 | 
					        auto output_streams,
 | 
				
			||||||
        BuildAudioClassificationTask(
 | 
					        BuildAudioClassificationTask(
 | 
				
			||||||
            sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
 | 
					            sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
 | 
				
			||||||
            graph[Input<Matrix>(kAudioTag)],
 | 
					            graph[Input<Matrix>(kAudioTag)],
 | 
				
			||||||
            use_stream_mode
 | 
					            absl::make_optional(graph[Input<double>(kSampleRateTag)]), graph));
 | 
				
			||||||
                ? absl::nullopt
 | 
					 | 
				
			||||||
                : absl::make_optional(graph[Input<double>(kSampleRateTag)]),
 | 
					 | 
				
			||||||
            graph));
 | 
					 | 
				
			||||||
    output_streams.classifications >>
 | 
					    output_streams.classifications >>
 | 
				
			||||||
        graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
					        graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
				
			||||||
    output_streams.timestamped_classifications >>
 | 
					    output_streams.timestamped_classifications >>
 | 
				
			||||||
| 
						 | 
					@ -207,7 +188,6 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
 | 
				
			||||||
      const proto::AudioClassifierGraphOptions& task_options,
 | 
					      const proto::AudioClassifierGraphOptions& task_options,
 | 
				
			||||||
      const core::ModelResources& model_resources, Source<Matrix> audio_in,
 | 
					      const core::ModelResources& model_resources, Source<Matrix> audio_in,
 | 
				
			||||||
      absl::optional<Source<double>> sample_rate_in, Graph& graph) {
 | 
					      absl::optional<Source<double>> sample_rate_in, Graph& graph) {
 | 
				
			||||||
    MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
 | 
					 | 
				
			||||||
    const bool use_stream_mode = task_options.base_options().use_stream_mode();
 | 
					    const bool use_stream_mode = task_options.base_options().use_stream_mode();
 | 
				
			||||||
    const auto* metadata_extractor = model_resources.GetMetadataExtractor();
 | 
					    const auto* metadata_extractor = model_resources.GetMetadataExtractor();
 | 
				
			||||||
    // Checks that metadata is available.
 | 
					    // Checks that metadata is available.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,6 +70,8 @@ Matrix GetAudioData(absl::string_view filename) {
 | 
				
			||||||
  return matrix_mapping.matrix();
 | 
					  return matrix_mapping.matrix();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Compares the exact score values to capture unexpected
 | 
				
			||||||
 | 
					// changes in the inference pipeline.
 | 
				
			||||||
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
 | 
					void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
 | 
				
			||||||
                       int expected_num_categories = 521) {
 | 
					                       int expected_num_categories = 521) {
 | 
				
			||||||
  EXPECT_EQ(result.size(), 5);
 | 
					  EXPECT_EQ(result.size(), 5);
 | 
				
			||||||
| 
						 | 
					@ -90,13 +92,15 @@ void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Compares the exact score values to capture unexpected
 | 
				
			||||||
 | 
					// changes in the inference pipeline.
 | 
				
			||||||
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
 | 
					void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
 | 
				
			||||||
  EXPECT_GE(result.size(), 1);
 | 
					  EXPECT_GE(result.size(), 1);
 | 
				
			||||||
  EXPECT_LE(result.size(), 2);
 | 
					  EXPECT_LE(result.size(), 2);
 | 
				
			||||||
  // Check first result.
 | 
					  // Check the first result.
 | 
				
			||||||
  EXPECT_EQ(result[0].timestamp_ms, 0);
 | 
					  EXPECT_EQ(result[0].timestamp_ms, 0);
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications.size(), 2);
 | 
					  EXPECT_EQ(result[0].classifications.size(), 2);
 | 
				
			||||||
  // Check first head.
 | 
					  // Check the first head.
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[0].head_index, 0);
 | 
					  EXPECT_EQ(result[0].classifications[0].head_index, 0);
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
 | 
					  EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
 | 
					  EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
 | 
				
			||||||
| 
						 | 
					@ -104,19 +108,19 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
 | 
					  EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
 | 
				
			||||||
            "Environmental noise");
 | 
					            "Environmental noise");
 | 
				
			||||||
  EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
 | 
					  EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
 | 
				
			||||||
  // Check second head.
 | 
					  // Check the second head.
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[1].head_index, 1);
 | 
					  EXPECT_EQ(result[0].classifications[1].head_index, 1);
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
 | 
					  EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
 | 
					  EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
 | 
					  EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
 | 
				
			||||||
  EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
 | 
					  EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
 | 
				
			||||||
            "Chestnut-crowned Antpitta");
 | 
					            "Chestnut-crowned Antpitta");
 | 
				
			||||||
  EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
 | 
					  EXPECT_GT(result[0].classifications[1].categories[0].score, 0.93f);
 | 
				
			||||||
  // Check second result, if present.
 | 
					  // Check the second result, if present.
 | 
				
			||||||
  if (result.size() == 2) {
 | 
					  if (result.size() == 2) {
 | 
				
			||||||
    EXPECT_EQ(result[1].timestamp_ms, 975);
 | 
					    EXPECT_EQ(result[1].timestamp_ms, 975);
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications.size(), 2);
 | 
					    EXPECT_EQ(result[1].classifications.size(), 2);
 | 
				
			||||||
    // Check first head.
 | 
					    // Check the first head.
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[0].head_index, 0);
 | 
					    EXPECT_EQ(result[1].classifications[0].head_index, 0);
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
 | 
					    EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
 | 
					    EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
 | 
				
			||||||
| 
						 | 
					@ -124,7 +128,7 @@ void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
 | 
					    EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
 | 
				
			||||||
              "Silence");
 | 
					              "Silence");
 | 
				
			||||||
    EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
 | 
					    EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
 | 
				
			||||||
    // Check second head.
 | 
					    // Check the second head.
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[1].head_index, 1);
 | 
					    EXPECT_EQ(result[1].classifications[1].head_index, 1);
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
 | 
					    EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
 | 
				
			||||||
    EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
 | 
					    EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
 | 
				
			||||||
| 
						 | 
					@ -234,7 +238,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) {
 | 
				
			||||||
  options->base_options.model_asset_path =
 | 
					  options->base_options.model_asset_path =
 | 
				
			||||||
      JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
 | 
					      JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
 | 
				
			||||||
  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
					  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
				
			||||||
  options->sample_rate = 16000;
 | 
					 | 
				
			||||||
  StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
 | 
					  StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
 | 
				
			||||||
      AudioClassifier::Create(std::move(options));
 | 
					      AudioClassifier::Create(std::move(options));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -266,25 +269,6 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
 | 
				
			||||||
                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
					                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
 | 
					 | 
				
			||||||
  auto options = std::make_unique<AudioClassifierOptions>();
 | 
					 | 
				
			||||||
  options->base_options.model_asset_path =
 | 
					 | 
				
			||||||
      JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
 | 
					 | 
				
			||||||
  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
					 | 
				
			||||||
  options->result_callback =
 | 
					 | 
				
			||||||
      [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
 | 
					 | 
				
			||||||
  StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
 | 
					 | 
				
			||||||
      AudioClassifier::Create(std::move(options));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  EXPECT_EQ(audio_classifier_or.status().code(),
 | 
					 | 
				
			||||||
            absl::StatusCode::kInvalidArgument);
 | 
					 | 
				
			||||||
  EXPECT_THAT(audio_classifier_or.status().message(),
 | 
					 | 
				
			||||||
              HasSubstr("the sample rate must be specified"));
 | 
					 | 
				
			||||||
  EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
 | 
					 | 
				
			||||||
              Optional(absl::Cord(absl::StrCat(
 | 
					 | 
				
			||||||
                  MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClassifyTest : public tflite_shims::testing::Test {};
 | 
					class ClassifyTest : public tflite_shims::testing::Test {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(ClassifyTest, Succeeds) {
 | 
					TEST_F(ClassifyTest, Succeeds) {
 | 
				
			||||||
| 
						 | 
					@ -493,7 +477,6 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
 | 
				
			||||||
  options->classifier_options.max_results = 1;
 | 
					  options->classifier_options.max_results = 1;
 | 
				
			||||||
  options->classifier_options.score_threshold = 0.3f;
 | 
					  options->classifier_options.score_threshold = 0.3f;
 | 
				
			||||||
  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
					  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
				
			||||||
  options->sample_rate = kSampleRateHz;
 | 
					 | 
				
			||||||
  std::vector<AudioClassifierResult> outputs;
 | 
					  std::vector<AudioClassifierResult> outputs;
 | 
				
			||||||
  options->result_callback =
 | 
					  options->result_callback =
 | 
				
			||||||
      [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
 | 
					      [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
 | 
				
			||||||
| 
						 | 
					@ -506,7 +489,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
 | 
				
			||||||
    int num_samples = std::min((int)(audio_buffer.cols() - start_col),
 | 
					    int num_samples = std::min((int)(audio_buffer.cols() - start_col),
 | 
				
			||||||
                               kYamnetNumOfAudioSamples * 3);
 | 
					                               kYamnetNumOfAudioSamples * 3);
 | 
				
			||||||
    MP_ASSERT_OK(audio_classifier->ClassifyAsync(
 | 
					    MP_ASSERT_OK(audio_classifier->ClassifyAsync(
 | 
				
			||||||
        audio_buffer.block(0, start_col, 1, num_samples),
 | 
					        audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
 | 
				
			||||||
        start_col * kMilliSecondsPerSecond / kSampleRateHz));
 | 
					        start_col * kMilliSecondsPerSecond / kSampleRateHz));
 | 
				
			||||||
    start_col += kYamnetNumOfAudioSamples * 3;
 | 
					    start_col += kYamnetNumOfAudioSamples * 3;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -523,7 +506,6 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
 | 
				
			||||||
  options->classifier_options.max_results = 1;
 | 
					  options->classifier_options.max_results = 1;
 | 
				
			||||||
  options->classifier_options.score_threshold = 0.3f;
 | 
					  options->classifier_options.score_threshold = 0.3f;
 | 
				
			||||||
  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
					  options->running_mode = core::RunningMode::AUDIO_STREAM;
 | 
				
			||||||
  options->sample_rate = kSampleRateHz;
 | 
					 | 
				
			||||||
  std::vector<AudioClassifierResult> outputs;
 | 
					  std::vector<AudioClassifierResult> outputs;
 | 
				
			||||||
  options->result_callback =
 | 
					  options->result_callback =
 | 
				
			||||||
      [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
 | 
					      [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
 | 
				
			||||||
| 
						 | 
					@ -538,7 +520,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
 | 
				
			||||||
        std::min((int)(audio_buffer.cols() - start_col),
 | 
					        std::min((int)(audio_buffer.cols() - start_col),
 | 
				
			||||||
                 rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
 | 
					                 rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * 3);
 | 
				
			||||||
    MP_ASSERT_OK(audio_classifier->ClassifyAsync(
 | 
					    MP_ASSERT_OK(audio_classifier->ClassifyAsync(
 | 
				
			||||||
        audio_buffer.block(0, start_col, 1, num_samples),
 | 
					        audio_buffer.block(0, start_col, 1, num_samples), kSampleRateHz,
 | 
				
			||||||
        start_col * kMilliSecondsPerSecond / kSampleRateHz));
 | 
					        start_col * kMilliSecondsPerSecond / kSampleRateHz));
 | 
				
			||||||
    start_col += num_samples;
 | 
					    start_col += num_samples;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -72,8 +72,39 @@ class BaseAudioTaskApi : public tasks::core::BaseTaskApi {
 | 
				
			||||||
    return runner_->Send(std::move(inputs));
 | 
					    return runner_->Send(std::move(inputs));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Checks or sets the sample rate in the audio stream mode.
 | 
				
			||||||
 | 
					  absl::Status CheckOrSetSampleRate(std::string sample_rate_stream_name,
 | 
				
			||||||
 | 
					                                    double sample_rate) {
 | 
				
			||||||
 | 
					    if (running_mode_ != RunningMode::AUDIO_STREAM) {
 | 
				
			||||||
 | 
					      return CreateStatusWithPayload(
 | 
				
			||||||
 | 
					          absl::StatusCode::kInvalidArgument,
 | 
				
			||||||
 | 
					          absl::StrCat("Task is not initialized with the audio stream mode. "
 | 
				
			||||||
 | 
					                       "Current running mode:",
 | 
				
			||||||
 | 
					                       GetRunningModeName(running_mode_)),
 | 
				
			||||||
 | 
					          MediaPipeTasksStatus::kRunnerApiCalledInWrongModeError);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (default_sample_rate_ > 0) {
 | 
				
			||||||
 | 
					      if (std::fabs(sample_rate - default_sample_rate_) >
 | 
				
			||||||
 | 
					          std::numeric_limits<double>::epsilon()) {
 | 
				
			||||||
 | 
					        return CreateStatusWithPayload(
 | 
				
			||||||
 | 
					            absl::StatusCode::kInvalidArgument,
 | 
				
			||||||
 | 
					            absl::StrCat("The input audio sample rate: ", sample_rate,
 | 
				
			||||||
 | 
					                         " is inconsistent with the previously provided: ",
 | 
				
			||||||
 | 
					                         default_sample_rate_),
 | 
				
			||||||
 | 
					            MediaPipeTasksStatus::kInvalidArgumentError);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      default_sample_rate_ = sample_rate;
 | 
				
			||||||
 | 
					      MP_RETURN_IF_ERROR(runner_->Send(
 | 
				
			||||||
 | 
					          {{sample_rate_stream_name, MakePacket<double>(default_sample_rate_)
 | 
				
			||||||
 | 
					                                         .At(Timestamp::PreStream())}}));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  RunningMode running_mode_;
 | 
					  RunningMode running_mode_;
 | 
				
			||||||
 | 
					  double default_sample_rate_ = -1.0;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace core
 | 
					}  // namespace core
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										41
									
								
								mediapipe/tasks/python/audio/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								mediapipe/tasks/python/audio/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,41 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict library and test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "audio_classifier",
 | 
				
			||||||
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "audio_classifier.py",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
 | 
					        "//mediapipe/python:packet_creator",
 | 
				
			||||||
 | 
					        "//mediapipe/python:packet_getter",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/audio/core:base_audio_task_api",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:task_info",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										13
									
								
								mediapipe/tasks/python/audio/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/tasks/python/audio/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
							
								
								
									
										280
									
								
								mediapipe/tasks/python/audio/audio_classifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								mediapipe/tasks/python/audio/audio_classifier.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,280 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""MediaPipe audio classifier task."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					from typing import Callable, Mapping, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.python import packet_creator
 | 
				
			||||||
 | 
					from mediapipe.python import packet_getter
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import packet
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.audio.audio_classifier.proto import audio_classifier_graph_options_pb2
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.audio.core import base_audio_task_api
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.processors import classifier_options as classifier_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import task_info as task_info_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					AudioClassifierResult = classification_result_module.ClassificationResult
 | 
				
			||||||
 | 
					_AudioClassifierGraphOptionsProto = audio_classifier_graph_options_pb2.AudioClassifierGraphOptions
 | 
				
			||||||
 | 
					_AudioData = audio_data_module.AudioData
 | 
				
			||||||
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
 | 
					_ClassifierOptions = classifier_options_module.ClassifierOptions
 | 
				
			||||||
 | 
					_RunningMode = running_mode_module.AudioTaskRunningMode
 | 
				
			||||||
 | 
					_TaskInfo = task_info_module.TaskInfo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_AUDIO_IN_STREAM_NAME = 'audio_in'
 | 
				
			||||||
 | 
					_AUDIO_TAG = 'AUDIO'
 | 
				
			||||||
 | 
					_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
 | 
				
			||||||
 | 
					_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
 | 
				
			||||||
 | 
					_SAMPLE_RATE_IN_STREAM_NAME = 'sample_rate_in'
 | 
				
			||||||
 | 
					_SAMPLE_RATE_TAG = 'SAMPLE_RATE'
 | 
				
			||||||
 | 
					_TASK_GRAPH_NAME = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'
 | 
				
			||||||
 | 
					_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME = 'timestamped_classifications_out'
 | 
				
			||||||
 | 
					_TIMESTAMPED_CLASSIFICATIONS_TAG = 'TIMESTAMPED_CLASSIFICATIONS'
 | 
				
			||||||
 | 
					_MICRO_SECONDS_PER_MILLISECOND = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class AudioClassifierOptions:
 | 
				
			||||||
 | 
					  """Options for the audio classifier task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    base_options: Base options for the audio classifier task.
 | 
				
			||||||
 | 
					    running_mode: The running mode of the task. Default to the audio clips mode.
 | 
				
			||||||
 | 
					      Audio classifier task has two running modes: 1) The audio clips mode for
 | 
				
			||||||
 | 
					      running classification on independent audio clips. 2) The audio stream
 | 
				
			||||||
 | 
					      mode for running classification on the audio stream, such as from
 | 
				
			||||||
 | 
					      microphone. In this mode,  the "result_callback" below must be specified
 | 
				
			||||||
 | 
					      to receive the classification results asynchronously.
 | 
				
			||||||
 | 
					    classifier_options: Options for configuring the classifier behavior, such as
 | 
				
			||||||
 | 
					      score threshold, number of results, etc.
 | 
				
			||||||
 | 
					    result_callback: The user-defined result callback for processing audio
 | 
				
			||||||
 | 
					      stream data. The result callback should only be specified when the running
 | 
				
			||||||
 | 
					      mode is set to the audio stream mode.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  base_options: _BaseOptions
 | 
				
			||||||
 | 
					  running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS
 | 
				
			||||||
 | 
					  classifier_options: _ClassifierOptions = _ClassifierOptions()
 | 
				
			||||||
 | 
					  result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def to_pb2(self) -> _AudioClassifierGraphOptionsProto:
 | 
				
			||||||
 | 
					    """Generates an AudioClassifierOptions protobuf object."""
 | 
				
			||||||
 | 
					    base_options_proto = self.base_options.to_pb2()
 | 
				
			||||||
 | 
					    base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.AUDIO_CLIPS else True
 | 
				
			||||||
 | 
					    classifier_options_proto = self.classifier_options.to_pb2()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return _AudioClassifierGraphOptionsProto(
 | 
				
			||||||
 | 
					        base_options=base_options_proto,
 | 
				
			||||||
 | 
					        classifier_options=classifier_options_proto)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AudioClassifier(base_audio_task_api.BaseAudioTaskApi):
 | 
				
			||||||
 | 
					  """Class that performs audio classification on audio data."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def create_from_model_path(cls, model_path: str) -> 'AudioClassifier':
 | 
				
			||||||
 | 
					    """Creates an `AudioClassifier` object from a TensorFlow Lite model and the default `AudioClassifierOptions`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note that the created `AudioClassifier` instance is in audio clips mode, for
 | 
				
			||||||
 | 
					    classifying on independent audio clips.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      model_path: Path to the model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      `AudioClassifier` object that's created from the model file and the
 | 
				
			||||||
 | 
					      default `AudioClassifierOptions`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If failed to create `AudioClassifier` object from the provided
 | 
				
			||||||
 | 
					        file such as invalid file path.
 | 
				
			||||||
 | 
					      RuntimeError: If other types of error occurred.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    base_options = _BaseOptions(model_asset_path=model_path)
 | 
				
			||||||
 | 
					    options = AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=base_options, running_mode=_RunningMode.AUDIO_CLIPS)
 | 
				
			||||||
 | 
					    return cls.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  def create_from_options(cls,
 | 
				
			||||||
 | 
					                          options: AudioClassifierOptions) -> 'AudioClassifier':
 | 
				
			||||||
 | 
					    """Creates the `AudioClassifier` object from audio classifier options.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      options: Options for the audio classifier task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      `AudioClassifier` object that's created from `options`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If failed to create `AudioClassifier` object from
 | 
				
			||||||
 | 
					        `AudioClassifierOptions` such as missing the model.
 | 
				
			||||||
 | 
					      RuntimeError: If other types of error occurred.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def packets_callback(output_packets: Mapping[str, packet.Packet]):
 | 
				
			||||||
 | 
					      timestamp_ms = output_packets[
 | 
				
			||||||
 | 
					          _CLASSIFICATIONS_STREAM_NAME].timestamp.value // _MICRO_SECONDS_PER_MILLISECOND
 | 
				
			||||||
 | 
					      if output_packets[_CLASSIFICATIONS_STREAM_NAME].is_empty():
 | 
				
			||||||
 | 
					        options.result_callback(
 | 
				
			||||||
 | 
					            AudioClassifierResult(classifications=[]), timestamp_ms)
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					      classification_result_proto = classifications_pb2.ClassificationResult()
 | 
				
			||||||
 | 
					      classification_result_proto.CopyFrom(
 | 
				
			||||||
 | 
					          packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
 | 
				
			||||||
 | 
					      options.result_callback(
 | 
				
			||||||
 | 
					          AudioClassifierResult.create_from_pb2(classification_result_proto),
 | 
				
			||||||
 | 
					          timestamp_ms)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    task_info = _TaskInfo(
 | 
				
			||||||
 | 
					        task_graph=_TASK_GRAPH_NAME,
 | 
				
			||||||
 | 
					        input_streams=[
 | 
				
			||||||
 | 
					            ':'.join([_AUDIO_TAG, _AUDIO_IN_STREAM_NAME]),
 | 
				
			||||||
 | 
					            ':'.join([_SAMPLE_RATE_TAG, _SAMPLE_RATE_IN_STREAM_NAME])
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        output_streams=[
 | 
				
			||||||
 | 
					            ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
 | 
				
			||||||
 | 
					            ':'.join([
 | 
				
			||||||
 | 
					                _TIMESTAMPED_CLASSIFICATIONS_TAG,
 | 
				
			||||||
 | 
					                _TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME
 | 
				
			||||||
 | 
					            ])
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        task_options=options)
 | 
				
			||||||
 | 
					    return cls(
 | 
				
			||||||
 | 
					        # Audio tasks should not drop input audio due to flow limiting, which
 | 
				
			||||||
 | 
					        # may cause data inconsistency.
 | 
				
			||||||
 | 
					        task_info.generate_graph_config(enable_flow_limiting=False),
 | 
				
			||||||
 | 
					        options.running_mode,
 | 
				
			||||||
 | 
					        packets_callback if options.result_callback else None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def classify(self, audio_clip: _AudioData) -> List[AudioClassifierResult]:
 | 
				
			||||||
 | 
					    """Performs audio classification on the provided audio clip.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The audio clip is represented as a MediaPipe AudioData. The method accepts
 | 
				
			||||||
 | 
					    audio clips with various length and audio sample rate. It's required to
 | 
				
			||||||
 | 
					    provide the corresponding audio sample rate within the `AudioData` object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The input audio clip may be longer than what the model is able to process
 | 
				
			||||||
 | 
					    in a single inference. When this occurs, the input audio clip is split into
 | 
				
			||||||
 | 
					    multiple chunks starting at different timestamps. For this reason, this
 | 
				
			||||||
 | 
					    function returns a vector of ClassificationResult objects, each associated
 | 
				
			||||||
 | 
					    ith a timestamp corresponding to the start (in milliseconds) of the chunk
 | 
				
			||||||
 | 
					    data that was classified, e.g:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ClassificationResult #0 (first chunk of data):
 | 
				
			||||||
 | 
					      timestamp_ms: 0 (starts at 0ms)
 | 
				
			||||||
 | 
					      classifications #0 (single head model):
 | 
				
			||||||
 | 
					        category #0:
 | 
				
			||||||
 | 
					          category_name: "Speech"
 | 
				
			||||||
 | 
					          score: 0.6
 | 
				
			||||||
 | 
					        category #1:
 | 
				
			||||||
 | 
					          category_name: "Music"
 | 
				
			||||||
 | 
					          score: 0.2
 | 
				
			||||||
 | 
					    ClassificationResult #1 (second chunk of data):
 | 
				
			||||||
 | 
					      timestamp_ms: 800 (starts at 800ms)
 | 
				
			||||||
 | 
					      classifications #0 (single head model):
 | 
				
			||||||
 | 
					        category #0:
 | 
				
			||||||
 | 
					          category_name: "Speech"
 | 
				
			||||||
 | 
					          score: 0.5
 | 
				
			||||||
 | 
					       category #1:
 | 
				
			||||||
 | 
					         category_name: "Silence"
 | 
				
			||||||
 | 
					         score: 0.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      audio_clip: MediaPipe AudioData.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      An `AudioClassifierResult` object that contains a list of
 | 
				
			||||||
 | 
					      classification result objects, each associated with a timestamp
 | 
				
			||||||
 | 
					      corresponding to the start (in milliseconds) of the chunk data that was
 | 
				
			||||||
 | 
					      classified.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If any of the input arguments is invalid, such as the sample
 | 
				
			||||||
 | 
					        rate is not provided in the `AudioData` object.
 | 
				
			||||||
 | 
					      RuntimeError: If audio classification failed to run.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not audio_clip.audio_format.sample_rate:
 | 
				
			||||||
 | 
					      raise ValueError('Must provide the audio sample rate in audio data.')
 | 
				
			||||||
 | 
					    output_packets = self._process_audio_clip({
 | 
				
			||||||
 | 
					        _AUDIO_IN_STREAM_NAME:
 | 
				
			||||||
 | 
					            packet_creator.create_matrix(audio_clip.buffer, transpose=True),
 | 
				
			||||||
 | 
					        _SAMPLE_RATE_IN_STREAM_NAME:
 | 
				
			||||||
 | 
					            packet_creator.create_double(audio_clip.audio_format.sample_rate)
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					    output_list = []
 | 
				
			||||||
 | 
					    classification_result_proto_list = packet_getter.get_proto_list(
 | 
				
			||||||
 | 
					        output_packets[_TIMESTAMPED_CLASSIFICATIONS_STREAM_NAME])
 | 
				
			||||||
 | 
					    for proto in classification_result_proto_list:
 | 
				
			||||||
 | 
					      classification_result_proto = classifications_pb2.ClassificationResult()
 | 
				
			||||||
 | 
					      classification_result_proto.CopyFrom(proto)
 | 
				
			||||||
 | 
					      output_list.append(
 | 
				
			||||||
 | 
					          AudioClassifierResult.create_from_pb2(classification_result_proto))
 | 
				
			||||||
 | 
					    return output_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def classify_async(self, audio_block: _AudioData, timestamp_ms: int) -> None:
 | 
				
			||||||
 | 
					    """Sends audio data (a block in a continuous audio stream) to perform audio classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Only use this method when the AudioClassifier is created with the audio
 | 
				
			||||||
 | 
					    stream running mode. The input timestamps should be monotonically increasing
 | 
				
			||||||
 | 
					    for adjacent calls of this method. This method will return immediately after
 | 
				
			||||||
 | 
					    the input audio data is accepted. The results will be available via the
 | 
				
			||||||
 | 
					    `result_callback` provided in the `AudioClassifierOptions`. The
 | 
				
			||||||
 | 
					    `classify_async` method is designed to process auido stream data such as
 | 
				
			||||||
 | 
					    microphone input.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The input audio data may be longer than what the model is able to process
 | 
				
			||||||
 | 
					    in a single inference. When this occurs, the input audio block is split
 | 
				
			||||||
 | 
					    into multiple chunks. For this reason, the callback may be called multiple
 | 
				
			||||||
 | 
					    times (once per chunk) for each call to this function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The `result_callback` provides:
 | 
				
			||||||
 | 
					      - An `AudioClassifierResult` object that contains a list of
 | 
				
			||||||
 | 
					        classifications.
 | 
				
			||||||
 | 
					      - The input timestamp in milliseconds.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      audio_block: MediaPipe AudioData.
 | 
				
			||||||
 | 
					      timestamp_ms: The timestamp of the input audio data in milliseconds.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If any of the followings:
 | 
				
			||||||
 | 
					        1) The sample rate is not provided in the `AudioData` object or the
 | 
				
			||||||
 | 
					        provided sample rate is inconsisent with the previously recevied.
 | 
				
			||||||
 | 
					        2) The current input timestamp is smaller than what the audio
 | 
				
			||||||
 | 
					        classifier has already processed.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if not audio_block.audio_format.sample_rate:
 | 
				
			||||||
 | 
					      raise ValueError('Must provide the audio sample rate in audio data.')
 | 
				
			||||||
 | 
					    if not self._default_sample_rate:
 | 
				
			||||||
 | 
					      self._default_sample_rate = audio_block.audio_format.sample_rate
 | 
				
			||||||
 | 
					      self._set_sample_rate(_SAMPLE_RATE_IN_STREAM_NAME,
 | 
				
			||||||
 | 
					                            self._default_sample_rate)
 | 
				
			||||||
 | 
					    elif audio_block.audio_format.sample_rate != self._default_sample_rate:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          f'The audio sample rate provided in audio data: '
 | 
				
			||||||
 | 
					          f'{audio_block.audio_format.sample_rate} is inconsisent with '
 | 
				
			||||||
 | 
					          f'the previously received: {self._default_sample_rate}.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self._send_audio_stream_data({
 | 
				
			||||||
 | 
					        _AUDIO_IN_STREAM_NAME:
 | 
				
			||||||
 | 
					            packet_creator.create_matrix(audio_block.buffer, transpose=True).at(
 | 
				
			||||||
 | 
					                timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,7 @@ py_library(
 | 
				
			||||||
        ":audio_task_running_mode",
 | 
					        ":audio_task_running_mode",
 | 
				
			||||||
        "//mediapipe/framework:calculator_py_pb2",
 | 
					        "//mediapipe/framework:calculator_py_pb2",
 | 
				
			||||||
        "//mediapipe/python:_framework_bindings",
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
 | 
					        "//mediapipe/python:packet_creator",
 | 
				
			||||||
        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,14 +16,17 @@
 | 
				
			||||||
from typing import Callable, Mapping, Optional
 | 
					from typing import Callable, Mapping, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.framework import calculator_pb2
 | 
					from mediapipe.framework import calculator_pb2
 | 
				
			||||||
 | 
					from mediapipe.python import packet_creator
 | 
				
			||||||
from mediapipe.python._framework_bindings import packet as packet_module
 | 
					from mediapipe.python._framework_bindings import packet as packet_module
 | 
				
			||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
 | 
					from mediapipe.python._framework_bindings import task_runner as task_runner_module
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import timestamp as timestamp_module
 | 
				
			||||||
from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode as running_mode_module
 | 
				
			||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_TaskRunner = task_runner_module.TaskRunner
 | 
					_TaskRunner = task_runner_module.TaskRunner
 | 
				
			||||||
_Packet = packet_module.Packet
 | 
					_Packet = packet_module.Packet
 | 
				
			||||||
_RunningMode = running_mode_module.AudioTaskRunningMode
 | 
					_RunningMode = running_mode_module.AudioTaskRunningMode
 | 
				
			||||||
 | 
					_Timestamp = timestamp_module.Timestamp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseAudioTaskApi(object):
 | 
					class BaseAudioTaskApi(object):
 | 
				
			||||||
| 
						 | 
					@ -59,6 +62,7 @@ class BaseAudioTaskApi(object):
 | 
				
			||||||
          'callback should not be provided.')
 | 
					          'callback should not be provided.')
 | 
				
			||||||
    self._runner = _TaskRunner.create(graph_config, packet_callback)
 | 
					    self._runner = _TaskRunner.create(graph_config, packet_callback)
 | 
				
			||||||
    self._running_mode = running_mode
 | 
					    self._running_mode = running_mode
 | 
				
			||||||
 | 
					    self._default_sample_rate = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _process_audio_clip(
 | 
					  def _process_audio_clip(
 | 
				
			||||||
      self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
 | 
					      self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
 | 
				
			||||||
| 
						 | 
					@ -82,6 +86,27 @@ class BaseAudioTaskApi(object):
 | 
				
			||||||
          + self._running_mode.name)
 | 
					          + self._running_mode.name)
 | 
				
			||||||
    return self._runner.process(inputs)
 | 
					    return self._runner.process(inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _set_sample_rate(self, sample_rate_stream_name: str,
 | 
				
			||||||
 | 
					                       sample_rate: float) -> None:
 | 
				
			||||||
 | 
					    """An asynchronous method to set audio sample rate in the audio stream mode.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      sample_rate_stream_name: The audio sample rate stream name.
 | 
				
			||||||
 | 
					      sample_rate: The audio sample rate.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If the task's running mode is not set to the audio stream
 | 
				
			||||||
 | 
					      mode.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if self._running_mode != _RunningMode.AUDIO_STREAM:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          'Task is not initialized with the audio stream mode. Current running mode:'
 | 
				
			||||||
 | 
					          + self._running_mode.name)
 | 
				
			||||||
 | 
					    self._runner.send({
 | 
				
			||||||
 | 
					        sample_rate_stream_name:
 | 
				
			||||||
 | 
					            packet_creator.create_double(sample_rate).at(_Timestamp.PRESTREAM)
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
 | 
					  def _send_audio_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
 | 
				
			||||||
    """An asynchronous method to send audio stream data to the runner.
 | 
					    """An asynchronous method to send audio stream data to the runner.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,3 +94,13 @@ py_library(
 | 
				
			||||||
        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "classification_result",
 | 
				
			||||||
 | 
					    srcs = ["classification_result.py"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":category",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:optional_dependencies",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,7 @@ import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class AudioFormat:
 | 
					class AudioDataFormat:
 | 
				
			||||||
  """Audio format metadata.
 | 
					  """Audio format metadata.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Attributes:
 | 
					  Attributes:
 | 
				
			||||||
| 
						 | 
					@ -35,8 +35,10 @@ class AudioData(object):
 | 
				
			||||||
  """MediaPipe Tasks' audio container."""
 | 
					  """MediaPipe Tasks' audio container."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(
 | 
					  def __init__(
 | 
				
			||||||
      self, buffer_length: int,
 | 
					      self,
 | 
				
			||||||
      audio_format: AudioFormat = AudioFormat()) -> None:
 | 
					      buffer_length: int,
 | 
				
			||||||
 | 
					      audio_format: AudioDataFormat = AudioDataFormat()
 | 
				
			||||||
 | 
					  ) -> None:
 | 
				
			||||||
    """Initializes the `AudioData` object.
 | 
					    """Initializes the `AudioData` object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
| 
						 | 
					@ -113,14 +115,14 @@ class AudioData(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    obj = cls(
 | 
					    obj = cls(
 | 
				
			||||||
        buffer_length=src.shape[0],
 | 
					        buffer_length=src.shape[0],
 | 
				
			||||||
        audio_format=AudioFormat(
 | 
					        audio_format=AudioDataFormat(
 | 
				
			||||||
            num_channels=1 if len(src.shape) == 1 else src.shape[1],
 | 
					            num_channels=1 if len(src.shape) == 1 else src.shape[1],
 | 
				
			||||||
            sample_rate=sample_rate))
 | 
					            sample_rate=sample_rate))
 | 
				
			||||||
    obj.load_from_array(src)
 | 
					    obj.load_from_array(src)
 | 
				
			||||||
    return obj
 | 
					    return obj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @property
 | 
					  @property
 | 
				
			||||||
  def audio_format(self) -> AudioFormat:
 | 
					  def audio_format(self) -> AudioDataFormat:
 | 
				
			||||||
    """Gets the audio format of the audio."""
 | 
					    """Gets the audio format of the audio."""
 | 
				
			||||||
    return self._audio_format
 | 
					    return self._audio_format
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,92 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Classifications data class."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					from typing import List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import category as category_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_ClassificationsProto = classifications_pb2.Classifications
 | 
				
			||||||
 | 
					_ClassificationResultProto = classifications_pb2.ClassificationResult
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class Classifications:
 | 
				
			||||||
 | 
					  """Represents the classification results for a given classifier head.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    categories: The array of predicted categories, usually sorted by descending
 | 
				
			||||||
 | 
					      scores (e.g. from high to low probability).
 | 
				
			||||||
 | 
					    head_index: The index of the classifier head these categories refer to. This
 | 
				
			||||||
 | 
					      is useful for multi-head models.
 | 
				
			||||||
 | 
					    head_name: The name of the classifier head, which is the corresponding
 | 
				
			||||||
 | 
					      tensor metadata name.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  categories: List[category_module.Category]
 | 
				
			||||||
 | 
					  head_index: int
 | 
				
			||||||
 | 
					  head_name: Optional[str] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
 | 
				
			||||||
 | 
					    """Creates a `Classifications` object from the given protobuf object."""
 | 
				
			||||||
 | 
					    categories = []
 | 
				
			||||||
 | 
					    for entry in pb2_obj.classification_list.classification:
 | 
				
			||||||
 | 
					      categories.append(
 | 
				
			||||||
 | 
					          category_module.Category(
 | 
				
			||||||
 | 
					              index=entry.index,
 | 
				
			||||||
 | 
					              score=entry.score,
 | 
				
			||||||
 | 
					              display_name=entry.display_name,
 | 
				
			||||||
 | 
					              category_name=entry.label))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return Classifications(
 | 
				
			||||||
 | 
					        categories=categories,
 | 
				
			||||||
 | 
					        head_index=pb2_obj.head_index,
 | 
				
			||||||
 | 
					        head_name=pb2_obj.head_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class ClassificationResult:
 | 
				
			||||||
 | 
					  """Contains the classification results of a model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    classifications: A list of `Classifications` objects, each for a head of the
 | 
				
			||||||
 | 
					      model.
 | 
				
			||||||
 | 
					    timestamp_ms: The optional timestamp (in milliseconds) of the start of the
 | 
				
			||||||
 | 
					      chunk of data corresponding to these results. This is only used for
 | 
				
			||||||
 | 
					      classification on time series (e.g. audio classification). In these use
 | 
				
			||||||
 | 
					      cases, the amount of data to process might exceed the maximum size that
 | 
				
			||||||
 | 
					      the model can process: to solve this, the input data is split into
 | 
				
			||||||
 | 
					      multiple chunks starting at different timestamps.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  classifications: List[Classifications]
 | 
				
			||||||
 | 
					  timestamp_ms: Optional[int] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @classmethod
 | 
				
			||||||
 | 
					  @doc_controls.do_not_generate_docs
 | 
				
			||||||
 | 
					  def create_from_pb2(
 | 
				
			||||||
 | 
					      cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
 | 
				
			||||||
 | 
					    """Creates a `ClassificationResult` object from the given protobuf object.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return ClassificationResult(
 | 
				
			||||||
 | 
					        classifications=[
 | 
				
			||||||
 | 
					            Classifications.create_from_pb2(classification)
 | 
				
			||||||
 | 
					            for classification in pb2_obj.classifications
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        timestamp_ms=pb2_obj.timestamp_ms)
 | 
				
			||||||
							
								
								
									
										37
									
								
								mediapipe/tasks/python/test/audio/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								mediapipe/tasks/python/test/audio/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,37 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Placeholder for internal Python strict test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "audio_classifier_test",
 | 
				
			||||||
 | 
					    srcs = ["audio_classifier_test.py"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/audio:test_audio_clips",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/audio:test_models",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/audio:audio_classifier",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/audio/core:audio_task_running_mode",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:audio_data",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/containers:classification_result",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/components/processors:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/python/test:test_utils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										13
									
								
								mediapipe/tasks/python/test/audio/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mediapipe/tasks/python/test/audio/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
							
								
								
									
										381
									
								
								mediapipe/tasks/python/test/audio/audio_classifier_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										381
									
								
								mediapipe/tasks/python/test/audio/audio_classifier_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,381 @@
 | 
				
			||||||
 | 
					# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Tests for audio classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from typing import List, Tuple
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl.testing import absltest
 | 
				
			||||||
 | 
					from absl.testing import parameterized
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from scipy.io import wavfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.audio import audio_classifier
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.audio.core import audio_task_running_mode
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import audio_data as audio_data_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.components.processors import classifier_options
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_AudioClassifier = audio_classifier.AudioClassifier
 | 
				
			||||||
 | 
					_AudioClassifierOptions = audio_classifier.AudioClassifierOptions
 | 
				
			||||||
 | 
					_AudioClassifierResult = classification_result_module.ClassificationResult
 | 
				
			||||||
 | 
					_AudioData = audio_data_module.AudioData
 | 
				
			||||||
 | 
					_BaseOptions = base_options_module.BaseOptions
 | 
				
			||||||
 | 
					_ClassifierOptions = classifier_options.ClassifierOptions
 | 
				
			||||||
 | 
					_RUNNING_MODE = audio_task_running_mode.AudioTaskRunningMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_YAMNET_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite'
 | 
				
			||||||
 | 
					_YAMNET_MODEL_SAMPLE_RATE = 16000
 | 
				
			||||||
 | 
					_TWO_HEADS_MODEL_FILE = 'two_heads.tflite'
 | 
				
			||||||
 | 
					_SPEECH_WAV_16K_MONO = 'speech_16000_hz_mono.wav'
 | 
				
			||||||
 | 
					_SPEECH_WAV_48K_MONO = 'speech_48000_hz_mono.wav'
 | 
				
			||||||
 | 
					_TEST_DATA_DIR = 'mediapipe/tasks/testdata/audio'
 | 
				
			||||||
 | 
					_TWO_HEADS_WAV_16K_MONO = 'two_heads_16000_hz_mono.wav'
 | 
				
			||||||
 | 
					_TWO_HEADS_WAV_44K_MONO = 'two_heads_44100_hz_mono.wav'
 | 
				
			||||||
 | 
					_YAMNET_NUM_OF_SAMPLES = 15600
 | 
				
			||||||
 | 
					_MILLSECONDS_PER_SECOND = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AudioClassifierTest(parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    self.yamnet_model_path = test_utils.get_test_data_path(
 | 
				
			||||||
 | 
					        os.path.join(_TEST_DATA_DIR, _YAMNET_MODEL_FILE))
 | 
				
			||||||
 | 
					    self.two_heads_model_path = test_utils.get_test_data_path(
 | 
				
			||||||
 | 
					        os.path.join(_TEST_DATA_DIR, _TWO_HEADS_MODEL_FILE))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _read_wav_file(self, file_name) -> _AudioData:
 | 
				
			||||||
 | 
					    sample_rate, buffer = wavfile.read(
 | 
				
			||||||
 | 
					        test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
 | 
				
			||||||
 | 
					    return _AudioData.create_from_array(
 | 
				
			||||||
 | 
					        buffer.astype(float) / np.iinfo(np.int16).max, sample_rate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _read_wav_file_as_stream(self, file_name) -> List[Tuple[_AudioData, int]]:
 | 
				
			||||||
 | 
					    sample_rate, buffer = wavfile.read(
 | 
				
			||||||
 | 
					        test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_name)))
 | 
				
			||||||
 | 
					    audio_data_list = []
 | 
				
			||||||
 | 
					    start = 0
 | 
				
			||||||
 | 
					    step_size = _YAMNET_NUM_OF_SAMPLES * sample_rate / _YAMNET_MODEL_SAMPLE_RATE
 | 
				
			||||||
 | 
					    while start < len(buffer):
 | 
				
			||||||
 | 
					      end = min(start + (int)(step_size), len(buffer))
 | 
				
			||||||
 | 
					      audio_data_list.append((_AudioData.create_from_array(
 | 
				
			||||||
 | 
					          buffer[start:end].astype(float) / np.iinfo(np.int16).max,
 | 
				
			||||||
 | 
					          sample_rate), (int)(start / sample_rate * _MILLSECONDS_PER_SECOND)))
 | 
				
			||||||
 | 
					      start = end
 | 
				
			||||||
 | 
					    return audio_data_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: Compares the exact score values to capture unexpected
 | 
				
			||||||
 | 
					  # changes in the inference pipeline.
 | 
				
			||||||
 | 
					  def _check_yamnet_result(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      classification_result_list: List[_AudioClassifierResult],
 | 
				
			||||||
 | 
					      expected_num_categories=521):
 | 
				
			||||||
 | 
					    self.assertLen(classification_result_list, 5)
 | 
				
			||||||
 | 
					    for idx, timestamp in enumerate([0, 975, 1950, 2925]):
 | 
				
			||||||
 | 
					      classification_result = classification_result_list[idx]
 | 
				
			||||||
 | 
					      self.assertEqual(classification_result.timestamp_ms, timestamp)
 | 
				
			||||||
 | 
					      self.assertLen(classification_result.classifications, 1)
 | 
				
			||||||
 | 
					      classifcation = classification_result.classifications[0]
 | 
				
			||||||
 | 
					      self.assertEqual(classifcation.head_index, 0)
 | 
				
			||||||
 | 
					      self.assertEqual(classifcation.head_name, 'scores')
 | 
				
			||||||
 | 
					      self.assertLen(classifcation.categories, expected_num_categories)
 | 
				
			||||||
 | 
					      audio_category = classifcation.categories[0]
 | 
				
			||||||
 | 
					      self.assertEqual(audio_category.index, 0)
 | 
				
			||||||
 | 
					      self.assertEqual(audio_category.category_name, 'Speech')
 | 
				
			||||||
 | 
					      self.assertGreater(audio_category.score, 0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: Compares the exact score values to capture unexpected
 | 
				
			||||||
 | 
					  # changes in the inference pipeline.
 | 
				
			||||||
 | 
					  def _check_two_heads_result(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      classification_result_list: List[_AudioClassifierResult],
 | 
				
			||||||
 | 
					      first_head_expected_num_categories=521,
 | 
				
			||||||
 | 
					      second_head_expected_num_categories=5):
 | 
				
			||||||
 | 
					    self.assertGreaterEqual(len(classification_result_list), 1)
 | 
				
			||||||
 | 
					    self.assertLessEqual(len(classification_result_list), 2)
 | 
				
			||||||
 | 
					    # Checks the first result.
 | 
				
			||||||
 | 
					    classification_result = classification_result_list[0]
 | 
				
			||||||
 | 
					    self.assertEqual(classification_result.timestamp_ms, 0)
 | 
				
			||||||
 | 
					    self.assertLen(classification_result.classifications, 2)
 | 
				
			||||||
 | 
					    # Checks the first head.
 | 
				
			||||||
 | 
					    yamnet_classifcation = classification_result.classifications[0]
 | 
				
			||||||
 | 
					    self.assertEqual(yamnet_classifcation.head_index, 0)
 | 
				
			||||||
 | 
					    self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
 | 
				
			||||||
 | 
					    self.assertLen(yamnet_classifcation.categories,
 | 
				
			||||||
 | 
					                   first_head_expected_num_categories)
 | 
				
			||||||
 | 
					    # Checks the second head.
 | 
				
			||||||
 | 
					    yamnet_category = yamnet_classifcation.categories[0]
 | 
				
			||||||
 | 
					    self.assertEqual(yamnet_category.index, 508)
 | 
				
			||||||
 | 
					    self.assertEqual(yamnet_category.category_name, 'Environmental noise')
 | 
				
			||||||
 | 
					    self.assertGreater(yamnet_category.score, 0.5)
 | 
				
			||||||
 | 
					    bird_classifcation = classification_result.classifications[1]
 | 
				
			||||||
 | 
					    self.assertEqual(bird_classifcation.head_index, 1)
 | 
				
			||||||
 | 
					    self.assertEqual(bird_classifcation.head_name, 'bird_classification')
 | 
				
			||||||
 | 
					    self.assertLen(bird_classifcation.categories,
 | 
				
			||||||
 | 
					                   second_head_expected_num_categories)
 | 
				
			||||||
 | 
					    bird_category = bird_classifcation.categories[0]
 | 
				
			||||||
 | 
					    self.assertEqual(bird_category.index, 4)
 | 
				
			||||||
 | 
					    self.assertEqual(bird_category.category_name, 'Chestnut-crowned Antpitta')
 | 
				
			||||||
 | 
					    self.assertGreater(bird_category.score, 0.93)
 | 
				
			||||||
 | 
					    # Checks the second result, if present.
 | 
				
			||||||
 | 
					    if len(classification_result_list) == 2:
 | 
				
			||||||
 | 
					      classification_result = classification_result_list[1]
 | 
				
			||||||
 | 
					      self.assertEqual(classification_result.timestamp_ms, 975)
 | 
				
			||||||
 | 
					      self.assertLen(classification_result.classifications, 2)
 | 
				
			||||||
 | 
					      # Checks the first head.
 | 
				
			||||||
 | 
					      yamnet_classifcation = classification_result.classifications[0]
 | 
				
			||||||
 | 
					      self.assertEqual(yamnet_classifcation.head_index, 0)
 | 
				
			||||||
 | 
					      self.assertEqual(yamnet_classifcation.head_name, 'yamnet_classification')
 | 
				
			||||||
 | 
					      self.assertLen(yamnet_classifcation.categories,
 | 
				
			||||||
 | 
					                     first_head_expected_num_categories)
 | 
				
			||||||
 | 
					      yamnet_category = yamnet_classifcation.categories[0]
 | 
				
			||||||
 | 
					      self.assertEqual(yamnet_category.index, 494)
 | 
				
			||||||
 | 
					      self.assertEqual(yamnet_category.category_name, 'Silence')
 | 
				
			||||||
 | 
					      self.assertGreater(yamnet_category.score, 0.9)
 | 
				
			||||||
 | 
					      bird_classifcation = classification_result.classifications[1]
 | 
				
			||||||
 | 
					      self.assertEqual(bird_classifcation.head_index, 1)
 | 
				
			||||||
 | 
					      self.assertEqual(bird_classifcation.head_name, 'bird_classification')
 | 
				
			||||||
 | 
					      self.assertLen(bird_classifcation.categories,
 | 
				
			||||||
 | 
					                     second_head_expected_num_categories)
 | 
				
			||||||
 | 
					      # Checks the second head.
 | 
				
			||||||
 | 
					      bird_category = bird_classifcation.categories[0]
 | 
				
			||||||
 | 
					      self.assertEqual(bird_category.index, 1)
 | 
				
			||||||
 | 
					      self.assertEqual(bird_category.category_name, 'White-breasted Wood-Wren')
 | 
				
			||||||
 | 
					      self.assertGreater(bird_category.score, 0.99)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with default option and valid model file successfully.
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.yamnet_model_path) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _AudioClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model file successfully.
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
 | 
					            base_options=_BaseOptions(
 | 
				
			||||||
 | 
					                model_asset_path=self.yamnet_model_path))) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _AudioClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
 | 
					    # Invalid empty model path.
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        r"ExternalFile must specify at least one of 'file_content', "
 | 
				
			||||||
 | 
					        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path='')
 | 
				
			||||||
 | 
					      options = _AudioClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      _AudioClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_content(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model content successfully.
 | 
				
			||||||
 | 
					    with open(self.yamnet_model_path, 'rb') as f:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=f.read())
 | 
				
			||||||
 | 
					      options = _AudioClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      classifier = _AudioClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _AudioClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
 | 
				
			||||||
 | 
					  def test_classify_with_yamnet_model(self, audio_file):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.yamnet_model_path) as classifier:
 | 
				
			||||||
 | 
					      classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					          self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					      self._check_yamnet_result(classification_result_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_with_yamnet_model_and_inputs_at_different_sample_rates(
 | 
				
			||||||
 | 
					      self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.yamnet_model_path) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_48K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_yamnet_result(classification_result_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_max_result_options(self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
 | 
					            base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					            classifier_options=_ClassifierOptions(
 | 
				
			||||||
 | 
					                max_results=1))) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_yamnet_result(
 | 
				
			||||||
 | 
					            classification_result_list, expected_num_categories=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_score_threshold_options(self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
 | 
					            base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					            classifier_options=_ClassifierOptions(
 | 
				
			||||||
 | 
					                score_threshold=0.9))) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_yamnet_result(
 | 
				
			||||||
 | 
					            classification_result_list, expected_num_categories=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_allow_list_option(self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
 | 
					            base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					            classifier_options=_ClassifierOptions(
 | 
				
			||||||
 | 
					                category_allowlist=['Speech']))) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_SPEECH_WAV_16K_MONO, _SPEECH_WAV_16K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_yamnet_result(
 | 
				
			||||||
 | 
					            classification_result_list, expected_num_categories=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_combined_allowlist_and_denylist(self):
 | 
				
			||||||
 | 
					    # Fails with combined allowlist and denylist
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        r'`category_allowlist` and `category_denylist` are mutually '
 | 
				
			||||||
 | 
					        r'exclusive options.'):
 | 
				
			||||||
 | 
					      options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					          base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					          classifier_options=_ClassifierOptions(
 | 
				
			||||||
 | 
					              category_allowlist=['foo'], category_denylist=['bar']))
 | 
				
			||||||
 | 
					      with _AudioClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters((_TWO_HEADS_WAV_16K_MONO),
 | 
				
			||||||
 | 
					                            (_TWO_HEADS_WAV_44K_MONO))
 | 
				
			||||||
 | 
					  def test_classify_with_two_heads_model_and_inputs_at_different_sample_rates(
 | 
				
			||||||
 | 
					      self, audio_file):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.two_heads_model_path) as classifier:
 | 
				
			||||||
 | 
					      classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					          self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					      self._check_two_heads_result(classification_result_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_with_two_heads_model(self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_model_path(
 | 
				
			||||||
 | 
					        self.two_heads_model_path) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_two_heads_result(classification_result_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_with_two_heads_model_with_max_results(self):
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(
 | 
				
			||||||
 | 
					        _AudioClassifierOptions(
 | 
				
			||||||
 | 
					            base_options=_BaseOptions(
 | 
				
			||||||
 | 
					                model_asset_path=self.two_heads_model_path),
 | 
				
			||||||
 | 
					            classifier_options=_ClassifierOptions(
 | 
				
			||||||
 | 
					                max_results=1))) as classifier:
 | 
				
			||||||
 | 
					      for audio_file in [_TWO_HEADS_WAV_16K_MONO, _TWO_HEADS_WAV_44K_MONO]:
 | 
				
			||||||
 | 
					        classification_result_list = classifier.classify(
 | 
				
			||||||
 | 
					            self._read_wav_file(audio_file))
 | 
				
			||||||
 | 
					        self._check_two_heads_result(classification_result_list, 1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_missing_sample_rate_in_audio_clips_mode(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_CLIPS)
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'Must provide the audio sample rate'):
 | 
				
			||||||
 | 
					      with _AudioClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					        classifier.classify(_AudioData(buffer_length=100))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_missing_sample_rate_in_audio_stream_mode(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_STREAM,
 | 
				
			||||||
 | 
					        result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'provide the audio sample rate in audio data'):
 | 
				
			||||||
 | 
					      with _AudioClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					        classifier.classify(_AudioData(buffer_length=100))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_missing_result_callback(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_STREAM)
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'result callback must be provided'):
 | 
				
			||||||
 | 
					      with _AudioClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_illegal_result_callback(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_CLIPS,
 | 
				
			||||||
 | 
					        result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'result callback should not be provided'):
 | 
				
			||||||
 | 
					      with _AudioClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_in_audio_stream_mode(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_STREAM,
 | 
				
			||||||
 | 
					        result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the audio clips mode'):
 | 
				
			||||||
 | 
					        classifier.classify(self._read_wav_file(_SPEECH_WAV_16K_MONO))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_async_in_audio_clips_mode(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_CLIPS)
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					          ValueError, r'not initialized with the audio stream mode'):
 | 
				
			||||||
 | 
					        classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_async_calls_with_illegal_timestamp(self):
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_STREAM,
 | 
				
			||||||
 | 
					        result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with _AudioClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 100)
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					          ValueError, r'Input timestamp must be monotonically increasing'):
 | 
				
			||||||
 | 
					        classifier.classify_async(self._read_wav_file(_SPEECH_WAV_16K_MONO), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters((_SPEECH_WAV_16K_MONO), (_SPEECH_WAV_48K_MONO))
 | 
				
			||||||
 | 
					  def test_classify_async(self, audio_file):
 | 
				
			||||||
 | 
					    classification_result_list = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_result(result: _AudioClassifierResult, timestamp_ms: int):
 | 
				
			||||||
 | 
					      result.timestamp_ms = timestamp_ms
 | 
				
			||||||
 | 
					      classification_result_list.append(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    options = _AudioClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.yamnet_model_path),
 | 
				
			||||||
 | 
					        running_mode=_RUNNING_MODE.AUDIO_STREAM,
 | 
				
			||||||
 | 
					        classifier_options=_ClassifierOptions(max_results=1),
 | 
				
			||||||
 | 
					        result_callback=save_result)
 | 
				
			||||||
 | 
					    classifier = _AudioClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					    audio_data_list = self._read_wav_file_as_stream(audio_file)
 | 
				
			||||||
 | 
					    for audio_data, timestamp_ms in audio_data_list:
 | 
				
			||||||
 | 
					      classifier.classify_async(audio_data, timestamp_ms)
 | 
				
			||||||
 | 
					    classifier.close()
 | 
				
			||||||
 | 
					    self._check_yamnet_result(
 | 
				
			||||||
 | 
					        classification_result_list, expected_num_categories=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  absltest.main()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user