Migrate AudioClassifier C++ to use new ClassificationResult struct.

PiperOrigin-RevId: 486162683
This commit is contained in:
MediaPipe Team 2022-11-04 09:44:07 -07:00 committed by Copybara-Service
parent 93a587a422
commit 8b2c937b9e
5 changed files with 222 additions and 178 deletions

View File

@ -65,6 +65,7 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -38,12 +40,16 @@ namespace audio_classifier {
namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications_out";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications_out";
constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] =
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
}
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph.Out(kClassificationsTag);
subgraph.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph.Out(kTimestampedClassificationsTag);
return graph.GetConfig();
}
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
return options_proto;
}
absl::StatusOr<ClassificationResult> ConvertOutputPackets(
absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return status_or_packets.value()[kClassificationResultStreamName]
.Get<ClassificationResult>();
auto classification_results =
status_or_packets.value()[kTimestampedClassificationsName]
.Get<std::vector<ClassificationResult>>();
std::vector<AudioClassifierResult> results;
results.reserve(classification_results.size());
for (const auto& classification_result : classification_results) {
results.emplace_back(ConvertToClassificationResult(classification_result));
}
return results;
}
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return ConvertToClassificationResult(
status_or_packets.value()[kClassificationsName]
.Get<ClassificationResult>());
}
} // namespace
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
auto result_callback = options->result_callback;
packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
result_callback(ConvertOutputPackets(status_or_packets));
result_callback(ConvertAsyncOutputPackets(status_or_packets));
};
}
return core::AudioTaskApiFactory::Create<AudioClassifier,
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
std::move(packets_callback));
}
absl::StatusOr<ClassificationResult> AudioClassifier::Classify(
absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
Matrix audio_clip, double audio_sample_rate) {
return ConvertOutputPackets(ProcessAudioClip(
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},

View File

@ -18,12 +18,13 @@ limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
@ -32,6 +33,10 @@ namespace tasks {
namespace audio {
namespace audio_classifier {
// Alias the shared ClassificationResult struct as result type.
using AudioClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions {
// Base options for configuring Task library, such as specifying the TfLite
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM.
std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>)>
result_callback = nullptr;
std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
nullptr;
};
// Performs audio classification on audio clips or audio stream.
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// required to provide the corresponding audio sample rate along with the
// input audio clips.
//
// For each audio clip, the output classifications are grouped in a
// ClassificationResult object that has three dimensions:
// Classification head:
// The prediction heads targeting different audio classification tasks
// such as audio event classification and bird sound classification.
// Classification timestamp:
// The start time (in milliseconds) of each audio clip that is sent to the
// model for audio classification. As the audio classification models take
// a fixed number of audio samples, long audio clips will be framed to
// multiple buffers (with the desired number of audio samples) during
// preprocessing.
// Classification category:
// The list of the classification categories that model predicts per
// framed audio clip.
// The input audio clip may be longer than what the model is able to process
// in a single inference. When this occurs, the input audio clip is split into
// multiple chunks starting at different timestamps. For this reason, this
// function returns a vector of ClassificationResult objects, each associated
// with a timestamp corresponding to the start (in milliseconds) of the chunk
// data that was classified, e.g:
//
// ClassificationResult #0 (first chunk of data):
// timestamp_ms: 0 (starts at 0ms)
// classifications #0 (single head model):
// category #0:
// category_name: "Speech"
// score: 0.6
// category #1:
// category_name: "Music"
// score: 0.2
// ClassificationResult #1 (second chunk of data):
// timestamp_ms: 800 (starts at 800ms)
// classifications #0 (single head model):
// category #0:
// category_name: "Speech"
// score: 0.5
// category #1:
// category_name: "Silence"
// score: 0.1
// ...
//
// TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing.
//
// The output classifications are grouped in a ClassificationResult object
// that has three dimensions:
// Classification head:
// The prediction heads targeting different audio classification tasks
// such as audio event classification and bird sound classification.
// Classification timestamp :
// The start time (in milliseconds) of the framed audio block that is sent
// to the model for audio classification.
// Classification category:
// The list of the classification categories that model predicts per
// framed audio clip.
// The input audio block may be longer than what the model is able to process
// in a single inference. When this occurs, the input audio block is split
// into multiple chunks. For this reason, the callback may be called multiple
// times (once per chunk) for each call to this function.
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
// Shuts down the AudioClassifier when all works are done.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdint.h>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the audio classifier
// graph.
struct AudioClassifierOutputStreams {
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() &&
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
// series stream header with sample rate info.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category).
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Only produces results if
// the graph if the 'use_stream_mode' option is true.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Only
// produces results if the graph if the 'use_stream_mode' option is false.
//
// Example:
// node {
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "CLASSIFICATIONS:classifications"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// {
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
.base_options()
.use_stream_mode();
ASSIGN_OR_RETURN(
auto classification_result_out,
auto output_streams,
BuildAudioClassificationTask(
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)],
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
? absl::nullopt
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
graph));
classification_result_out >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig();
}
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on
// audio files. Disables time aggregration by not connecting the
// audio files. Disables timestamp aggregation by not connecting the
// "TIMESTAMPS" streams.
if (!use_stream_mode) {
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
}
// Outputs the aggregated classification result as the subgraph output
// stream.
return postprocessing[Output<ClassificationResult>(
kClassificationResultTag)];
// Output both streams as graph output streams/
return AudioClassifierOutputStreams{
/*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationsTag)],
/*timestamped_classifications=*/
postprocessing[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)],
};
}
};

