Migrate AudioClassifier C++ to use new ClassificationResult struct.

PiperOrigin-RevId: 486162683
This commit is contained in:
MediaPipe Team 2022-11-04 09:44:07 -07:00 committed by Copybara-Service
parent 93a587a422
commit 8b2c937b9e
5 changed files with 222 additions and 178 deletions

View File

@ -65,6 +65,7 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -38,12 +40,16 @@ namespace audio_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsName[] = "classifications_out";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications_out";
constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
} }
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap( subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
.SetName(kClassificationResultStreamName) >> graph.Out(kClassificationsTag);
graph.Out(kClassificationResultTag); subgraph.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph.Out(kTimestampedClassificationsTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
return options_proto; return options_proto;
} }
absl::StatusOr<ClassificationResult> ConvertOutputPackets( absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) { absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) { if (!status_or_packets.ok()) {
return status_or_packets.status(); return status_or_packets.status();
} }
return status_or_packets.value()[kClassificationResultStreamName] auto classification_results =
.Get<ClassificationResult>(); status_or_packets.value()[kTimestampedClassificationsName]
.Get<std::vector<ClassificationResult>>();
std::vector<AudioClassifierResult> results;
results.reserve(classification_results.size());
for (const auto& classification_result : classification_results) {
results.emplace_back(ConvertToClassificationResult(classification_result));
}
return results;
}
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return ConvertToClassificationResult(
status_or_packets.value()[kClassificationsName]
.Get<ClassificationResult>());
} }
} // namespace } // namespace
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
auto result_callback = options->result_callback; auto result_callback = options->result_callback;
packets_callback = packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
result_callback(ConvertOutputPackets(status_or_packets)); result_callback(ConvertAsyncOutputPackets(status_or_packets));
}; };
} }
return core::AudioTaskApiFactory::Create<AudioClassifier, return core::AudioTaskApiFactory::Create<AudioClassifier,
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<ClassificationResult> AudioClassifier::Classify( absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
Matrix audio_clip, double audio_sample_rate) { Matrix audio_clip, double audio_sample_rate) {
return ConvertOutputPackets(ProcessAudioClip( return ConvertOutputPackets(ProcessAudioClip(
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))}, {{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},

View File

@ -18,12 +18,13 @@ limitations under the License.
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
@ -32,6 +33,10 @@ namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier { namespace audio_classifier {
// Alias the shared ClassificationResult struct as result type.
using AudioClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a mediapipe audio classifier task. // The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions { struct AudioClassifierOptions {
// Base options for configuring Task library, such as specifying the TfLite // Base options for configuring Task library, such as specifying the TfLite
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data. // The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM. // to RunningMode::AUDIO_STREAM.
std::function<void( std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
absl::StatusOr<components::containers::proto::ClassificationResult>)> nullptr;
result_callback = nullptr;
}; };
// Performs audio classification on audio clips or audio stream. // Performs audio classification on audio clips or audio stream.
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// required to provide the corresponding audio sample rate along with the // required to provide the corresponding audio sample rate along with the
// input audio clips. // input audio clips.
// //
// For each audio clip, the output classifications are grouped in a // The input audio clip may be longer than what the model is able to process
// ClassificationResult object that has three dimensions: // in a single inference. When this occurs, the input audio clip is split into
// Classification head: // multiple chunks starting at different timestamps. For this reason, this
// The prediction heads targeting different audio classification tasks // function returns a vector of ClassificationResult objects, each associated
// such as audio event classification and bird sound classification. // with a timestamp corresponding to the start (in milliseconds) of the chunk
// Classification timestamp: // data that was classified, e.g:
// The start time (in milliseconds) of each audio clip that is sent to the //
// model for audio classification. As the audio classification models take // ClassificationResult #0 (first chunk of data):
// a fixed number of audio samples, long audio clips will be framed to // timestamp_ms: 0 (starts at 0ms)
// multiple buffers (with the desired number of audio samples) during // classifications #0 (single head model):
// preprocessing. // category #0:
// Classification category: // category_name: "Speech"
// The list of the classification categories that model predicts per // score: 0.6
// framed audio clip. // category #1:
// category_name: "Music"
// score: 0.2
// ClassificationResult #1 (second chunk of data):
// timestamp_ms: 800 (starts at 800ms)
// classifications #0 (single head model):
// category #0:
// category_name: "Speech"
// score: 0.5
// category #1:
// category_name: "Silence"
// score: 0.1
// ...
//
// TODO: Use `sample_rate` in AudioClassifierOptions by default // TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional. // and makes `audio_sample_rate` optional.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
mediapipe::Matrix audio_clip, double audio_sample_rate); mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio // Sends audio data (a block in a continuous audio stream) to perform audio
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// milliseconds) to indicate the start time of the input audio block. The // milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing. // timestamps must be monotonically increasing.
// //
// The output classifications are grouped in a ClassificationResult object // The input audio block may be longer than what the model is able to process
// that has three dimensions: // in a single inference. When this occurs, the input audio block is split
// Classification head: // into multiple chunks. For this reason, the callback may be called multiple
// The prediction heads targeting different audio classification tasks // times (once per chunk) for each call to this function.
// such as audio event classification and bird sound classification.
// Classification timestamp :
// The start time (in milliseconds) of the framed audio block that is sent
// to the model for audio classification.
// Classification category:
// The list of the classification categories that model predicts per
// framed audio clip.
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms); absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
// Shuts down the AudioClassifier when all works are done. // Shuts down the AudioClassifier when all works are done.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kPacketTag[] = "PACKET"; constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the audio classifier
// graph.
struct AudioClassifierOutputStreams {
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
absl::Status SanityCheckOptions( absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) { const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() && if (options.base_options().use_stream_mode() &&
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
// series stream header with sample rate info. // series stream header with sample rate info.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions: // The classification results aggregated by head. Only produces results if
// (classification head, classification timestamp, classification category). // the graph if the 'use_stream_mode' option is true.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Only
// produces results if the graph if the 'use_stream_mode' option is false.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in" // input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in" // input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATIONS:classifications"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// { // {
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
.base_options() .base_options()
.use_stream_mode(); .use_stream_mode();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto output_streams,
BuildAudioClassificationTask( BuildAudioClassificationTask(
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources, sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)], graph[Input<Matrix>(kAudioTag)],
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
? absl::nullopt ? absl::nullopt
: absl::make_optional(graph[Input<double>(kSampleRateTag)]), : absl::make_optional(graph[Input<double>(kSampleRateTag)]),
graph)); graph));
classification_result_out >> output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// audio_in: (mediapipe::Matrix) stream to run audio classification on. // audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate. // sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask( absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
const proto::AudioClassifierGraphOptions& task_options, const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in, const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) { absl::optional<Source<double>> sample_rate_in, Graph& graph) {
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on
// audio files. Disables time aggregration by not connecting the // audio files. Disables timestamp aggregation by not connecting the
// "TIMESTAMPS" streams. // "TIMESTAMPS" streams.
if (!use_stream_mode) { if (!use_stream_mode) {
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag); audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
} }
// Outputs the aggregated classification result as the subgraph output // Output both streams as graph output streams/
// stream. return AudioClassifierOutputStreams{
return postprocessing[Output<ClassificationResult>( /*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)]; kClassificationsTag)],
/*timestamped_classifications=*/
postprocessing[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)],
};
} }
}; };

