180 lines
8.2 KiB
C++
180 lines
8.2 KiB
C++
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_CLASSIFIER_AUDIO_CLASSIFIER_H_
|
|
#define MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_CLASSIFIER_AUDIO_CLASSIFIER_H_
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/status/statusor.h"
|
|
#include "mediapipe/framework/formats/matrix.h"
|
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
|
|
|
namespace mediapipe {
|
|
namespace tasks {
|
|
namespace audio {
|
|
namespace audio_classifier {
|
|
|
|
// Alias the shared ClassificationResult struct as result type.
|
|
using AudioClassifierResult =
|
|
::mediapipe::tasks::components::containers::ClassificationResult;
|
|
|
|
// The options for configuring a mediapipe audio classifier task.
|
|
struct AudioClassifierOptions {
|
|
// Base options for configuring Task library, such as specifying the TfLite
|
|
// model file with metadata, accelerator options, op resolver, etc.
|
|
tasks::core::BaseOptions base_options;
|
|
|
|
// Options for configuring the classifier behavior, such as score threshold,
|
|
// number of results, etc.
|
|
components::processors::ClassifierOptions classifier_options;
|
|
|
|
// The running mode of the audio classifier. Default to the audio clips mode.
|
|
// Audio classifier has two running modes:
|
|
// 1) The audio clips mode for running classification on independent audio
|
|
// clips.
|
|
// 2) The audio stream mode for running classification on the audio stream,
|
|
// such as from microphone. In this mode, the "result_callback" below must
|
|
// be specified to receive the classification results asynchronously.
|
|
core::RunningMode running_mode = core::RunningMode::AUDIO_CLIPS;
|
|
|
|
// The user-defined result callback for processing audio stream data.
|
|
// The result callback should only be specified when the running mode is set
|
|
// to RunningMode::AUDIO_STREAM.
|
|
std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
|
|
nullptr;
|
|
};
|
|
|
|
// Performs audio classification on audio clips or audio stream.
|
|
//
|
|
// This API expects a TFLite model with mandatory TFLite Model Metadata that
|
|
// contains the mandatory AudioProperties of the solo input audio tensor and the
|
|
// optional (but recommended) label items as AssociatedFiles with type
|
|
// TENSOR_AXIS_LABELS per output classification tensor.
|
|
//
|
|
// Input tensor:
|
|
// (kTfLiteFloat32)
|
|
// - input audio buffer of size `[batch * samples]`.
|
|
// - batch inference is not supported (`batch` is required to be 1).
|
|
// - for multi-channel models, the channels need be interleaved.
|
|
// At least one output tensor with:
|
|
// (kTfLiteFloat32)
|
|
// - `[1 x N]` array with `N` represents the number of categories.
|
|
// - optional (but recommended) label items as AssociatedFiles with type
|
|
// TENSOR_AXIS_LABELS, containing one label per line. The first such
|
|
// AssociatedFile (if any) is used to fill the `category_name` field of the
|
|
// results. The `display_name` field is filled from the AssociatedFile (if
|
|
// any) whose locale matches the `display_names_locale` field of the
|
|
// `AudioClassifierOptions` used at creation time ("en" by default, i.e.
|
|
// English). If none of these are available, only the `index` field of the
|
|
// results will be filled.
|
|
// TODO: Create an audio container to replace the matrix, the
|
|
// sample rate, and the timestamp.
|
|
class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|
public:
|
|
using BaseAudioTaskApi::BaseAudioTaskApi;
|
|
|
|
// Creates an AudioClassifier to process either audio clips (e.g., audio
|
|
// files) or audio stream data (e.g., microphone live input). Audio classifier
|
|
// can be created with one of following two running modes:
|
|
// 1) Audio clips mode for running audio classification on audio clips.
|
|
// Users feed audio clips to the `Classify` method, and will
|
|
// receive the classification results as the return value.
|
|
// 2) Audio stream mode for running audio classification on the audio stream,
|
|
// such as from microphone. Users call `ClassifyAsync` to push the audio
|
|
// data into the AudioClassifier, the classification results will be
|
|
// available in the result callback when the audio classifier finishes the
|
|
// work.
|
|
static absl::StatusOr<std::unique_ptr<AudioClassifier>> Create(
|
|
std::unique_ptr<AudioClassifierOptions> options);
|
|
|
|
// Performs audio classification on the provided audio clip. Only use this
|
|
// method when the AudioClassifier is created with the audio clips running
|
|
// mode.
|
|
//
|
|
// The audio clip is represented as a MediaPipe Matrix that has the number of
|
|
// channels rows and the number of samples per channel columns. The method
|
|
// accepts audio clips with various length and audio sample rate. It's
|
|
// required to provide the corresponding audio sample rate along with the
|
|
// input audio clips.
|
|
//
|
|
// The input audio clip may be longer than what the model is able to process
|
|
// in a single inference. When this occurs, the input audio clip is split into
|
|
// multiple chunks starting at different timestamps. For this reason, this
|
|
// function returns a vector of ClassificationResult objects, each associated
|
|
// with a timestamp corresponding to the start (in milliseconds) of the chunk
|
|
// data that was classified, e.g:
|
|
//
|
|
// ClassificationResult #0 (first chunk of data):
|
|
// timestamp_ms: 0 (starts at 0ms)
|
|
// classifications #0 (single head model):
|
|
// category #0:
|
|
// category_name: "Speech"
|
|
// score: 0.6
|
|
// category #1:
|
|
// category_name: "Music"
|
|
// score: 0.2
|
|
// ClassificationResult #1 (second chunk of data):
|
|
// timestamp_ms: 800 (starts at 800ms)
|
|
// classifications #0 (single head model):
|
|
// category #0:
|
|
// category_name: "Speech"
|
|
// score: 0.5
|
|
// category #1:
|
|
// category_name: "Silence"
|
|
// score: 0.1
|
|
// ...
|
|
//
|
|
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
|
// and makes `audio_sample_rate` optional.
|
|
absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
|
|
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
|
|
|
// Sends audio data (a block in a continuous audio stream) to perform audio
|
|
// classification. Only use this method when the AudioClassifier is created
|
|
// with the audio stream running mode.
|
|
//
|
|
// The audio block is represented as a MediaPipe Matrix that has the number
|
|
// of channels rows and the number of samples per channel columns. The audio
|
|
// data will be resampled, accumulated, and framed to the proper size for the
|
|
// underlying model to consume. It's required to provide the corresponding
|
|
// audio sample rate along with the input audio block as well as a timestamp
|
|
// (in milliseconds) to indicate the start time of the input audio block. The
|
|
// timestamps must be monotonically increasing.
|
|
//
|
|
// The input audio block may be longer than what the model is able to process
|
|
// in a single inference. When this occurs, the input audio block is split
|
|
// into multiple chunks. For this reason, the callback may be called multiple
|
|
// times (once per chunk) for each call to this function.
|
|
absl::Status ClassifyAsync(mediapipe::Matrix audio_block,
|
|
double audio_sample_rate, int64 timestamp_ms);
|
|
|
|
// Shuts down the AudioClassifier when all works are done.
|
|
absl::Status Close() { return runner_->Close(); }
|
|
};
|
|
|
|
} // namespace audio_classifier
|
|
} // namespace audio
|
|
} // namespace tasks
|
|
} // namespace mediapipe
|
|
|
|
#endif // MEDIAPIPE_TASKS_CC_AUDIO_AUDIO_CLASSIFIER_AUDIO_CLASSIFIER_H_
|