View File

@ -32,13 +32,11 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
@ -49,7 +47,6 @@ namespace {
using ::absl::StatusOr;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix();
}
void CheckSpeechClassificationResult(const ClassificationResult& result) {
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
EXPECT_EQ(result.classifications(0).head_name(), "scores");
EXPECT_EQ(result.classifications(0).head_index(), 0);
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5));
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
int expected_num_categories = 521) {
EXPECT_EQ(result.size(), 5);
// Ignore last result, which operates on a too small chunk to return relevant
// results.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < timestamps_ms.size(); i++) {
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
testing::Eq(521));
const auto* top_category =
&result.classifications(0).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech"));
EXPECT_GT(top_category->score(), 0.9f);
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(),
timestamps_ms[i]);
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
EXPECT_EQ(result[i].classifications.size(), 1);
auto classifications = result[i].classifications[0];
EXPECT_EQ(classifications.head_index, 0);
EXPECT_EQ(classifications.head_name, "scores");
EXPECT_EQ(classifications.categories.size(), expected_num_categories);
auto category = classifications.categories[0];
EXPECT_EQ(category.index, 0);
EXPECT_EQ(category.category_name, "Speech");
EXPECT_GT(category.score, 0.9f);
}
}
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) {
EXPECT_THAT(result.classifications_size(), testing::Eq(2));
// Checks classification head #1.
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification");
EXPECT_EQ(result.classifications(0).head_index(), 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
testing::Eq(521));
const auto* top_category =
&result.classifications(0).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(),
testing::Eq("Environmental noise"));
EXPECT_GT(top_category->score(), 0.5f);
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0);
if (result.classifications(0).entries_size() == 2) {
top_category = &result.classifications(0).entries(1).categories(0);
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence"));
EXPECT_GT(top_category->score(), 0.99f);
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975);
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_GE(result.size(), 1);
EXPECT_LE(result.size(), 2);
// Check first result.
EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_EQ(result[0].classifications.size(), 2);
// Check first head.
EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
"Environmental noise");
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
// Check second head.
EXPECT_EQ(result[0].classifications[1].head_index, 1);
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
"Chestnut-crowned Antpitta");
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
// Check second result, if present.
if (result.size() == 2) {
EXPECT_EQ(result[1].timestamp_ms, 975);
EXPECT_EQ(result[1].classifications.size(), 2);
// Check first head.
EXPECT_EQ(result[1].classifications[0].head_index, 0);
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
"Silence");
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
// Check second head.
EXPECT_EQ(result[1].classifications[1].head_index, 1);
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
"White-breasted Wood-Wren");
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
}
// Checks classification head #2.
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
EXPECT_EQ(result.classifications(1).head_index(), 1);
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
testing::Eq(5));
top_category = &result.classifications(1).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(),
testing::Eq("Chestnut-crowned Antpitta"));
EXPECT_GT(top_category->score(), 0.9f);
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
}
ClassificationResult GenerateSpeechClassificationResult() {
return ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
head_index: 0
head_name: "scores"
entries {
categories { index: 0 score: 0.94140625 category_name: "Speech" }
timestamp_ms: 0
}
entries {
categories { index: 0 score: 0.9921875 category_name: "Speech" }
timestamp_ms: 975
}
entries {
categories { index: 0 score: 0.98828125 category_name: "Speech" }
timestamp_ms: 1950
}
entries {
categories { index: 0 score: 0.99609375 category_name: "Speech" }
timestamp_ms: 2925
}
entries {
# categories are filtered out due to the low scores.
timestamp_ms: 3900
}
})pb");
}
void CheckStreamingModeClassificationResult(
std::vector<ClassificationResult> outputs) {
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
auto expected_results = GenerateSpeechClassificationResult();
for (int i = 0; i < outputs.size() - 1; ++i) {
EXPECT_THAT(outputs[i].classifications(0).entries(0),
EqualsProto(expected_results.classifications(0).entries(i)));
void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
EXPECT_EQ(outputs.size(), 5);
// Ignore last result, which operates on a too small chunk to return relevant
// results.
for (int i = 0; i < outputs.size() - 1; i++) {
EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
EXPECT_EQ(outputs[i].classifications.size(), 1);
EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
"Speech");
EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
}
int last_elem_index = outputs.size() - 1;
EXPECT_EQ(
mediapipe::Timestamp::Done().Value() / 1000,
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {};
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options));
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {};
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options));
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result);
CheckSpeechResult(result);
}
TEST_F(ClassifyTest, SucceedsWithResampling) {
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result);
CheckSpeechResult(result);
}
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000));
CheckSpeechClassificationResult(result_16k_hz);
CheckSpeechResult(result_16k_hz);
MP_ASSERT_OK_AND_ASSIGN(
auto result_48k_hz,
audio_classifier->Classify(std::move(audio_buffer_48k_hz),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result_48k_hz);
CheckSpeechResult(result_48k_hz);
}
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1));
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
testing::Eq(521));
EXPECT_THAT(
result.classifications(0).entries(0).categories(0).category_name(),
testing::Eq("Silence"));
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(),
testing::FloatEq(.800781f));
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_EQ(result[0].classifications.size(), 1);
EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_EQ(result[0].classifications[0].head_name, "scores");
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
"Silence");
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
}
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result);
CheckTwoHeadsResult(result);
}
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/44100));
MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result);
CheckTwoHeadsResult(result);
}
TEST_F(ClassifyTest,
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
auto result_44k_hz,
audio_classifier->Classify(std::move(audio_buffer_44k_hz),
/*audio_sample_rate=*/44100));
CheckTwoHeadsClassificationResult(result_44k_hz);
CheckTwoHeadsResult(result_44k_hz);
MP_ASSERT_OK_AND_ASSIGN(
auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result_16k_hz);
CheckTwoHeadsResult(result_16k_hz);
}
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.35f;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
CheckSpeechResult(result, /*expected_num_categories=*/1);
}
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
CheckSpeechResult(result, /*expected_num_categories=*/1);
}
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
CheckSpeechResult(result, /*expected_num_categories=*/1);
}
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close());
// All categroies with the "Speech" label are filtered out.
EXPECT_THAT(result, EqualsProto(R"pb(classifications {
head_index: 0
head_name: "scores"
entries { timestamp_ms: 0 }
entries { timestamp_ms: 975 }
entries { timestamp_ms: 1950 }
entries { timestamp_ms: 2925 }
entries { timestamp_ms: 3900 }
})pb"));
// All categories with the "Speech" label are filtered out.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < timestamps_ms.size(); i++) {
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
EXPECT_EQ(result[i].classifications.size(), 1);
auto classifications = result[i].classifications[0];
EXPECT_EQ(classifications.head_index, 0);
EXPECT_EQ(classifications.head_name, "scores");
EXPECT_TRUE(classifications.categories.empty());
}
}
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs;
std::vector<AudioClassifierResult> outputs;
options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
start_col += kYamnetNumOfAudioSamples * 3;
}
MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs);
CheckStreamingModeResults(outputs);
}
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs;
std::vector<AudioClassifierResult> outputs;
options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
start_col += num_samples;
}
MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs);
CheckStreamingModeResults(outputs);
}
} // namespace