Migrate AudioClassifier C++ to use new ClassificationResult struct.
PiperOrigin-RevId: 486162683
This commit is contained in:
parent
93a587a422
commit
8b2c937b9e
|
@ -65,6 +65,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
|
|
|
@ -18,12 +18,14 @@ limitations under the License.
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
@ -38,12 +40,16 @@ namespace audio_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kAudioStreamName[] = "audio_in";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kClassificationsName[] = "classifications_out";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
constexpr char kTimestampedClassificationsName[] =
|
||||
"timestamped_classifications_out";
|
||||
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
|
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
}
|
||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||
graph.Out(kClassificationsTag);
|
||||
subgraph.Out(kTimestampedClassificationsTag)
|
||||
.SetName(kTimestampedClassificationsName) >>
|
||||
graph.Out(kTimestampedClassificationsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
|||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ConvertOutputPackets(
|
||||
absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
|
||||
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
return status_or_packets.status();
|
||||
}
|
||||
return status_or_packets.value()[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
auto classification_results =
|
||||
status_or_packets.value()[kTimestampedClassificationsName]
|
||||
.Get<std::vector<ClassificationResult>>();
|
||||
std::vector<AudioClassifierResult> results;
|
||||
results.reserve(classification_results.size());
|
||||
for (const auto& classification_result : classification_results) {
|
||||
results.emplace_back(ConvertToClassificationResult(classification_result));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
|
||||
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
return status_or_packets.status();
|
||||
}
|
||||
return ConvertToClassificationResult(
|
||||
status_or_packets.value()[kClassificationsName]
|
||||
.Get<ClassificationResult>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
|||
auto result_callback = options->result_callback;
|
||||
packets_callback =
|
||||
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||
result_callback(ConvertOutputPackets(status_or_packets));
|
||||
result_callback(ConvertAsyncOutputPackets(status_or_packets));
|
||||
};
|
||||
}
|
||||
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
||||
|
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> AudioClassifier::Classify(
|
||||
absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
|
||||
Matrix audio_clip, double audio_sample_rate) {
|
||||
return ConvertOutputPackets(ProcessAudioClip(
|
||||
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
|
||||
|
|
|
@ -18,12 +18,13 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
|
@ -32,6 +33,10 @@ namespace tasks {
|
|||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
// Alias the shared ClassificationResult struct as result type.
|
||||
using AudioClassifierResult =
|
||||
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||
|
||||
// The options for configuring a mediapipe audio classifier task.
|
||||
struct AudioClassifierOptions {
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
|
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
|
|||
// The user-defined result callback for processing audio stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::AUDIO_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>)>
|
||||
result_callback = nullptr;
|
||||
std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
|
||||
nullptr;
|
||||
};
|
||||
|
||||
// Performs audio classification on audio clips or audio stream.
|
||||
|
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
// required to provide the corresponding audio sample rate along with the
|
||||
// input audio clips.
|
||||
//
|
||||
// For each audio clip, the output classifications are grouped in a
|
||||
// ClassificationResult object that has three dimensions:
|
||||
// Classification head:
|
||||
// The prediction heads targeting different audio classification tasks
|
||||
// such as audio event classification and bird sound classification.
|
||||
// Classification timestamp:
|
||||
// The start time (in milliseconds) of each audio clip that is sent to the
|
||||
// model for audio classification. As the audio classification models take
|
||||
// a fixed number of audio samples, long audio clips will be framed to
|
||||
// multiple buffers (with the desired number of audio samples) during
|
||||
// preprocessing.
|
||||
// Classification category:
|
||||
// The list of the classification categories that model predicts per
|
||||
// framed audio clip.
|
||||
// The input audio clip may be longer than what the model is able to process
|
||||
// in a single inference. When this occurs, the input audio clip is split into
|
||||
// multiple chunks starting at different timestamps. For this reason, this
|
||||
// function returns a vector of ClassificationResult objects, each associated
|
||||
// with a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||
// data that was classified, e.g:
|
||||
//
|
||||
// ClassificationResult #0 (first chunk of data):
|
||||
// timestamp_ms: 0 (starts at 0ms)
|
||||
// classifications #0 (single head model):
|
||||
// category #0:
|
||||
// category_name: "Speech"
|
||||
// score: 0.6
|
||||
// category #1:
|
||||
// category_name: "Music"
|
||||
// score: 0.2
|
||||
// ClassificationResult #1 (second chunk of data):
|
||||
// timestamp_ms: 800 (starts at 800ms)
|
||||
// classifications #0 (single head model):
|
||||
// category #0:
|
||||
// category_name: "Speech"
|
||||
// score: 0.5
|
||||
// category #1:
|
||||
// category_name: "Silence"
|
||||
// score: 0.1
|
||||
// ...
|
||||
//
|
||||
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
||||
// and makes `audio_sample_rate` optional.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
|
||||
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
||||
|
||||
// Sends audio data (a block in a continuous audio stream) to perform audio
|
||||
|
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
// milliseconds) to indicate the start time of the input audio block. The
|
||||
// timestamps must be monotonically increasing.
|
||||
//
|
||||
// The output classifications are grouped in a ClassificationResult object
|
||||
// that has three dimensions:
|
||||
// Classification head:
|
||||
// The prediction heads targeting different audio classification tasks
|
||||
// such as audio event classification and bird sound classification.
|
||||
// Classification timestamp :
|
||||
// The start time (in milliseconds) of the framed audio block that is sent
|
||||
// to the model for audio classification.
|
||||
// Classification category:
|
||||
// The list of the classification categories that model predicts per
|
||||
// framed audio clip.
|
||||
// The input audio block may be longer than what the model is able to process
|
||||
// in a single inference. When this occurs, the input audio block is split
|
||||
// into multiple chunks. For this reason, the callback may be called multiple
|
||||
// times (once per chunk) for each call to this function.
|
||||
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
|
||||
|
||||
// Shuts down the AudioClassifier when all works are done.
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include <stdint.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|||
|
||||
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
constexpr char kPacketTag[] = "PACKET";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Struct holding the different output streams produced by the audio classifier
|
||||
// graph.
|
||||
struct AudioClassifierOutputStreams {
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(
|
||||
const proto::AudioClassifierGraphOptions& options) {
|
||||
if (options.base_options().use_stream_mode() &&
|
||||
|
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
|
|||
// series stream header with sample rate info.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The aggregated classification result object that has 3 dimensions:
|
||||
// (classification head, classification timestamp, classification category).
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Only produces results if
|
||||
// the graph if the 'use_stream_mode' option is true.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Only
|
||||
// produces results if the graph if the 'use_stream_mode' option is false.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
||||
// input_stream: "AUDIO:audio_in"
|
||||
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// output_stream: "CLASSIFICATIONS:classifications"
|
||||
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
|
||||
// options {
|
||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
||||
// {
|
||||
|
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
.base_options()
|
||||
.use_stream_mode();
|
||||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
auto output_streams,
|
||||
BuildAudioClassificationTask(
|
||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Matrix>(kAudioTag)],
|
||||
|
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
? absl::nullopt
|
||||
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
|
||||
graph));
|
||||
classification_result_out >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.timestamped_classifications >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
||||
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
||||
absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
|
||||
const proto::AudioClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||
|
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Time aggregation is only needed for performing audio classification on
|
||||
// audio files. Disables time aggregration by not connecting the
|
||||
// audio files. Disables timestamp aggregation by not connecting the
|
||||
// "TIMESTAMPS" streams.
|
||||
if (!use_stream_mode) {
|
||||
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
|
||||
}
|
||||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
// Output both streams as graph output streams/
|
||||
return AudioClassifierOutputStreams{
|
||||
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationsTag)],
|
||||
/*timestamped_classifications=*/
|
||||
postprocessing[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -32,13 +32,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -49,7 +47,6 @@ namespace {
|
|||
|
||||
using ::absl::StatusOr;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
|
|||
return matrix_mapping.matrix();
|
||||
}
|
||||
|
||||
void CheckSpeechClassificationResult(const ClassificationResult& result) {
|
||||
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
|
||||
EXPECT_EQ(result.classifications(0).head_name(), "scores");
|
||||
EXPECT_EQ(result.classifications(0).head_index(), 0);
|
||||
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5));
|
||||
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
||||
int expected_num_categories = 521) {
|
||||
EXPECT_EQ(result.size(), 5);
|
||||
// Ignore last result, which operates on a too small chunk to return relevant
|
||||
// results.
|
||||
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
||||
for (int i = 0; i < timestamps_ms.size(); i++) {
|
||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
||||
testing::Eq(521));
|
||||
const auto* top_category =
|
||||
&result.classifications(0).entries(0).categories(0);
|
||||
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech"));
|
||||
EXPECT_GT(top_category->score(), 0.9f);
|
||||
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(),
|
||||
timestamps_ms[i]);
|
||||
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
|
||||
EXPECT_EQ(result[i].classifications.size(), 1);
|
||||
auto classifications = result[i].classifications[0];
|
||||
EXPECT_EQ(classifications.head_index, 0);
|
||||
EXPECT_EQ(classifications.head_name, "scores");
|
||||
EXPECT_EQ(classifications.categories.size(), expected_num_categories);
|
||||
auto category = classifications.categories[0];
|
||||
EXPECT_EQ(category.index, 0);
|
||||
EXPECT_EQ(category.category_name, "Speech");
|
||||
EXPECT_GT(category.score, 0.9f);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) {
|
||||
EXPECT_THAT(result.classifications_size(), testing::Eq(2));
|
||||
// Checks classification head #1.
|
||||
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification");
|
||||
EXPECT_EQ(result.classifications(0).head_index(), 0);
|
||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
||||
testing::Eq(521));
|
||||
const auto* top_category =
|
||||
&result.classifications(0).entries(0).categories(0);
|
||||
EXPECT_THAT(top_category->category_name(),
|
||||
testing::Eq("Environmental noise"));
|
||||
EXPECT_GT(top_category->score(), 0.5f);
|
||||
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0);
|
||||
if (result.classifications(0).entries_size() == 2) {
|
||||
top_category = &result.classifications(0).entries(1).categories(0);
|
||||
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence"));
|
||||
EXPECT_GT(top_category->score(), 0.99f);
|
||||
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975);
|
||||
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||
EXPECT_GE(result.size(), 1);
|
||||
EXPECT_LE(result.size(), 2);
|
||||
// Check first result.
|
||||
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||
EXPECT_EQ(result[0].classifications.size(), 2);
|
||||
// Check first head.
|
||||
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
|
||||
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||
EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
|
||||
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||
"Environmental noise");
|
||||
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
|
||||
// Check second head.
|
||||
EXPECT_EQ(result[0].classifications[1].head_index, 1);
|
||||
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
|
||||
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
|
||||
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
|
||||
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
|
||||
"Chestnut-crowned Antpitta");
|
||||
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
|
||||
// Check second result, if present.
|
||||
if (result.size() == 2) {
|
||||
EXPECT_EQ(result[1].timestamp_ms, 975);
|
||||
EXPECT_EQ(result[1].classifications.size(), 2);
|
||||
// Check first head.
|
||||
EXPECT_EQ(result[1].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
|
||||
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
|
||||
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
|
||||
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
|
||||
"Silence");
|
||||
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
|
||||
// Check second head.
|
||||
EXPECT_EQ(result[1].classifications[1].head_index, 1);
|
||||
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
|
||||
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
|
||||
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
|
||||
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
|
||||
"White-breasted Wood-Wren");
|
||||
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
|
||||
}
|
||||
// Checks classification head #2.
|
||||
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
|
||||
EXPECT_EQ(result.classifications(1).head_index(), 1);
|
||||
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
|
||||
testing::Eq(5));
|
||||
top_category = &result.classifications(1).entries(0).categories(0);
|
||||
EXPECT_THAT(top_category->category_name(),
|
||||
testing::Eq("Chestnut-crowned Antpitta"));
|
||||
EXPECT_GT(top_category->score(), 0.9f);
|
||||
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
|
||||
}
|
||||
|
||||
ClassificationResult GenerateSpeechClassificationResult() {
|
||||
return ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
head_index: 0
|
||||
head_name: "scores"
|
||||
entries {
|
||||
categories { index: 0 score: 0.94140625 category_name: "Speech" }
|
||||
timestamp_ms: 0
|
||||
}
|
||||
entries {
|
||||
categories { index: 0 score: 0.9921875 category_name: "Speech" }
|
||||
timestamp_ms: 975
|
||||
}
|
||||
entries {
|
||||
categories { index: 0 score: 0.98828125 category_name: "Speech" }
|
||||
timestamp_ms: 1950
|
||||
}
|
||||
entries {
|
||||
categories { index: 0 score: 0.99609375 category_name: "Speech" }
|
||||
timestamp_ms: 2925
|
||||
}
|
||||
entries {
|
||||
# categories are filtered out due to the low scores.
|
||||
timestamp_ms: 3900
|
||||
}
|
||||
})pb");
|
||||
}
|
||||
|
||||
void CheckStreamingModeClassificationResult(
|
||||
std::vector<ClassificationResult> outputs) {
|
||||
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
|
||||
auto expected_results = GenerateSpeechClassificationResult();
|
||||
for (int i = 0; i < outputs.size() - 1; ++i) {
|
||||
EXPECT_THAT(outputs[i].classifications(0).entries(0),
|
||||
EqualsProto(expected_results.classifications(0).entries(i)));
|
||||
void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
||||
EXPECT_EQ(outputs.size(), 5);
|
||||
// Ignore last result, which operates on a too small chunk to return relevant
|
||||
// results.
|
||||
for (int i = 0; i < outputs.size() - 1; i++) {
|
||||
EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
|
||||
EXPECT_EQ(outputs[i].classifications.size(), 1);
|
||||
EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
|
||||
EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
|
||||
EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
|
||||
EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
|
||||
"Speech");
|
||||
EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
|
||||
}
|
||||
int last_elem_index = outputs.size() - 1;
|
||||
EXPECT_EQ(
|
||||
mediapipe::Timestamp::Done().Value() / 1000,
|
||||
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
|
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
||||
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||
AudioClassifier::Create(std::move(options));
|
||||
|
||||
|
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
|||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
||||
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||
AudioClassifier::Create(std::move(options));
|
||||
|
||||
|
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/16000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckSpeechClassificationResult(result);
|
||||
CheckSpeechResult(result);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithResampling) {
|
||||
|
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckSpeechClassificationResult(result);
|
||||
CheckSpeechResult(result);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
||||
|
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
|||
auto result_16k_hz,
|
||||
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
||||
/*audio_sample_rate=*/16000));
|
||||
CheckSpeechClassificationResult(result_16k_hz);
|
||||
CheckSpeechResult(result_16k_hz);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result_48k_hz,
|
||||
audio_classifier->Classify(std::move(audio_buffer_48k_hz),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckSpeechClassificationResult(result_48k_hz);
|
||||
CheckSpeechResult(result_48k_hz);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
||||
|
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
|
||||
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1));
|
||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
||||
testing::Eq(521));
|
||||
EXPECT_THAT(
|
||||
result.classifications(0).entries(0).categories(0).category_name(),
|
||||
testing::Eq("Silence"));
|
||||
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(),
|
||||
testing::FloatEq(.800781f));
|
||||
EXPECT_EQ(result.size(), 1);
|
||||
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||
EXPECT_EQ(result[0].classifications.size(), 1);
|
||||
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(result[0].classifications[0].head_name, "scores");
|
||||
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||
EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
|
||||
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||
"Silence");
|
||||
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
||||
|
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/16000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckTwoHeadsClassificationResult(result);
|
||||
CheckTwoHeadsResult(result);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
||||
|
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/44100));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckTwoHeadsClassificationResult(result);
|
||||
CheckTwoHeadsResult(result);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest,
|
||||
|
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
|
|||
auto result_44k_hz,
|
||||
audio_classifier->Classify(std::move(audio_buffer_44k_hz),
|
||||
/*audio_sample_rate=*/44100));
|
||||
CheckTwoHeadsClassificationResult(result_44k_hz);
|
||||
CheckTwoHeadsResult(result_44k_hz);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result_16k_hz,
|
||||
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
||||
/*audio_sample_rate=*/16000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckTwoHeadsClassificationResult(result_16k_hz);
|
||||
CheckTwoHeadsResult(result_16k_hz);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
||||
|
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||
options->classifier_options.max_results = 1;
|
||||
options->classifier_options.score_threshold = 0.35f;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||
AudioClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
||||
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
||||
|
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
||||
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
||||
|
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
||||
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||
|
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
|||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||
/*audio_sample_rate=*/48000));
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
// All categroies with the "Speech" label are filtered out.
|
||||
EXPECT_THAT(result, EqualsProto(R"pb(classifications {
|
||||
head_index: 0
|
||||
head_name: "scores"
|
||||
entries { timestamp_ms: 0 }
|
||||
entries { timestamp_ms: 975 }
|
||||
entries { timestamp_ms: 1950 }
|
||||
entries { timestamp_ms: 2925 }
|
||||
entries { timestamp_ms: 3900 }
|
||||
})pb"));
|
||||
// All categories with the "Speech" label are filtered out.
|
||||
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
||||
for (int i = 0; i < timestamps_ms.size(); i++) {
|
||||
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
|
||||
EXPECT_EQ(result[i].classifications.size(), 1);
|
||||
auto classifications = result[i].classifications[0];
|
||||
EXPECT_EQ(classifications.head_index, 0);
|
||||
EXPECT_EQ(classifications.head_name, "scores");
|
||||
EXPECT_TRUE(classifications.categories.empty());
|
||||
}
|
||||
}
|
||||
|
||||
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
||||
|
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
|||
options->classifier_options.score_threshold = 0.3f;
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->sample_rate = kSampleRateHz;
|
||||
std::vector<ClassificationResult> outputs;
|
||||
std::vector<AudioClassifierResult> outputs;
|
||||
options->result_callback =
|
||||
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
|
||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
||||
};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||
|
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
|||
start_col += kYamnetNumOfAudioSamples * 3;
|
||||
}
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckStreamingModeClassificationResult(outputs);
|
||||
CheckStreamingModeResults(outputs);
|
||||
}
|
||||
|
||||
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||
|
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
options->classifier_options.score_threshold = 0.3f;
|
||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||
options->sample_rate = kSampleRateHz;
|
||||
std::vector<ClassificationResult> outputs;
|
||||
std::vector<AudioClassifierResult> outputs;
|
||||
options->result_callback =
|
||||
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
|
||||
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
||||
};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||
|
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
start_col += num_samples;
|
||||
}
|
||||
MP_ASSERT_OK(audio_classifier->Close());
|
||||
CheckStreamingModeClassificationResult(outputs);
|
||||
CheckStreamingModeResults(outputs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue
Block a user