Add DftTensorFormat To TensorsToAudioCalculatorOptions.

PiperOrigin-RevId: 515077766
This commit is contained in:
Jiuqiang Tang 2023-03-08 10:36:17 -08:00 committed by Copybara-Service
parent 16c2e32a0d
commit ddc535e705
3 changed files with 92 additions and 19 deletions

View File

@ -34,6 +34,8 @@ namespace mediapipe {
namespace api2 {
namespace {
using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
std::vector<float> hann_window(window_size);
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
@ -138,11 +140,15 @@ class TensorsToAudioCalculator : public Node {
std::vector<float, Eigen::aligned_allocator<float>> prev_fft_output_;
int overlapping_samples_ = -1;
int step_samples_ = -1;
Options::DftTensorFormat dft_tensor_format_;
};
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
dft_tensor_format_ = options.dft_tensor_format();
RET_CHECK(dft_tensor_format_ != Options::DFT_TENSOR_FORMAT_UNKNOWN)
<< "dft tensor format must be specified.";
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
RET_CHECK(IsValidFftSize(options.fft_size()))
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
@ -183,14 +189,37 @@ absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
auto view = input_tensors[0].GetCpuReadView();
switch (dft_tensor_format_) {
case Options::WITH_NYQUIST: {
// DC's real part.
input_dft_[0] = kDcAndNyquistIn(cc)->first;
// Nyquist's real part is the penultimate element of the tensor buffer.
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
// from the tensor buffer.
// pffft ignores the Nyquist's imagery part. No need to fetch the last
// value from the tensor buffer.
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
(fft_size_ - 2) * sizeof(float));
break;
}
case Options::WITH_DC_AND_NYQUIST: {
// DC's real part is the first element of the tensor buffer.
input_dft_[0] = *(view.buffer<float>());
// Nyquist's real part is the penultimate element of the tensor buffer.
input_dft_[1] = *(view.buffer<float>() + fft_size_);
std::memcpy(input_dft_.data() + 2, view.buffer<float>() + 2,
(fft_size_ - 2) * sizeof(float));
break;
}
case Options::WITHOUT_DC_AND_NYQUIST: {
input_dft_[0] = kDcAndNyquistIn(cc)->first;
input_dft_[1] = kDcAndNyquistIn(cc)->second;
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
(fft_size_ - 2) * sizeof(float));
break;
}
default:
return absl::InvalidArgumentError("Unsupported dft tensor format.");
}
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
fft_workplace_.data(), PFFFT_BACKWARD);
// Applies the inverse window function.

View File

@ -32,4 +32,17 @@ message TensorsToAudioCalculatorOptions {
// The number of overlapping samples between adjacent windows.
optional int64 num_overlapping_samples = 3 [default = 0];
enum DftTensorFormat {
DFT_TENSOR_FORMAT_UNKNOWN = 0;
// The input dft tensor without dc and nyquist components.
WITHOUT_DC_AND_NYQUIST = 1;
// The input dft tensor contains the nyquist component as the last
// two values.
WITH_NYQUIST = 2;
// The input dft tensor contains the dc component as the first two values
// and the nyquist component as the last two values.
WITH_DC_AND_NYQUIST = 3;
}
optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST];
}

View File

@ -30,6 +30,8 @@
namespace mediapipe {
namespace {
using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
protected:
// Creates an audio matrix containing a single sample of 1.0 at a specified
@ -40,9 +42,10 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
return impulse;
}
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
void ConfigGraph(int num_samples, double sample_rate, int fft_size,
Options::DftTensorFormat dft_tensor_format) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"(
input_stream: "audio_in"
input_stream: "sample_rate"
output_stream: "audio_out"
@ -59,6 +62,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
num_overlapping_samples: 0
target_sample_rate: $1
fft_size: $2
dft_tensor_format: $3
}
}
}
@ -70,13 +74,15 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
options {
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
fft_size: $2
dft_tensor_format: $3
}
}
}
)",
/*$0=*/num_samples,
/*$1=*/sample_rate,
/*$2=*/fft_size));
/*$2=*/fft_size,
/*$3=*/Options::DftTensorFormat_Name(dft_tensor_format)));
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
}
@ -97,7 +103,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
};
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
ConfigGraph(320, 16000, 103);
ConfigGraph(320, 16000, 103, Options::WITH_NYQUIST);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
@ -109,8 +115,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
@ -122,7 +127,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
@ -135,7 +140,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
@ -145,5 +150,31 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
}
TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithDCAndNyquist) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320, Options::WITH_DC_AND_NYQUIST);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// The impulse signal at the center is not affected by the window function.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithoutDCAndNyquist) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320, Options::WITHOUT_DC_AND_NYQUIST);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// The impulse signal at the center is not affected by the window function.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
}
} // namespace
} // namespace mediapipe