View File

@ -32,13 +32,11 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -49,7 +47,6 @@ namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix(); return matrix_mapping.matrix();
} }
void CheckSpeechClassificationResult(const ClassificationResult& result) { void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); int expected_num_categories = 521) {
EXPECT_EQ(result.classifications(0).head_name(), "scores"); EXPECT_EQ(result.size(), 5);
EXPECT_EQ(result.classifications(0).head_index(), 0); // Ignore last result, which operates on a too small chunk to return relevant
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5)); // results.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925}; std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < timestamps_ms.size(); i++) { for (int i = 0; i < timestamps_ms.size(); i++) {
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
testing::Eq(521)); EXPECT_EQ(result[i].classifications.size(), 1);
const auto* top_category = auto classifications = result[i].classifications[0];
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(classifications.head_index, 0);
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech")); EXPECT_EQ(classifications.head_name, "scores");
EXPECT_GT(top_category->score(), 0.9f); EXPECT_EQ(classifications.categories.size(), expected_num_categories);
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(), auto category = classifications.categories[0];
timestamps_ms[i]); EXPECT_EQ(category.index, 0);
EXPECT_EQ(category.category_name, "Speech");
EXPECT_GT(category.score, 0.9f);
} }
} }
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) { void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_THAT(result.classifications_size(), testing::Eq(2)); EXPECT_GE(result.size(), 1);
// Checks classification head #1. EXPECT_LE(result.size(), 2);
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification"); // Check first result.
EXPECT_EQ(result.classifications(0).head_index(), 0); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 2);
testing::Eq(521)); // Check first head.
const auto* top_category = EXPECT_EQ(result[0].classifications[0].head_index, 0);
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
EXPECT_THAT(top_category->category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Environmental noise")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
EXPECT_GT(top_category->score(), 0.5f); EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0); "Environmental noise");
if (result.classifications(0).entries_size() == 2) { EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
top_category = &result.classifications(0).entries(1).categories(0); // Check second head.
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[1].head_index, 1);
EXPECT_GT(top_category->score(), 0.99f); EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975); EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
"Chestnut-crowned Antpitta");
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
// Check second result, if present.
if (result.size() == 2) {
EXPECT_EQ(result[1].timestamp_ms, 975);
EXPECT_EQ(result[1].classifications.size(), 2);
// Check first head.
EXPECT_EQ(result[1].classifications[0].head_index, 0);
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
"Silence");
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
// Check second head.
EXPECT_EQ(result[1].classifications[1].head_index, 1);
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
"White-breasted Wood-Wren");
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
} }
// Checks classification head #2.
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
EXPECT_EQ(result.classifications(1).head_index(), 1);
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
testing::Eq(5));
top_category = &result.classifications(1).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(),
testing::Eq("Chestnut-crowned Antpitta"));
EXPECT_GT(top_category->score(), 0.9f);
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
} }
ClassificationResult GenerateSpeechClassificationResult() { void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
return ParseTextProtoOrDie<ClassificationResult>( EXPECT_EQ(outputs.size(), 5);
R"pb(classifications { // Ignore last result, which operates on a too small chunk to return relevant
head_index: 0 // results.
head_name: "scores" for (int i = 0; i < outputs.size() - 1; i++) {
entries { EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
categories { index: 0 score: 0.94140625 category_name: "Speech" } EXPECT_EQ(outputs[i].classifications.size(), 1);
timestamp_ms: 0 EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
"Speech");
EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
} }
entries {
categories { index: 0 score: 0.9921875 category_name: "Speech" }
timestamp_ms: 975
}
entries {
categories { index: 0 score: 0.98828125 category_name: "Speech" }
timestamp_ms: 1950
}
entries {
categories { index: 0 score: 0.99609375 category_name: "Speech" }
timestamp_ms: 2925
}
entries {
# categories are filtered out due to the low scores.
timestamp_ms: 3900
}
})pb");
}
void CheckStreamingModeClassificationResult(
std::vector<ClassificationResult> outputs) {
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
auto expected_results = GenerateSpeechClassificationResult();
for (int i = 0; i < outputs.size() - 1; ++i) {
EXPECT_THAT(outputs[i].classifications(0).entries(0),
EqualsProto(expected_results.classifications(0).entries(i)));
}
int last_elem_index = outputs.size() - 1;
EXPECT_EQ(
mediapipe::Timestamp::Done().Value() / 1000,
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
} }
class CreateFromOptionsTest : public tflite_shims::testing::Test {}; class CreateFromOptionsTest : public tflite_shims::testing::Test {};
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithResampling) { TEST_F(ClassifyTest, SucceedsWithResampling) {
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
CheckSpeechClassificationResult(result_16k_hz); CheckSpeechResult(result_16k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_48k_hz, auto result_48k_hz,
audio_classifier->Classify(std::move(audio_buffer_48k_hz), audio_classifier->Classify(std::move(audio_buffer_48k_hz),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result_48k_hz); CheckSpeechResult(result_48k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithInsufficientData) { TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000)); auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); EXPECT_EQ(result.size(), 1);
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1)); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 1);
testing::Eq(521)); EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_THAT( EXPECT_EQ(result[0].classifications[0].head_name, "scores");
result.classifications(0).entries(0).categories(0).category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(), EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
testing::FloatEq(.800781f)); "Silence");
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, TEST_F(ClassifyTest,
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
auto result_44k_hz, auto result_44k_hz,
audio_classifier->Classify(std::move(audio_buffer_44k_hz), audio_classifier->Classify(std::move(audio_buffer_44k_hz),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
CheckTwoHeadsClassificationResult(result_44k_hz); CheckTwoHeadsResult(result_44k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result_16k_hz); CheckTwoHeadsResult(result_16k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.35f;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
// All categroies with the "Speech" label are filtered out. // All categories with the "Speech" label are filtered out.
EXPECT_THAT(result, EqualsProto(R"pb(classifications { std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
head_index: 0 for (int i = 0; i < timestamps_ms.size(); i++) {
head_name: "scores" EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
entries { timestamp_ms: 0 } EXPECT_EQ(result[i].classifications.size(), 1);
entries { timestamp_ms: 975 } auto classifications = result[i].classifications[0];
entries { timestamp_ms: 1950 } EXPECT_EQ(classifications.head_index, 0);
entries { timestamp_ms: 2925 } EXPECT_EQ(classifications.head_name, "scores");
entries { timestamp_ms: 3900 } EXPECT_TRUE(classifications.categories.empty());
})pb")); }
} }
class ClassifyAsyncTest : public tflite_shims::testing::Test {}; class ClassifyAsyncTest : public tflite_shims::testing::Test {};
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
start_col += kYamnetNumOfAudioSamples * 3; start_col += kYamnetNumOfAudioSamples * 3;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
start_col += num_samples; start_col += num_samples;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
} // namespace } // namespace