mediapipe/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc
Sebastian Schmidt 9f59d4d01b Remove cosineSimilarity() from AudioEmbedder
PiperOrigin-RevId: 512671255
2023-02-27 11:13:20 -08:00

268 lines
12 KiB
C++

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/audio/audio_embedder/audio_embedder.h"
#include <algorithm>
#include <memory>
#include <new>
#include <string>
#include <utility>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_embedder {
namespace {
using ::absl::StatusOr;
using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/audio";
constexpr char kModelWithMetadata[] = "yamnet_embedding_metadata.tflite";
constexpr char k16kTestWavFilename[] = "speech_16000_hz_mono.wav";
constexpr char k48kTestWavFilename[] = "speech_48000_hz_mono.wav";
constexpr char k16kTestWavForTwoHeadsFilename[] = "two_heads_16000_hz_mono.wav";
constexpr int kMilliSecondsPerSecond = 1000;
constexpr int kYamnetNumOfAudioSamples = 15600;
constexpr int kYamnetAudioSampleRate = 16000;
Matrix GetAudioData(absl::string_view filename) {
std::string wav_file_path = JoinPath("./", kTestDataDirectory, filename);
int buffer_size;
auto audio_data = internal::ReadWavFile(wav_file_path, &buffer_size);
Eigen::Map<Matrix> matrix_mapping(audio_data->get(), 1, buffer_size);
return matrix_mapping.matrix();
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
auto audio_embedder =
AudioEmbedder::Create(std::make_unique<AudioEmbedderOptions>());
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
audio_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
}
TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInAudioClipsMode) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
options->result_callback = [](absl::StatusOr<AudioEmbedderResult>) {};
auto audio_embedder = AudioEmbedder::Create(std::move(options));
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
audio_embedder.status().message(),
HasSubstr("a user-defined result callback shouldn't be provided"));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
auto audio_embedder = AudioEmbedder::Create(std::move(options));
EXPECT_EQ(audio_embedder.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(audio_embedder.status().message(),
HasSubstr("a user-defined result callback must be provided"));
EXPECT_THAT(audio_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class EmbedTest : public tflite_shims::testing::Test {};
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
Matrix silent_data(1, kYamnetNumOfAudioSamples);
silent_data.setZero();
MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_embedder->Embed(silent_data, kYamnetAudioSampleRate));
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0].embeddings[0].float_embedding.size(), 1024);
constexpr float kValueDiffTolerance = 3e-6;
EXPECT_NEAR(result[0].embeddings[0].float_embedding[0], 2.07613f,
kValueDiffTolerance);
EXPECT_NEAR(result[0].embeddings[0].float_embedding[1], 0.392721f,
kValueDiffTolerance);
EXPECT_NEAR(result[0].embeddings[0].float_embedding[2], 0.543622f,
kValueDiffTolerance);
}
TEST_F(EmbedTest, SucceedsWithSameAudioAtDifferentSampleRates) {
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
auto audio_buffer2 = GetAudioData(k48kTestWavFilename);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto result1,
audio_embedder->Embed(audio_buffer1, 16000));
MP_ASSERT_OK_AND_ASSIGN(auto result2,
audio_embedder->Embed(audio_buffer2, 48000));
int expected_size = 5;
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
MP_EXPECT_OK(audio_embedder->Close());
}
TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
auto audio_buffer1 = GetAudioData(k16kTestWavFilename);
auto audio_buffer2 = GetAudioData(k16kTestWavForTwoHeadsFilename);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_CLIPS;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result1,
audio_embedder->Embed(audio_buffer1, kYamnetAudioSampleRate));
MP_ASSERT_OK_AND_ASSIGN(
auto result2,
audio_embedder->Embed(audio_buffer2, kYamnetAudioSampleRate));
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
MP_EXPECT_OK(audio_embedder->Close());
}
class EmbedAsyncTest : public tflite_shims::testing::Test {
protected:
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
int sample_rate_hz,
std::vector<AudioEmbedderResult>* result) {
auto audio_buffer = GetAudioData(audio_file_name);
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[result](absl::StatusOr<AudioEmbedderResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(result->emplace_back(), status_or_result);
};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
int start_col = 0;
static unsigned int rseed = 0;
while (start_col < audio_buffer.cols()) {
int num_samples = std::min(
(int)(audio_buffer.cols() - start_col),
rand_r(&rseed) % 10 + kYamnetNumOfAudioSamples * sample_rate_hz /
kYamnetAudioSampleRate);
MP_ASSERT_OK(audio_embedder->EmbedAsync(
audio_buffer.block(0, start_col, 1, num_samples), sample_rate_hz,
start_col * kMilliSecondsPerSecond / sample_rate_hz));
start_col += num_samples;
}
MP_ASSERT_OK(audio_embedder->Close());
}
};
TEST_F(EmbedAsyncTest, FailsWithOutOfOrderInputTimestamps) {
auto options = std::make_unique<AudioEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback =
[](absl::StatusOr<AudioEmbedderResult> status_or_result) { return; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioEmbedder> audio_embedder,
AudioEmbedder::Create(std::move(options)));
MP_ASSERT_OK(audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
kYamnetAudioSampleRate, 100));
auto status = audio_embedder->EmbedAsync(Matrix(1, kYamnetNumOfAudioSamples),
kYamnetAudioSampleRate, 0);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
HasSubstr("timestamp must be monotonically increasing"));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInvalidTimestampError))));
MP_ASSERT_OK(audio_embedder->Close());
}
TEST_F(EmbedAsyncTest, SucceedsWithSameAudioAtDifferentSampleRates) {
std::vector<AudioEmbedderResult> result1;
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
std::vector<AudioEmbedderResult> result2;
RunAudioEmbedderInStreamMode(k48kTestWavFilename, 48000, &result2);
int expected_size = 5;
ASSERT_EQ(result1.size(), expected_size);
ASSERT_EQ(result2.size(), expected_size);
}
TEST_F(EmbedAsyncTest, SucceedsWithDifferentAudios) {
std::vector<AudioEmbedderResult> result1;
RunAudioEmbedderInStreamMode(k16kTestWavFilename, 16000, &result1);
std::vector<AudioEmbedderResult> result2;
RunAudioEmbedderInStreamMode(k16kTestWavForTwoHeadsFilename, 16000, &result2);
ASSERT_EQ(result1.size(), 5);
ASSERT_EQ(result2.size(), 1);
}
} // namespace
} // namespace audio_embedder
} // namespace audio
} // namespace tasks
} // namespace mediapipe