Merge branch 'google:master' into gesture-recognizer-python
This commit is contained in:
commit
6f485ae3dd
|
@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: {
|
||||
calculator: "BypassCalculator"
|
||||
node_options: {
|
||||
|
|
|
@ -50,7 +50,7 @@ namespace mediapipe {
|
|||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
template <typename IterableT>
|
||||
class EndLoopCalculator : public CalculatorBase {
|
||||
|
|
|
@ -109,6 +109,56 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "tensors_to_audio_calculator_proto",
|
||||
srcs = ["tensors_to_audio_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensors_to_audio_calculator",
|
||||
srcs = ["tensors_to_audio_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":tensors_to_audio_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||
"@pffft",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensors_to_audio_calculator_test",
|
||||
srcs = ["tensors_to_audio_calculator_test.cc"],
|
||||
deps = [
|
||||
":audio_to_tensor_calculator",
|
||||
":audio_to_tensor_calculator_cc_proto",
|
||||
":tensors_to_audio_calculator",
|
||||
":tensors_to_audio_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "feedback_tensors_calculator_proto",
|
||||
srcs = ["feedback_tensors_calculator.proto"],
|
||||
|
|
|
@ -133,7 +133,7 @@ bool IsValidFftSize(int size) {
|
|||
// invocation. In the non-streaming mode, the vector contains all of the
|
||||
// output timestamps for an input audio buffer.
|
||||
// DC_AND_NYQUIST - std::pair<float, float> @Optional.
|
||||
// A pair of dc component and nyquest component. Only can be connected when
|
||||
// A pair of dc component and nyquist component. Only can be connected when
|
||||
// the calculator performs fft (the fft_size is set in the calculator
|
||||
// options).
|
||||
//
|
||||
|
|
197
mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
Normal file
197
mediapipe/calculators/tensor/tensors_to_audio_calculator.cc
Normal file
|
@ -0,0 +1,197 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <new>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "pffft.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
std::vector<float> HannWindow(int window_size, bool sqrt_hann) {
|
||||
std::vector<float> hann_window(window_size);
|
||||
audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window);
|
||||
if (sqrt_hann) {
|
||||
absl::c_transform(hann_window, hann_window.begin(),
|
||||
[](double x) { return std::sqrt(x); });
|
||||
}
|
||||
return hann_window;
|
||||
}
|
||||
|
||||
// Note that the InvHannWindow function may only work for 50% overlapping case.
|
||||
std::vector<float> InvHannWindow(int window_size, bool sqrt_hann) {
|
||||
std::vector<float> window = HannWindow(window_size, sqrt_hann);
|
||||
std::vector<float> inv_window(window.size());
|
||||
if (sqrt_hann) {
|
||||
absl::c_copy(window, inv_window.begin());
|
||||
} else {
|
||||
const int kHalfWindowSize = window.size() / 2;
|
||||
absl::c_transform(window, inv_window.begin(),
|
||||
[](double x) { return x * x; });
|
||||
for (int i = 0; i < kHalfWindowSize; ++i) {
|
||||
double sum = inv_window[i] + inv_window[kHalfWindowSize + i];
|
||||
inv_window[i] = window[i] / sum;
|
||||
inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum;
|
||||
}
|
||||
}
|
||||
return inv_window;
|
||||
}
|
||||
|
||||
// PFFFT only supports transforms for inputs of length N of the form
|
||||
// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
|
||||
bool IsValidFftSize(int size) {
|
||||
if (size <= 0) {
|
||||
return false;
|
||||
}
|
||||
constexpr int kFactors[] = {2, 3, 5};
|
||||
int factorization[] = {0, 0, 0};
|
||||
int n = static_cast<int>(size);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
while (n % kFactors[i] == 0) {
|
||||
n = n / kFactors[i];
|
||||
++factorization[i];
|
||||
}
|
||||
}
|
||||
return factorization[0] >= 5 && n == 1;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Converts 2D MediaPipe float Tensors to audio buffers.
|
||||
// The calculator will perform ifft on the complex DFT and apply the window
|
||||
// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must
|
||||
// have the DFT real parts in its first row and the DFT imagery parts in its
|
||||
// second row. A valid "fft_size" must be set in the CalculatorOptions.
|
||||
//
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor that represents the audio's complex DFT
|
||||
// results.
|
||||
// DC_AND_NYQUIST - std::pair<float, float>
|
||||
// A pair of dc component and nyquist component.
|
||||
//
|
||||
// Outputs:
|
||||
// AUDIO - mediapipe::Matrix
|
||||
// The audio data represented as mediapipe::Matrix.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "TensorsToAudioCalculator"
|
||||
// input_stream: "TENSORS:tensors"
|
||||
// input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
// output_stream: "AUDIO:audio"
|
||||
// options {
|
||||
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
// fft_size: 256
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TensorsToAudioCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||
static constexpr Input<std::pair<float, float>> kDcAndNyquistIn{
|
||||
"DC_AND_NYQUIST"};
|
||||
static constexpr Output<Matrix> kAudioOut{"AUDIO"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// The internal state of the FFT library.
|
||||
PFFFT_Setup* fft_state_ = nullptr;
|
||||
int fft_size_ = 0;
|
||||
float inverse_fft_size_ = 0;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> input_dft_;
|
||||
std::vector<float> inv_fft_window_;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_;
|
||||
// pffft requires memory to work with to avoid using the stack.
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_;
|
||||
std::vector<float, Eigen::aligned_allocator<float>> fft_output_;
|
||||
};
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::TensorsToAudioCalculatorOptions>();
|
||||
RET_CHECK(options.has_fft_size()) << "FFT size must be specified.";
|
||||
RET_CHECK(IsValidFftSize(options.fft_size()))
|
||||
<< "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b "
|
||||
">=0 and c >= 0 and a >= 5, the requested fft size is "
|
||||
<< options.fft_size();
|
||||
fft_size_ = options.fft_size();
|
||||
inverse_fft_size_ = 1.0f / fft_size_;
|
||||
fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL);
|
||||
input_dft_.resize(fft_size_);
|
||||
inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false);
|
||||
fft_input_buffer_.resize(fft_size_);
|
||||
fft_workplace_.resize(fft_size_);
|
||||
fft_output_.resize(fft_size_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) {
|
||||
if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& input_tensors = *kTensorsIn(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||
auto view = input_tensors[0].GetCpuReadView();
|
||||
// DC's real part.
|
||||
input_dft_[0] = kDcAndNyquistIn(cc)->first;
|
||||
// Nyquist's real part is the penultimate element of the tensor buffer.
|
||||
// pffft ignores the Nyquist's imagery part. No need to fetch the last value
|
||||
// from the tensor buffer.
|
||||
input_dft_[1] = *(view.buffer<float>() + (fft_size_ - 2));
|
||||
std::memcpy(input_dft_.data() + 2, view.buffer<float>(),
|
||||
(fft_size_ - 2) * sizeof(float));
|
||||
pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(),
|
||||
fft_workplace_.data(), PFFFT_BACKWARD);
|
||||
// Applies the inverse window function.
|
||||
std::transform(
|
||||
fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(),
|
||||
fft_output_.begin(),
|
||||
[this](float a, float b) { return a * b * inverse_fft_size_; });
|
||||
Matrix matrix = Eigen::Map<Matrix>(fft_output_.data(), 1, fft_output_.size());
|
||||
kAudioOut(cc).Send(std::move(matrix));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) {
|
||||
if (fft_state_) {
|
||||
pffft_destroy_setup(fft_state_);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TensorsToAudioCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TensorsToAudioCalculatorOptions ext = 484297136;
|
||||
}
|
||||
|
||||
// Size of the fft in number of bins. If set, the calculator will do ifft
|
||||
// on the input tensor.
|
||||
optional int64 fft_size = 1;
|
||||
}
|
149
mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
Normal file
149
mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class TensorsToAudioCalculatorFftTest : public ::testing::Test {
|
||||
protected:
|
||||
// Creates an audio matrix containing a single sample of 1.0 at a specified
|
||||
// offset.
|
||||
Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) {
|
||||
Matrix impulse = Matrix::Zero(1, num_samples);
|
||||
impulse(0, impulse_offset_idx) = 1.0;
|
||||
return impulse;
|
||||
}
|
||||
|
||||
void ConfigGraph(int num_samples, double sample_rate, int fft_size) {
|
||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "audio_in"
|
||||
input_stream: "sample_rate"
|
||||
output_stream: "audio_out"
|
||||
node {
|
||||
calculator: "AudioToTensorCalculator"
|
||||
input_stream: "AUDIO:audio_in"
|
||||
input_stream: "SAMPLE_RATE:sample_rate"
|
||||
output_stream: "TENSORS:tensors"
|
||||
output_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
options {
|
||||
[mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
num_channels: 1
|
||||
num_samples: $0
|
||||
num_overlapping_samples: 0
|
||||
target_sample_rate: $1
|
||||
fft_size: $2
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
calculator: "TensorsToAudioCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
input_stream: "DC_AND_NYQUIST:dc_and_nyquist"
|
||||
output_stream: "AUDIO:audio_out"
|
||||
options {
|
||||
[mediapipe.TensorsToAudioCalculatorOptions.ext] {
|
||||
fft_size: $2
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
/*$0=*/num_samples,
|
||||
/*$1=*/sample_rate,
|
||||
/*$2=*/fft_size));
|
||||
tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_);
|
||||
}
|
||||
|
||||
void RunGraph(const Matrix& input_data, double sample_rate) {
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"audio_in", MakePacket<Matrix>(input_data).At(Timestamp(0))));
|
||||
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||
}
|
||||
|
||||
std::vector<Packet> audio_out_packets_;
|
||||
CalculatorGraphConfig graph_config_;
|
||||
CalculatorGraph graph_;
|
||||
};
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) {
|
||||
ConfigGraph(320, 16000, 103);
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config_));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
auto status = graph_.WaitUntilIdle();
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(status.message(),
|
||||
::testing::HasSubstr("FFT size must be of the form"));
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// The impulse signal at the center is not affected by the window function.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data);
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// As the impulse signal sits at the 1/4 of the hann window, the inverse
|
||||
// window function reduces it by half.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), impulse_data / 2);
|
||||
}
|
||||
|
||||
TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) {
|
||||
constexpr int sample_size = 320;
|
||||
constexpr double sample_rate = 16000;
|
||||
ConfigGraph(sample_size, sample_rate, 320);
|
||||
Matrix impulse_data = CreateImpulseSignalData(sample_size, 0);
|
||||
RunGraph(impulse_data, sample_rate);
|
||||
ASSERT_EQ(1, audio_out_packets_.size());
|
||||
MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType<Matrix>());
|
||||
// As the impulse signal sits at the beginning of the hann window, the inverse
|
||||
// window function completely removes it.
|
||||
EXPECT_EQ(audio_out_packets_[0].Get<Matrix>(), Matrix::Zero(1, sample_size));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
|
|||
|
||||
// Returns a PacketSequencerCalculator node.
|
||||
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
|
||||
bool synchronize_io) {
|
||||
bool async_selection) {
|
||||
CalculatorGraphConfig::Node* result = config->add_node();
|
||||
*result->mutable_calculator() = "PacketSequencerCalculator";
|
||||
if (synchronize_io) {
|
||||
if (!async_selection) {
|
||||
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
|
||||
"DefaultInputStreamHandler";
|
||||
}
|
||||
|
@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
|
|||
return tags.count({tag, 0}) > 0;
|
||||
}
|
||||
|
||||
// Returns true if a set of "TAG::index" includes a TagIndex.
|
||||
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
|
||||
TagIndex item) {
|
||||
for (const std::string& t : tags) {
|
||||
if (ParseTagIndex(t) == item) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
||||
const Subgraph::SubgraphOptions& options) {
|
||||
CalculatorGraphConfig config;
|
||||
|
@ -263,17 +272,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
std::string enable_stream = "ENABLE:gate_enable";
|
||||
|
||||
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
|
||||
bool synchronize_io =
|
||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
|
||||
.synchronize_io();
|
||||
const auto& switch_options =
|
||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
|
||||
bool async_selection = switch_options.async_selection();
|
||||
if (HasTag(container_node.input_stream(), "SELECT")) {
|
||||
select_node = BuildTimestampNode(&config, synchronize_io);
|
||||
select_node = BuildTimestampNode(&config, async_selection);
|
||||
select_node->add_input_stream("INPUT:gate_select");
|
||||
select_node->add_output_stream("OUTPUT:gate_select_timed");
|
||||
select_stream = "SELECT:gate_select_timed";
|
||||
}
|
||||
if (HasTag(container_node.input_stream(), "ENABLE")) {
|
||||
enable_node = BuildTimestampNode(&config, synchronize_io);
|
||||
enable_node = BuildTimestampNode(&config, async_selection);
|
||||
enable_node->add_input_stream("INPUT:gate_enable");
|
||||
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
|
||||
enable_stream = "ENABLE:gate_enable_timed";
|
||||
|
@ -296,7 +305,7 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
mux->add_input_side_packet("SELECT:gate_select");
|
||||
mux->add_input_side_packet("ENABLE:gate_enable");
|
||||
|
||||
// Add input streams for graph and demux and the timestamper.
|
||||
// Add input streams for graph and demux.
|
||||
config.add_input_stream("SELECT:gate_select");
|
||||
config.add_input_stream("ENABLE:gate_enable");
|
||||
config.add_input_side_packet("SELECT:gate_select");
|
||||
|
@ -306,6 +315,12 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
std::string stream = CatStream(p.first, p.second);
|
||||
config.add_input_stream(stream);
|
||||
demux->add_input_stream(stream);
|
||||
}
|
||||
|
||||
// Add input streams for the timestamper.
|
||||
auto& tick_streams = switch_options.tick_input_stream();
|
||||
for (const auto& p : input_tags) {
|
||||
if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
|
||||
TagIndex tick_tag{"TICK", tick_index++};
|
||||
if (select_node) {
|
||||
select_node->add_input_stream(CatStream(tick_tag, p.second));
|
||||
|
|
|
@ -25,6 +25,14 @@ message SwitchContainerOptions {
|
|||
// Activates channel 1 for enable = true, channel 0 otherwise.
|
||||
optional bool enable = 4;
|
||||
|
||||
// Use DefaultInputStreamHandler for muxing & demuxing.
|
||||
// Use DefaultInputStreamHandler for demuxing.
|
||||
optional bool synchronize_io = 5;
|
||||
|
||||
// Use ImmediateInputStreamHandler for channel selection.
|
||||
optional bool async_selection = 6;
|
||||
|
||||
// Specifies an input stream, "TAG:index", that defines the processed
|
||||
// timestamps. SwitchContainer awaits output at the last processed
|
||||
// timestamp before advancing from one selected channel to the next.
|
||||
repeated string tick_input_stream = 7;
|
||||
}
|
||||
|
|
|
@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
input_stream: "INPUT:enable"
|
||||
input_stream: "TICK:foo"
|
||||
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "DefaultInputStreamHandler"
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "switchcontainer__SwitchDemuxCalculator"
|
||||
|
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
||||
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
||||
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
||||
CalculatorGraphConfig supergraph =
|
||||
SubnodeContainerExample("async_selection: true");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
RunTestContainer(supergraph);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
|
@ -54,21 +55,47 @@ namespace mediapipe {
|
|||
// contained subgraph or calculator nodes.
|
||||
//
|
||||
class SwitchDemuxCalculator : public CalculatorBase {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status RecordPackets(CalculatorContext* cc);
|
||||
int ChannelIndex(Timestamp timestamp);
|
||||
absl::Status SendActivePackets(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
int channel_index_;
|
||||
std::set<std::string> channel_tags_;
|
||||
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
|
||||
PacketQueue input_queue_;
|
||||
std::map<Timestamp, int> channel_history_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SwitchDemuxCalculator);
|
||||
|
||||
namespace {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
// Returns the last received timestamp for an input stream.
|
||||
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
|
||||
return input.Value().Timestamp();
|
||||
}
|
||||
|
||||
// Returns the last received timestamp for channel selection.
|
||||
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
|
||||
Timestamp result = Timestamp::Done();
|
||||
if (cc->Inputs().HasTag(kEnableTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
|
||||
} else if (cc->Inputs().HasTag(kSelectTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
||||
// Allow any one of kSelectTag, kEnableTag.
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||
|
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
|||
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets to all channels.
|
||||
// Note: This is necessary because Calculator::Open only proceeds when every
|
||||
|
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
|
||||
// Update the input channel index if specified.
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
MP_RETURN_IF_ERROR(RecordPackets(cc));
|
||||
MP_RETURN_IF_ERROR(SendActivePackets(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Relay packets and timestamps only to channel_index_.
|
||||
// Enqueue all arriving packets and bounds.
|
||||
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
|
||||
// Enqueue any new arriving packets.
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto& input = cc->Inputs().Get(tag, index);
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index_);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
auto& output = cc->Outputs().Get(output_tag, index);
|
||||
tool::Relay(input, &output);
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Packet packet = cc->Inputs().Get(input_id).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
input_queue_[input_id].push(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enque any new input channel and its activation timestamp.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
|
||||
if (channel_settled == cc->InputTimestamp() &&
|
||||
new_channel_index != channel_index_) {
|
||||
channel_index_ = new_channel_index;
|
||||
channel_history_[channel_settled] = channel_index_;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns the channel index for a Timestamp.
|
||||
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
|
||||
auto it = std::prev(channel_history_.upper_bound(timestamp));
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Dispatches all queued input packets with known channels.
|
||||
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
|
||||
// Dispatch any queued input packets with a defined channel_index.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
auto& queue = input_queue_[input_id];
|
||||
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
|
||||
int channel_index = ChannelIndex(queue.front().Timestamp());
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
cc->Outputs().Get(output_id).AddPacket(queue.front());
|
||||
}
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discard all select packets not needed for any remaining input packets.
|
||||
Timestamp input_settled = Timestamp::Done();
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
|
||||
if (!input_queue_[input_id].empty()) {
|
||||
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
|
||||
stream_settled =
|
||||
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
Timestamp input_bound = input_settled.NextAllowedInStream();
|
||||
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
|
||||
channel_history_.erase(channel_history_.begin(), history_bound);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
|
|||
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
|
||||
channel_history_[Timestamp::Unset()] = channel_index_;
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets only from channel_index_.
|
||||
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user