Add DftTensorFormat To TensorsToAudioCalculatorOptions.
PiperOrigin-RevId: 515077766
This commit is contained in:
		
							parent
							
								
									16c2e32a0d
								
							
						
					
					
						commit
						ddc535e705
					
				| 
						 | 
				
			
			@ -34,6 +34,8 @@ namespace mediapipe {
 | 
			
		|||
namespace api2 {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
 | 
			
		||||
 | 
			
		||||
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
 | 
			
		||||
  std::vector<float> hann_window(window_size);
 | 
			
		||||
  audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
 | 
			
		||||
| 
						 | 
				
			
			@ -138,11 +140,15 @@ class TensorsToAudioCalculator : public Node {
 | 
			
		|||
  std::vector<float, Eigen::aligned_allocator<float>> prev_fft_output_;
 | 
			
		||||
  int overlapping_samples_ = -1;
 | 
			
		||||
  int step_samples_ = -1;
 | 
			
		||||
  Options::DftTensorFormat dft_tensor_format_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
 | 
			
		||||
  const auto& options =
 | 
			
		||||
      cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
 | 
			
		||||
  dft_tensor_format_ = options.dft_tensor_format();
 | 
			
		||||
  RET_CHECK(dft_tensor_format_ != Options::DFT_TENSOR_FORMAT_UNKNOWN)
 | 
			
		||||
      << "dft tensor format must be specified.";
 | 
			
		||||
  RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
 | 
			
		||||
  RET_CHECK(IsValidFftSize(options.fft_size()))
 | 
			
		||||
      << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
 | 
			
		||||
| 
						 | 
				
			
			@ -183,14 +189,37 @@ absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
			
		||||
  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
			
		||||
  auto view = input_tensors[0].GetCpuReadView();
 | 
			
		||||
  // DC's real part.
 | 
			
		||||
  input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
			
		||||
  // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
			
		||||
  // pffft ignores the Nyquist's imagery part. No need to fetch the last value
 | 
			
		||||
  // from the tensor buffer.
 | 
			
		||||
  input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
 | 
			
		||||
  std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
			
		||||
              (fft_size_ - 2) * sizeof(float));
 | 
			
		||||
  switch (dft_tensor_format_) {
 | 
			
		||||
    case Options::WITH_NYQUIST: {
 | 
			
		||||
      // DC's real part.
 | 
			
		||||
      input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
			
		||||
      // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
			
		||||
      // pffft ignores the Nyquist's imagery part. No need to fetch the last
 | 
			
		||||
      // value from the tensor buffer.
 | 
			
		||||
      input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
 | 
			
		||||
      std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
			
		||||
                  (fft_size_ - 2) * sizeof(float));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case Options::WITH_DC_AND_NYQUIST: {
 | 
			
		||||
      // DC's real part is the first element of the tensor buffer.
 | 
			
		||||
      input_dft_[0] = *(view.buffer<float>());
 | 
			
		||||
      // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
			
		||||
      input_dft_[1] = *(view.buffer<float>() + fft_size_);
 | 
			
		||||
      std::memcpy(input_dft_.data() + 2, view.buffer<float>() + 2,
 | 
			
		||||
                  (fft_size_ - 2) * sizeof(float));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case Options::WITHOUT_DC_AND_NYQUIST: {
 | 
			
		||||
      input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
			
		||||
      input_dft_[1] = kDcAndNyquistIn(cc)->second;
 | 
			
		||||
      std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
			
		||||
                  (fft_size_ - 2) * sizeof(float));
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      return absl::InvalidArgumentError("Unsupported dft tensor format.");
 | 
			
		||||
  }
 | 
			
		||||
  pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
 | 
			
		||||
                          fft_workplace_.data(), PFFFT_BACKWARD);
 | 
			
		||||
  // Applies the inverse window function.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,4 +32,17 @@ message TensorsToAudioCalculatorOptions {
 | 
			
		|||
 | 
			
		||||
  // The number of overlapping samples between adjacent windows.
 | 
			
		||||
  optional int64 num_overlapping_samples = 3 [default = 0];
 | 
			
		||||
 | 
			
		||||
  enum DftTensorFormat {
 | 
			
		||||
    DFT_TENSOR_FORMAT_UNKNOWN = 0;
 | 
			
		||||
    // The input dft tensor without dc and nyquist components.
 | 
			
		||||
    WITHOUT_DC_AND_NYQUIST = 1;
 | 
			
		||||
    // The input dft tensor contains the nyquist component as the last
 | 
			
		||||
    // two values.
 | 
			
		||||
    WITH_NYQUIST = 2;
 | 
			
		||||
    // The input dft tensor contains the dc component as the first two values
 | 
			
		||||
    // and the nyquist component as the last two values.
 | 
			
		||||
    WITH_DC_AND_NYQUIST = 3;
 | 
			
		||||
  }
 | 
			
		||||
  optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST];
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,8 @@
 | 
			
		|||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
 | 
			
		||||
 | 
			
		||||
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  // Creates an audio matrix containing a single sample of 1.0 at a specified
 | 
			
		||||
| 
						 | 
				
			
			@ -40,9 +42,10 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		|||
    return impulse;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
 | 
			
		||||
    graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
			
		||||
        absl::Substitute(R"(
 | 
			
		||||
  void ConfigGraph(int num_samples, double sample_rate, int fft_size,
 | 
			
		||||
                   Options::DftTensorFormat dft_tensor_format) {
 | 
			
		||||
    graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
 | 
			
		||||
        R"(
 | 
			
		||||
        input_stream: "audio_in"
 | 
			
		||||
        input_stream: "sample_rate"
 | 
			
		||||
        output_stream: "audio_out"
 | 
			
		||||
| 
						 | 
				
			
			@ -59,6 +62,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		|||
              num_overlapping_samples: 0
 | 
			
		||||
              target_sample_rate: $1
 | 
			
		||||
              fft_size: $2
 | 
			
		||||
              dft_tensor_format: $3
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -70,13 +74,15 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		|||
          options {
 | 
			
		||||
            [mediapipe.TensorsToAudioCalculatorOptions.ext] {
 | 
			
		||||
              fft_size: $2
 | 
			
		||||
              dft_tensor_format: $3
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        )",
 | 
			
		||||
                         /*$0=*/num_samples,
 | 
			
		||||
                         /*$1=*/sample_rate,
 | 
			
		||||
                         /*$2=*/fft_size));
 | 
			
		||||
        /*$0=*/num_samples,
 | 
			
		||||
        /*$1=*/sample_rate,
 | 
			
		||||
        /*$2=*/fft_size,
 | 
			
		||||
        /*$3=*/Options::DftTensorFormat_Name(dft_tensor_format)));
 | 
			
		||||
    tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +103,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
			
		|||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
			
		||||
  ConfigGraph(320, 16000, 103);
 | 
			
		||||
  ConfigGraph(320, 16000, 103, Options::WITH_NYQUIST);
 | 
			
		||||
  MP_ASSERT_OK(graph_.Initialize(graph_config_));
 | 
			
		||||
  MP_ASSERT_OK(graph_.StartRun({}));
 | 
			
		||||
  auto status = graph_.WaitUntilIdle();
 | 
			
		||||
| 
						 | 
				
			
			@ -109,8 +115,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
			
		|||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +127,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
			
		|||
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +140,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
			
		|||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
| 
						 | 
				
			
			@ -145,5 +150,31 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
			
		|||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithDCAndNyquist) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_DC_AND_NYQUIST);
 | 
			
		||||
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
			
		||||
  // The impulse signal at the center is not affected by the window function.
 | 
			
		||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithoutDCAndNyquist) {
 | 
			
		||||
  constexpr int sample_size = 320;
 | 
			
		||||
  constexpr double sample_rate = 16000;
 | 
			
		||||
  ConfigGraph(sample_size, sample_rate, 320, Options::WITHOUT_DC_AND_NYQUIST);
 | 
			
		||||
 | 
			
		||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
			
		||||
  RunGraph(impulse_data, sample_rate);
 | 
			
		||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
			
		||||
  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
			
		||||
  // The impulse signal at the center is not affected by the window function.
 | 
			
		||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user