Implement the initial version of TensorsToAudioCalculator that supports ifft and inverse hann windowing.

PiperOrigin-RevId: 484605092
This commit is contained in:
Jiuqiang Tang 2022-10-28 13:16:38 -07:00 committed by Copybara-Service
parent c5c639d634
commit e16be2e8fa
5 changed files with 426 additions and 1 deletions

View File

@ -109,6 +109,56 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "tensors_to_audio_calculator_proto",
srcs = ["tensors_to_audio_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "tensors_to_audio_calculator",
srcs = ["tensors_to_audio_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_audio_tools//audio/dsp:window_functions",
"@pffft",
],
alwayslink = 1,
)
cc_test(
name = "tensors_to_audio_calculator_test",
srcs = ["tensors_to_audio_calculator_test.cc"],
deps = [
":audio_to_tensor_calculator",
":audio_to_tensor_calculator_cc_proto",
":tensors_to_audio_calculator",
":tensors_to_audio_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "feedback_tensors_calculator_proto", name = "feedback_tensors_calculator_proto",
srcs = ["feedback_tensors_calculator.proto"], srcs = ["feedback_tensors_calculator.proto"],

View File

@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
// invocation. In the non-streaming mode, the vector contains all of the // invocation. In the non-streaming mode, the vector contains all of the
// output timestamps for an input audio buffer. // output timestamps for an input audio buffer.
// DC_AND_NYQUIST - std::pair<float, float> @Optional. // DC_AND_NYQUIST - std::pair<float, float> @Optional.
// A pair of dc component and nyquest component. Only can be connected when // A pair of dc component and nyquist component. Only can be connected when
// the calculator performs fft (the fft_size is set in the calculator // the calculator performs fft (the fft_size is set in the calculator
// options). // options).
// //

View File

@ -0,0 +1,197 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstring>
#include <new>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "pffft.h"
namespace mediapipe {
namespace api2 {
namespace {
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
std::vector<float> hann_window(window_size);
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
if (sqrt_hann) {
absl::c_transform(hann_window, hann_window.begin(),
[](double x) { return std::sqrt(x); });
}
return hann_window;
}
// Note that the InvHannWindow function may only work for 50% overlapping case.
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
std::vector<float> window = HannWindow(window_size, sqrt_hann);
std::vector<float> inv_window(window.size());
if (sqrt_hann) {
absl::c_copy(window, inv_window.begin());
} else {
const int kHalfWindowSize = window.size() / 2;
absl::c_transform(window, inv_window.begin(),
[](double x) { return x * x; });
for (int i = 0; i < kHalfWindowSize; ++i) {
double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
inv_window[i] = window[i] / sum;
inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
}
}
return inv_window;
}
// PFFFT only supports transforms for inputs of length N of the form
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
bool IsValidFftSize(int size) {
if (size <= 0) {
return false;
}
constexpr int kFactors[] = {2, 3, 5};
int factorization[] = {0, 0, 0};
int n = static_cast<int>(size);
for (int i = 0; i < 3; ++i) {
while (n % kFactors[i] == 0) {
n = n / kFactors[i];
++factorization[i];
}
}
return factorization[0] >= 5 && n == 1;
}
} // namespace
// Converts 2D MediaPipe float Tensors to audio buffers.
// The calculator will perform ifft on the complex DFT and apply the window
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
// have the DFT real parts in its first row and the DFT imagery parts in its
// second row. A valid "fft_size" must be set in the CalculatorOptions.
//
// Inputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor that represents the audio's complex DFT
// results.
// DC_AND_NYQUIST - std::pair<float, float>
// A pair of dc component and nyquist component.
//
// Outputs:
// AUDIO - mediapipe::Matrix
// The audio data represented as mediapipe::Matrix.
//
// Example:
// node {
// calculator: "TensorsToAudioCalculator"
// input_stream: "TENSORS:tensors"
// input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
// output_stream: "AUDIO:audio"
// options {
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
// fft_size: 256
// }
// }
// }
class TensorsToAudioCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
"DC_AND_NYQUIST"};
static constexpr Output<Matrix> kAudioOut{"AUDIO"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
// The internal state of the FFT library.
PFFFT_Setup* fft_state_ = nullptr;
int fft_size_ = 0;
float inverse_fft_size_ = 0;
std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
std::vector<float> inv_fft_window_;
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
// pffft requires memory to work with to avoid using the stack.
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
};
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
RET_CHECK(IsValidFftSize(options.fft_size()))
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
">=0 and c >= 0 and a >= 5, the requested fft size is "
<< options.fft_size();
fft_size_ = options.fft_size();
inverse_fft_size_ = 1.0f / fft_size_;
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
input_dft_.resize(fft_size_);
inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
fft_input_buffer_.resize(fft_size_);
fft_workplace_.resize(fft_size_);
fft_output_.resize(fft_size_);
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = *kTensorsIn(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
auto view = input_tensors[0].GetCpuReadView();
// DC's real part.
input_dft_[0] = kDcAndNyquistIn(cc)->first;
// Nyquist's real part is the penultimate element of the tensor buffer.
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
// from the tensor buffer.
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
(fft_size_ - 2) * sizeof(float));
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
fft_workplace_.data(), PFFFT_BACKWARD);
// Applies the inverse window function.
std::transform(
fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
fft_output_.begin(),
[this](float a, float b) { return a * b * inverse_fft_size_; });
Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
kAudioOut(cc).Send(std::move(matrix));
return absl::OkStatus();
}
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
if (fft_state_) {
pffft_destroy_setup(fft_state_);
}
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsToAudioCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsToAudioCalculatorOptions ext = 484297136;
}
// Size of the fft in number of bins. If set, the calculator will do ifft
// on the input tensor.
optional int64 fft_size = 1;
}

View File

@ -0,0 +1,149 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <new>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
protected:
// Creates an audio matrix containing a single sample of 1.0 at a specified
// offset.
Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
Matrix impulse = Matrix::Zero(1, num_samples);
impulse(0, impulse_offset_idx) = 1.0;
return impulse;
}
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio_in"
input_stream: "sample_rate"
output_stream: "audio_out"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio_in"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: 1
num_samples: $0
num_overlapping_samples: 0
target_sample_rate: $1
fft_size: $2
}
}
}
node {
calculator: "TensorsToAudioCalculator"
input_stream: "TENSORS:tensors"
input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
output_stream: "AUDIO:audio_out"
options {
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
fft_size: $2
}
}
}
)",
/*$0=*/num_samples,
/*$1=*/sample_rate,
/*$2=*/fft_size));
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
}
void RunGraph(const Matrix& input_data, double sample_rate) {
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
std::vector<Packet> audio_out_packets_;
CalculatorGraphConfig graph_config_;
CalculatorGraph graph_;
};
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
ConfigGraph(320, 16000, 103);
MP_ASSERT_OK(graph_.Initialize(graph_config_));
MP_ASSERT_OK(graph_.StartRun({}));
auto status = graph_.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_THAT(status.message(),
::testing::HasSubstr("FFT size must be of the form"));
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// The impulse signal at the center is not affected by the window function.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the 1/4 of the hann window, the inverse
// window function reduces it by half.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
}
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
constexpr int sample_size = 320;
constexpr double sample_rate = 16000;
ConfigGraph(sample_size, sample_rate, 320);
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
RunGraph(impulse_data, sample_rate);
ASSERT_EQ(1, audio_out_packets_.size());
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
// As the impulse signal sits at the beginning of the hann window, the inverse
// window function completely removes it.
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
}
} // namespace
} // namespace mediapipe