Add DftTensorFormat To TensorsToAudioCalculatorOptions.
PiperOrigin-RevId: 515077766
This commit is contained in:
		
							parent
							
								
									16c2e32a0d
								
							
						
					
					
						commit
						ddc535e705
					
				| 
						 | 
					@ -34,6 +34,8 @@ namespace mediapipe {
 | 
				
			||||||
namespace api2 {
 | 
					namespace api2 {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
 | 
					std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
 | 
				
			||||||
  std::vector<float> hann_window(window_size);
 | 
					  std::vector<float> hann_window(window_size);
 | 
				
			||||||
  audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
 | 
					  audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
 | 
				
			||||||
| 
						 | 
					@ -138,11 +140,15 @@ class TensorsToAudioCalculator : public Node {
 | 
				
			||||||
  std::vector<float, Eigen::aligned_allocator<float>> prev_fft_output_;
 | 
					  std::vector<float, Eigen::aligned_allocator<float>> prev_fft_output_;
 | 
				
			||||||
  int overlapping_samples_ = -1;
 | 
					  int overlapping_samples_ = -1;
 | 
				
			||||||
  int step_samples_ = -1;
 | 
					  int step_samples_ = -1;
 | 
				
			||||||
 | 
					  Options::DftTensorFormat dft_tensor_format_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
 | 
					absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
  const auto& options =
 | 
					  const auto& options =
 | 
				
			||||||
      cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
 | 
					      cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
 | 
				
			||||||
 | 
					  dft_tensor_format_ = options.dft_tensor_format();
 | 
				
			||||||
 | 
					  RET_CHECK(dft_tensor_format_ != Options::DFT_TENSOR_FORMAT_UNKNOWN)
 | 
				
			||||||
 | 
					      << "dft tensor format must be specified.";
 | 
				
			||||||
  RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
 | 
					  RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
 | 
				
			||||||
  RET_CHECK(IsValidFftSize(options.fft_size()))
 | 
					  RET_CHECK(IsValidFftSize(options.fft_size()))
 | 
				
			||||||
      << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
 | 
					      << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
 | 
				
			||||||
| 
						 | 
					@ -183,14 +189,37 @@ absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
					  RET_CHECK_EQ(input_tensors.size(), 1);
 | 
				
			||||||
  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
					  RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
 | 
				
			||||||
  auto view = input_tensors[0].GetCpuReadView();
 | 
					  auto view = input_tensors[0].GetCpuReadView();
 | 
				
			||||||
 | 
					  switch (dft_tensor_format_) {
 | 
				
			||||||
 | 
					    case Options::WITH_NYQUIST: {
 | 
				
			||||||
      // DC's real part.
 | 
					      // DC's real part.
 | 
				
			||||||
      input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
					      input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
				
			||||||
      // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
					      // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
				
			||||||
  // pffft ignores the Nyquist's imagery part. No need to fetch the last value
 | 
					      // pffft ignores the Nyquist's imagery part. No need to fetch the last
 | 
				
			||||||
  // from the tensor buffer.
 | 
					      // value from the tensor buffer.
 | 
				
			||||||
      input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
 | 
					      input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
 | 
				
			||||||
      std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
					      std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
				
			||||||
                  (fft_size_ - 2) * sizeof(float));
 | 
					                  (fft_size_ - 2) * sizeof(float));
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    case Options::WITH_DC_AND_NYQUIST: {
 | 
				
			||||||
 | 
					      // DC's real part is the first element of the tensor buffer.
 | 
				
			||||||
 | 
					      input_dft_[0] = *(view.buffer<float>());
 | 
				
			||||||
 | 
					      // Nyquist's real part is the penultimate element of the tensor buffer.
 | 
				
			||||||
 | 
					      input_dft_[1] = *(view.buffer<float>() + fft_size_);
 | 
				
			||||||
 | 
					      std::memcpy(input_dft_.data() + 2, view.buffer<float>() + 2,
 | 
				
			||||||
 | 
					                  (fft_size_ - 2) * sizeof(float));
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    case Options::WITHOUT_DC_AND_NYQUIST: {
 | 
				
			||||||
 | 
					      input_dft_[0] = kDcAndNyquistIn(cc)->first;
 | 
				
			||||||
 | 
					      input_dft_[1] = kDcAndNyquistIn(cc)->second;
 | 
				
			||||||
 | 
					      std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
 | 
				
			||||||
 | 
					                  (fft_size_ - 2) * sizeof(float));
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return absl::InvalidArgumentError("Unsupported dft tensor format.");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
 | 
					  pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
 | 
				
			||||||
                          fft_workplace_.data(), PFFFT_BACKWARD);
 | 
					                          fft_workplace_.data(), PFFFT_BACKWARD);
 | 
				
			||||||
  // Applies the inverse window function.
 | 
					  // Applies the inverse window function.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,4 +32,17 @@ message TensorsToAudioCalculatorOptions {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The number of overlapping samples between adjacent windows.
 | 
					  // The number of overlapping samples between adjacent windows.
 | 
				
			||||||
  optional int64 num_overlapping_samples = 3 [default = 0];
 | 
					  optional int64 num_overlapping_samples = 3 [default = 0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  enum DftTensorFormat {
 | 
				
			||||||
 | 
					    DFT_TENSOR_FORMAT_UNKNOWN = 0;
 | 
				
			||||||
 | 
					    // The input dft tensor without dc and nyquist components.
 | 
				
			||||||
 | 
					    WITHOUT_DC_AND_NYQUIST = 1;
 | 
				
			||||||
 | 
					    // The input dft tensor contains the nyquist component as the last
 | 
				
			||||||
 | 
					    // two values.
 | 
				
			||||||
 | 
					    WITH_NYQUIST = 2;
 | 
				
			||||||
 | 
					    // The input dft tensor contains the dc component as the first two values
 | 
				
			||||||
 | 
					    // and the nyquist component as the last two values.
 | 
				
			||||||
 | 
					    WITH_DC_AND_NYQUIST = 3;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,8 @@
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using Options = ::mediapipe::TensorsToAudioCalculatorOptions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
					class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  // Creates an audio matrix containing a single sample of 1.0 at a specified
 | 
					  // Creates an audio matrix containing a single sample of 1.0 at a specified
 | 
				
			||||||
| 
						 | 
					@ -40,9 +42,10 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
				
			||||||
    return impulse;
 | 
					    return impulse;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
 | 
					  void ConfigGraph(int num_samples, double sample_rate, int fft_size,
 | 
				
			||||||
    graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
					                   Options::DftTensorFormat dft_tensor_format) {
 | 
				
			||||||
        absl::Substitute(R"(
 | 
					    graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
 | 
				
			||||||
 | 
					        R"(
 | 
				
			||||||
        input_stream: "audio_in"
 | 
					        input_stream: "audio_in"
 | 
				
			||||||
        input_stream: "sample_rate"
 | 
					        input_stream: "sample_rate"
 | 
				
			||||||
        output_stream: "audio_out"
 | 
					        output_stream: "audio_out"
 | 
				
			||||||
| 
						 | 
					@ -59,6 +62,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
				
			||||||
              num_overlapping_samples: 0
 | 
					              num_overlapping_samples: 0
 | 
				
			||||||
              target_sample_rate: $1
 | 
					              target_sample_rate: $1
 | 
				
			||||||
              fft_size: $2
 | 
					              fft_size: $2
 | 
				
			||||||
 | 
					              dft_tensor_format: $3
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
| 
						 | 
					@ -70,13 +74,15 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
				
			||||||
          options {
 | 
					          options {
 | 
				
			||||||
            [mediapipe.TensorsToAudioCalculatorOptions.ext] {
 | 
					            [mediapipe.TensorsToAudioCalculatorOptions.ext] {
 | 
				
			||||||
              fft_size: $2
 | 
					              fft_size: $2
 | 
				
			||||||
 | 
					              dft_tensor_format: $3
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        )",
 | 
					        )",
 | 
				
			||||||
        /*$0=*/num_samples,
 | 
					        /*$0=*/num_samples,
 | 
				
			||||||
        /*$1=*/sample_rate,
 | 
					        /*$1=*/sample_rate,
 | 
				
			||||||
                         /*$2=*/fft_size));
 | 
					        /*$2=*/fft_size,
 | 
				
			||||||
 | 
					        /*$3=*/Options::DftTensorFormat_Name(dft_tensor_format)));
 | 
				
			||||||
    tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
 | 
					    tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,7 +103,7 @@ class TensorsToAudioCalculatorFftTest : public ::testing::Test {
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
				
			||||||
  ConfigGraph(320, 16000, 103);
 | 
					  ConfigGraph(320, 16000, 103, Options::WITH_NYQUIST);
 | 
				
			||||||
  MP_ASSERT_OK(graph_.Initialize(graph_config_));
 | 
					  MP_ASSERT_OK(graph_.Initialize(graph_config_));
 | 
				
			||||||
  MP_ASSERT_OK(graph_.StartRun({}));
 | 
					  MP_ASSERT_OK(graph_.StartRun({}));
 | 
				
			||||||
  auto status = graph_.WaitUntilIdle();
 | 
					  auto status = graph_.WaitUntilIdle();
 | 
				
			||||||
| 
						 | 
					@ -109,8 +115,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
 | 
				
			||||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
				
			||||||
  constexpr int sample_size = 320;
 | 
					  constexpr int sample_size = 320;
 | 
				
			||||||
  constexpr double sample_rate = 16000;
 | 
					  constexpr double sample_rate = 16000;
 | 
				
			||||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
					  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
				
			||||||
 | 
					 | 
				
			||||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
					  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
				
			||||||
  RunGraph(impulse_data, sample_rate);
 | 
					  RunGraph(impulse_data, sample_rate);
 | 
				
			||||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
					  ASSERT_EQ(1, audio_out_packets_.size());
 | 
				
			||||||
| 
						 | 
					@ -122,7 +127,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
 | 
				
			||||||
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
				
			||||||
  constexpr int sample_size = 320;
 | 
					  constexpr int sample_size = 320;
 | 
				
			||||||
  constexpr double sample_rate = 16000;
 | 
					  constexpr double sample_rate = 16000;
 | 
				
			||||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
					  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
				
			||||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
 | 
					  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
 | 
				
			||||||
  RunGraph(impulse_data, sample_rate);
 | 
					  RunGraph(impulse_data, sample_rate);
 | 
				
			||||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
					  ASSERT_EQ(1, audio_out_packets_.size());
 | 
				
			||||||
| 
						 | 
					@ -135,7 +140,7 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
 | 
				
			||||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
				
			||||||
  constexpr int sample_size = 320;
 | 
					  constexpr int sample_size = 320;
 | 
				
			||||||
  constexpr double sample_rate = 16000;
 | 
					  constexpr double sample_rate = 16000;
 | 
				
			||||||
  ConfigGraph(sample_size, sample_rate, 320);
 | 
					  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_NYQUIST);
 | 
				
			||||||
  Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
 | 
					  Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
 | 
				
			||||||
  RunGraph(impulse_data, sample_rate);
 | 
					  RunGraph(impulse_data, sample_rate);
 | 
				
			||||||
  ASSERT_EQ(1, audio_out_packets_.size());
 | 
					  ASSERT_EQ(1, audio_out_packets_.size());
 | 
				
			||||||
| 
						 | 
					@ -145,5 +150,31 @@ TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
 | 
				
			||||||
  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
 | 
					  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithDCAndNyquist) {
 | 
				
			||||||
 | 
					  constexpr int sample_size = 320;
 | 
				
			||||||
 | 
					  constexpr double sample_rate = 16000;
 | 
				
			||||||
 | 
					  ConfigGraph(sample_size, sample_rate, 320, Options::WITH_DC_AND_NYQUIST);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
				
			||||||
 | 
					  RunGraph(impulse_data, sample_rate);
 | 
				
			||||||
 | 
					  ASSERT_EQ(1, audio_out_packets_.size());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
				
			||||||
 | 
					  // The impulse signal at the center is not affected by the window function.
 | 
				
			||||||
 | 
					  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(TensorsToAudioCalculatorFftTest, TestDftTensorWithoutDCAndNyquist) {
 | 
				
			||||||
 | 
					  constexpr int sample_size = 320;
 | 
				
			||||||
 | 
					  constexpr double sample_rate = 16000;
 | 
				
			||||||
 | 
					  ConfigGraph(sample_size, sample_rate, 320, Options::WITHOUT_DC_AND_NYQUIST);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
 | 
				
			||||||
 | 
					  RunGraph(impulse_data, sample_rate);
 | 
				
			||||||
 | 
					  ASSERT_EQ(1, audio_out_packets_.size());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
 | 
				
			||||||
 | 
					  // The impulse signal at the center is not affected by the window function.
 | 
				
			||||||
 | 
					  EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user