Speed up TimeSeriesFramerCalculator.
Currently, TimeSeriesFramerCalculator constructs a distinct Matrix object for every input sample, which is inefficient. This CL revises buffering to keep each input packet's worth of samples as one grouped Matrix. A benchmark is added, showing a speed up of about 20x. ``` name old new BM_TimeSeriesFramerCalculator 48.45ms 2.26ms ``` PiperOrigin-RevId: 542462618
This commit is contained in:
parent
0d2548cd65
commit
825e3a8af0
|
@ -219,12 +219,10 @@ cc_library(
|
|||
deps = [
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@eigen_archive//:eigen3",
|
||||
|
@ -319,6 +317,20 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "time_series_framer_calculator_benchmark",
|
||||
srcs = ["time_series_framer_calculator_benchmark.cc"],
|
||||
deps = [
|
||||
":time_series_framer_calculator",
|
||||
":time_series_framer_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "time_series_framer_calculator_test",
|
||||
srcs = ["time_series_framer_calculator_test.cc"],
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
// Defines TimeSeriesFramerCalculator.
|
||||
#include <math.h>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
|
@ -25,9 +23,8 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -88,11 +85,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Adds input data to the internal buffer.
|
||||
void EnqueueInput(CalculatorContext* cc);
|
||||
// Constructs and emits framed output packets.
|
||||
void FrameOutput(CalculatorContext* cc);
|
||||
|
||||
Timestamp CurrentOutputTimestamp() {
|
||||
if (use_local_timestamp_) {
|
||||
return current_timestamp_;
|
||||
|
@ -106,14 +98,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
|||
Timestamp::kTimestampUnitsPerSecond);
|
||||
}
|
||||
|
||||
// Returns the timestamp of a sample on a base, which is usually the time
|
||||
// stamp of a packet.
|
||||
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
|
||||
int64_t number_of_samples) {
|
||||
return timestamp_base + round(number_of_samples / sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
}
|
||||
|
||||
// The number of input samples to advance after the current output frame is
|
||||
// emitted.
|
||||
int next_frame_step_samples() const {
|
||||
|
@ -142,61 +126,172 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
|||
Timestamp initial_input_timestamp_;
|
||||
// The current timestamp is updated along with the incoming packets.
|
||||
Timestamp current_timestamp_;
|
||||
int num_channels_;
|
||||
|
||||
// Each entry in this deque consists of a single sample, i.e. a
|
||||
// single column vector, and its timestamp.
|
||||
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
|
||||
// Samples are buffered in a vector of sample blocks.
|
||||
class SampleBlockBuffer {
|
||||
public:
|
||||
// Initializes the buffer.
|
||||
void Init(double sample_rate, int num_channels) {
|
||||
ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
|
||||
num_channels_ = num_channels;
|
||||
num_samples_ = 0;
|
||||
first_block_offset_ = 0;
|
||||
}
|
||||
|
||||
// Number of channels, equal to the number of rows in each Matrix.
|
||||
int num_channels() const { return num_channels_; }
|
||||
// Total number of available samples over all blocks.
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
// Pushes a new block of samples on the back of the buffer with `timestamp`
|
||||
// being the input timestamp of the packet containing the Matrix.
|
||||
void Push(const Matrix& samples, Timestamp timestamp);
|
||||
// Copies `count` samples from the front of the buffer. If there are fewer
|
||||
// samples than this, the result is zero padded to have `count` samples.
|
||||
// The timestamp of the last copied sample is written to *last_timestamp.
|
||||
// This output is used below to update `current_timestamp_`, which is only
|
||||
// used when `use_local_timestamp` is true.
|
||||
Matrix CopySamples(int count, Timestamp* last_timestamp) const;
|
||||
// Drops `count` samples from the front of the buffer. If `count` exceeds
|
||||
// `num_samples()`, the buffer is emptied. Returns how many samples were
|
||||
// dropped.
|
||||
int DropSamples(int count);
|
||||
|
||||
private:
|
||||
struct Block {
|
||||
// Matrix of num_channels rows by num_samples columns, a block of possibly
|
||||
// multiple samples.
|
||||
Matrix samples;
|
||||
// Timestamp of the first sample in the Block. This comes from the input
|
||||
// packet's timestamp that contains this Matrix.
|
||||
Timestamp timestamp;
|
||||
|
||||
Block() : timestamp(Timestamp::Unstarted()) {}
|
||||
Block(const Matrix& samples, Timestamp timestamp)
|
||||
: samples(samples), timestamp(timestamp) {}
|
||||
int num_samples() const { return samples.cols(); }
|
||||
};
|
||||
std::vector<Block> blocks_;
|
||||
// Number of timestamp units per sample. Used to compute timestamps as
|
||||
// nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
|
||||
double ts_units_per_sample_;
|
||||
// Number of rows in each Matrix.
|
||||
int num_channels_;
|
||||
// The total number of samples over all blocks, equal to
|
||||
// (sum_i blocks_[i].num_samples()) - first_block_offset_.
|
||||
int num_samples_;
|
||||
// The number of samples in the first block that have been discarded. This
|
||||
// way we can cheaply represent "partially discarding" a block.
|
||||
int first_block_offset_;
|
||||
} sample_buffer_;
|
||||
|
||||
bool use_window_;
|
||||
Matrix window_;
|
||||
Eigen::RowVectorXf window_;
|
||||
|
||||
bool use_local_timestamp_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
|
||||
|
||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
|
||||
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
|
||||
|
||||
for (int i = 0; i < input_frame.cols(); ++i) {
|
||||
sample_buffer_.emplace_back(std::make_pair(
|
||||
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
|
||||
}
|
||||
void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
|
||||
Timestamp timestamp) {
|
||||
num_samples_ += samples.cols();
|
||||
blocks_.emplace_back(samples, timestamp);
|
||||
}
|
||||
|
||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
||||
while (sample_buffer_.size() >=
|
||||
Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
|
||||
int count, Timestamp* last_timestamp) const {
|
||||
Matrix copied(num_channels_, count);
|
||||
|
||||
if (!blocks_.empty()) {
|
||||
int num_copied = 0;
|
||||
// First block has an offset for samples that have been discarded.
|
||||
int offset = first_block_offset_;
|
||||
int n;
|
||||
Timestamp last_block_ts;
|
||||
|
||||
for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
|
||||
n = std::min(it->num_samples() - offset, count);
|
||||
// Copy `n` samples from the next block.
|
||||
copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
|
||||
count -= n;
|
||||
num_copied += n;
|
||||
last_block_ts = it->timestamp;
|
||||
offset = 0; // No samples have been discarded in subsequent blocks.
|
||||
}
|
||||
|
||||
// Compute the timestamp of the last copied sample.
|
||||
*last_timestamp =
|
||||
last_block_ts + std::round(ts_units_per_sample_ * (n - 1));
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
copied.rightCols(count).setZero(); // Zero pad if needed.
|
||||
}
|
||||
|
||||
return copied;
|
||||
}
|
||||
|
||||
int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
|
||||
if (blocks_.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto block_it = blocks_.begin();
|
||||
if (first_block_offset_ + count < block_it->num_samples()) {
|
||||
// `count` is less than the remaining samples in the first block.
|
||||
first_block_offset_ += count;
|
||||
num_samples_ -= count;
|
||||
return count;
|
||||
}
|
||||
|
||||
int num_samples_dropped = block_it->num_samples() - first_block_offset_;
|
||||
count -= num_samples_dropped;
|
||||
first_block_offset_ = 0;
|
||||
|
||||
for (++block_it; block_it != blocks_.end(); ++block_it) {
|
||||
if (block_it->num_samples() > count) {
|
||||
break;
|
||||
}
|
||||
num_samples_dropped += block_it->num_samples();
|
||||
count -= block_it->num_samples();
|
||||
}
|
||||
|
||||
blocks_.erase(blocks_.begin(), block_it); // Drop whole blocks.
|
||||
if (!blocks_.empty()) {
|
||||
first_block_offset_ = count; // Drop part of the next block.
|
||||
num_samples_dropped += count;
|
||||
}
|
||||
|
||||
num_samples_ -= num_samples_dropped;
|
||||
return num_samples_dropped;
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
current_timestamp_ = initial_input_timestamp_;
|
||||
}
|
||||
|
||||
// Add input data to the internal buffer.
|
||||
sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
|
||||
cc->InputTimestamp());
|
||||
|
||||
// Construct and emit framed output packets.
|
||||
while (sample_buffer_.num_samples() >=
|
||||
frame_duration_samples_ + samples_still_to_drop_) {
|
||||
while (samples_still_to_drop_ > 0) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
sample_buffer_.DropSamples(samples_still_to_drop_);
|
||||
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
|
||||
¤t_timestamp_);
|
||||
const int frame_step_samples = next_frame_step_samples();
|
||||
std::unique_ptr<Matrix> output_frame(
|
||||
new Matrix(num_channels_, frame_duration_samples_));
|
||||
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
|
||||
++i) {
|
||||
output_frame->col(i) = sample_buffer_.front().first;
|
||||
current_timestamp_ = sample_buffer_.front().second;
|
||||
sample_buffer_.pop_front();
|
||||
}
|
||||
const int frame_overlap_samples =
|
||||
frame_duration_samples_ - frame_step_samples;
|
||||
if (frame_overlap_samples > 0) {
|
||||
for (int i = 0; i < frame_overlap_samples; ++i) {
|
||||
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
|
||||
current_timestamp_ = sample_buffer_[i].second;
|
||||
}
|
||||
} else {
|
||||
samples_still_to_drop_ = -frame_overlap_samples;
|
||||
}
|
||||
samples_still_to_drop_ = frame_step_samples;
|
||||
|
||||
if (use_window_) {
|
||||
*output_frame = (output_frame->array() * window_.array()).matrix();
|
||||
// Apply the window to each row of output_frame.
|
||||
output_frame.array().rowwise() *= window_.array();
|
||||
}
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
|
||||
.At(CurrentOutputTimestamp()));
|
||||
++cumulative_output_frames_;
|
||||
cumulative_completed_samples_ += frame_step_samples;
|
||||
}
|
||||
|
@ -206,35 +301,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
|||
// fact to enable packet queueing optimizations.
|
||||
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_input_timestamp_ = cc->InputTimestamp();
|
||||
current_timestamp_ = initial_input_timestamp_;
|
||||
}
|
||||
|
||||
EnqueueInput(cc);
|
||||
FrameOutput(cc);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
|
||||
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
|
||||
sample_buffer_.pop_front();
|
||||
--samples_still_to_drop_;
|
||||
}
|
||||
if (!sample_buffer_.empty() && pad_final_packet_) {
|
||||
std::unique_ptr<Matrix> output_frame(new Matrix);
|
||||
output_frame->setZero(num_channels_, frame_duration_samples_);
|
||||
for (int i = 0; i < sample_buffer_.size(); ++i) {
|
||||
output_frame->col(i) = sample_buffer_[i].first;
|
||||
current_timestamp_ = sample_buffer_[i].second;
|
||||
}
|
||||
sample_buffer_.DropSamples(samples_still_to_drop_);
|
||||
|
||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
||||
CurrentOutputTimestamp());
|
||||
if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
|
||||
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
|
||||
¤t_timestamp_);
|
||||
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
|
||||
.At(CurrentOutputTimestamp()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -258,7 +336,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
|||
cc->Inputs().Index(0).Header(), &input_header));
|
||||
|
||||
sample_rate_ = input_header.sample_rate();
|
||||
num_channels_ = input_header.num_channels();
|
||||
sample_buffer_.Init(sample_rate_, input_header.num_channels());
|
||||
frame_duration_samples_ = time_series_util::SecondsToSamples(
|
||||
framer_options.frame_duration_seconds(), sample_rate_);
|
||||
RET_CHECK_GT(frame_duration_samples_, 0)
|
||||
|
@ -312,9 +390,8 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
if (use_window_) {
|
||||
window_ = Matrix::Ones(num_channels_, 1) *
|
||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
||||
frame_duration_samples_)
|
||||
window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
|
||||
frame_duration_samples_)
|
||||
.cast<float>();
|
||||
}
|
||||
use_local_timestamp_ = framer_options.use_local_timestamp();
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Benchmark for TimeSeriesFramerCalculator.
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "benchmark/benchmark.h"
|
||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
||||
using ::mediapipe::Matrix;
|
||||
|
||||
void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
|
||||
constexpr float kSampleRate = 32000.0;
|
||||
constexpr int kNumChannels = 2;
|
||||
constexpr int kFrameDurationSeconds = 5.0;
|
||||
std::mt19937 rng(0 /*seed*/);
|
||||
// Input around a half second's worth of samples at a time.
|
||||
std::uniform_int_distribution<int> input_size_dist(15000, 17000);
|
||||
// Generate a pool of random blocks of samples up front.
|
||||
std::vector<Matrix> sample_pool;
|
||||
sample_pool.reserve(20);
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
|
||||
}
|
||||
std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
|
||||
|
||||
mediapipe::CalculatorGraphConfig config;
|
||||
config.add_input_stream("input");
|
||||
config.add_output_stream("output");
|
||||
auto* node = config.add_node();
|
||||
node->set_calculator("TimeSeriesFramerCalculator");
|
||||
node->add_input_stream("input");
|
||||
node->add_output_stream("output");
|
||||
mediapipe::TimeSeriesFramerCalculatorOptions* options =
|
||||
node->mutable_options()->MutableExtension(
|
||||
mediapipe::TimeSeriesFramerCalculatorOptions::ext);
|
||||
options->set_frame_duration_seconds(kFrameDurationSeconds);
|
||||
|
||||
for (auto _ : state) {
|
||||
state.PauseTiming(); // Pause benchmark timing.
|
||||
|
||||
// Prepare input packets of random blocks of samples.
|
||||
std::vector<mediapipe::Packet> input_packets;
|
||||
input_packets.reserve(32);
|
||||
float t = 0;
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
auto samples =
|
||||
std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
|
||||
const int num_samples = samples->cols();
|
||||
input_packets.push_back(mediapipe::Adopt(samples.release())
|
||||
.At(mediapipe::Timestamp::FromSeconds(t)));
|
||||
t += num_samples / kSampleRate;
|
||||
}
|
||||
// Initialize graph.
|
||||
mediapipe::CalculatorGraph graph;
|
||||
CHECK_OK(graph.Initialize(config));
|
||||
// Prepare input header.
|
||||
auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
|
||||
header->set_sample_rate(kSampleRate);
|
||||
header->set_num_channels(kNumChannels);
|
||||
|
||||
state.ResumeTiming(); // Resume benchmark timing.
|
||||
|
||||
CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
|
||||
for (auto& packet : input_packets) {
|
||||
CHECK_OK(graph.AddPacketToInputStream("input", packet));
|
||||
}
|
||||
CHECK(!graph.HasError());
|
||||
CHECK_OK(graph.CloseAllInputStreams());
|
||||
CHECK_OK(graph.WaitUntilIdle());
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_TimeSeriesFramerCalculator);
|
||||
|
||||
BENCHMARK_MAIN();
|
Loading…
Reference in New Issue
Block a user