diff --git a/mediapipe/calculators/core/bypass_calculator_test.cc b/mediapipe/calculators/core/bypass_calculator_test.cc index 4d1cd8f79..224742a13 100644 --- a/mediapipe/calculators/core/bypass_calculator_test.cc +++ b/mediapipe/calculators/core/bypass_calculator_test.cc @@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb( output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" options { [mediapipe.SwitchContainerOptions.ext] { + async_selection: true contained_node: { calculator: "AppearancesPassThroughSubgraph" } } } @@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb( output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" options { [mediapipe.SwitchContainerOptions.ext] { + async_selection: true contained_node: { calculator: "BypassCalculator" node_options: { diff --git a/mediapipe/calculators/core/end_loop_calculator.h b/mediapipe/calculators/core/end_loop_calculator.h index e40301e81..9f56657d0 100644 --- a/mediapipe/calculators/core/end_loop_calculator.h +++ b/mediapipe/calculators/core/end_loop_calculator.h @@ -50,7 +50,7 @@ namespace mediapipe { // calculator: "EndLoopWithOutputCalculator" // input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts // input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts -// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts +// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts // } template class EndLoopCalculator : public CalculatorBase { diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 99b5b3e91..a966dd3fc 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -109,6 +109,56 @@ cc_test( ], ) +mediapipe_proto_library( + name = "tensors_to_audio_calculator_proto", + srcs = ["tensors_to_audio_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "tensors_to_audio_calculator", + srcs = ["tensors_to_audio_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":tensors_to_audio_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_audio_tools//audio/dsp:window_functions", + "@pffft", + ], + alwayslink = 1, +) + +cc_test( + name = "tensors_to_audio_calculator_test", + srcs = ["tensors_to_audio_calculator_test.cc"], + deps = [ + ":audio_to_tensor_calculator", + ":audio_to_tensor_calculator_cc_proto", + ":tensors_to_audio_calculator", + ":tensors_to_audio_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + mediapipe_proto_library( name = "feedback_tensors_calculator_proto", srcs = ["feedback_tensors_calculator.proto"], diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index 59c129191..d0513518a 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -133,7 +133,7 @@ bool IsValidFftSize(int size) { // invocation. In the non-streaming mode, the vector contains all of the // output timestamps for an input audio buffer. // DC_AND_NYQUIST - std::pair @Optional. -// A pair of dc component and nyquest component. Only can be connected when +// A pair of dc component and nyquist component. Only can be connected when // the calculator performs fft (the fft_size is set in the calculator // options). // diff --git a/mediapipe/calculators/tensor/tensors_to_audio_calculator.cc b/mediapipe/calculators/tensor/tensors_to_audio_calculator.cc new file mode 100644 index 000000000..8da29bb69 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_audio_calculator.cc @@ -0,0 +1,197 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "audio/dsp/window_functions.h" +#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "pffft.h" + +namespace mediapipe { +namespace api2 { +namespace { + +std::vector HannWindow(int window_size, bool sqrt_hann) { + std::vector hann_window(window_size); + audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window); + if (sqrt_hann) { + absl::c_transform(hann_window, hann_window.begin(), + [](double x) { return std::sqrt(x); }); + } + return hann_window; +} + +// Note that the InvHannWindow function may only work for 50% overlapping case. +std::vector InvHannWindow(int window_size, bool sqrt_hann) { + std::vector window = HannWindow(window_size, sqrt_hann); + std::vector inv_window(window.size()); + if (sqrt_hann) { + absl::c_copy(window, inv_window.begin()); + } else { + const int kHalfWindowSize = window.size() / 2; + absl::c_transform(window, inv_window.begin(), + [](double x) { return x * x; }); + for (int i = 0; i < kHalfWindowSize; ++i) { + double sum = inv_window[i] + inv_window[kHalfWindowSize + i]; + inv_window[i] = window[i] / sum; + inv_window[kHalfWindowSize + i] = window[kHalfWindowSize + i] / sum; + } + } + return inv_window; +} + +// PFFFT only supports transforms for inputs of length N of the form +// N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT. +bool IsValidFftSize(int size) { + if (size <= 0) { + return false; + } + constexpr int kFactors[] = {2, 3, 5}; + int factorization[] = {0, 0, 0}; + int n = static_cast(size); + for (int i = 0; i < 3; ++i) { + while (n % kFactors[i] == 0) { + n = n / kFactors[i]; + ++factorization[i]; + } + } + return factorization[0] >= 5 && n == 1; +} + +} // namespace + +// Converts 2D MediaPipe float Tensors to audio buffers. +// The calculator will perform ifft on the complex DFT and apply the window +// function (Inverse Hann) afterwards. The input 2D MediaPipe Tensor must +// have the DFT real parts in its first row and the DFT imagery parts in its +// second row. A valid "fft_size" must be set in the CalculatorOptions. +// +// Inputs: +// TENSORS - std::vector +// Vector containing a single Tensor that represents the audio's complex DFT +// results. +// DC_AND_NYQUIST - std::pair +// A pair of dc component and nyquist component. +// +// Outputs: +// AUDIO - mediapipe::Matrix +// The audio data represented as mediapipe::Matrix. +// +// Example: +// node { +// calculator: "TensorsToAudioCalculator" +// input_stream: "TENSORS:tensors" +// input_stream: "DC_AND_NYQUIST:dc_and_nyquist" +// output_stream: "AUDIO:audio" +// options { +// [mediapipe.AudioToTensorCalculatorOptions.ext] { +// fft_size: 256 +// } +// } +// } +class TensorsToAudioCalculator : public Node { + public: + static constexpr Input> kTensorsIn{"TENSORS"}; + static constexpr Input> kDcAndNyquistIn{ + "DC_AND_NYQUIST"}; + static constexpr Output kAudioOut{"AUDIO"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kDcAndNyquistIn, kAudioOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + // The internal state of the FFT library. + PFFFT_Setup* fft_state_ = nullptr; + int fft_size_ = 0; + float inverse_fft_size_ = 0; + std::vector> input_dft_; + std::vector inv_fft_window_; + std::vector> fft_input_buffer_; + // pffft requires memory to work with to avoid using the stack. + std::vector> fft_workplace_; + std::vector> fft_output_; +}; + +absl::Status TensorsToAudioCalculator::Open(CalculatorContext* cc) { + const auto& options = + cc->Options(); + RET_CHECK(options.has_fft_size()) << "FFT size must be specified."; + RET_CHECK(IsValidFftSize(options.fft_size())) + << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b " + ">=0 and c >= 0 and a >= 5, the requested fft size is " + << options.fft_size(); + fft_size_ = options.fft_size(); + inverse_fft_size_ = 1.0f / fft_size_; + fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL); + input_dft_.resize(fft_size_); + inv_fft_window_ = InvHannWindow(fft_size_, /* sqrt_hann = */ false); + fft_input_buffer_.resize(fft_size_); + fft_workplace_.resize(fft_size_); + fft_output_.resize(fft_size_); + return absl::OkStatus(); +} + +absl::Status TensorsToAudioCalculator::Process(CalculatorContext* cc) { + if (kTensorsIn(cc).IsEmpty() || kDcAndNyquistIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kTensorsIn(cc); + RET_CHECK_EQ(input_tensors.size(), 1); + RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); + auto view = input_tensors[0].GetCpuReadView(); + // DC's real part. + input_dft_[0] = kDcAndNyquistIn(cc)->first; + // Nyquist's real part is the penultimate element of the tensor buffer. + // pffft ignores the Nyquist's imagery part. No need to fetch the last value + // from the tensor buffer. + input_dft_[1] = *(view.buffer() + (fft_size_ - 2)); + std::memcpy(input_dft_.data() + 2, view.buffer(), + (fft_size_ - 2) * sizeof(float)); + pffft_transform_ordered(fft_state_, input_dft_.data(), fft_output_.data(), + fft_workplace_.data(), PFFFT_BACKWARD); + // Applies the inverse window function. + std::transform( + fft_output_.begin(), fft_output_.end(), inv_fft_window_.begin(), + fft_output_.begin(), + [this](float a, float b) { return a * b * inverse_fft_size_; }); + Matrix matrix = Eigen::Map(fft_output_.data(), 1, fft_output_.size()); + kAudioOut(cc).Send(std::move(matrix)); + return absl::OkStatus(); +} + +absl::Status TensorsToAudioCalculator::Close(CalculatorContext* cc) { + if (fft_state_) { + pffft_destroy_setup(fft_state_); + } + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(TensorsToAudioCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_audio_calculator.proto b/mediapipe/calculators/tensor/tensors_to_audio_calculator.proto new file mode 100644 index 000000000..907627125 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_audio_calculator.proto @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsToAudioCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorsToAudioCalculatorOptions ext = 484297136; + } + + // Size of the fft in number of bins. If set, the calculator will do ifft + // on the input tensor. + optional int64 fft_size = 1; +} diff --git a/mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc new file mode 100644 index 000000000..b332381c6 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_audio_calculator_test.cc @@ -0,0 +1,149 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_audio_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +class TensorsToAudioCalculatorFftTest : public ::testing::Test { + protected: + // Creates an audio matrix containing a single sample of 1.0 at a specified + // offset. + Matrix CreateImpulseSignalData(int64 num_samples, int impulse_offset_idx) { + Matrix impulse = Matrix::Zero(1, num_samples); + impulse(0, impulse_offset_idx) = 1.0; + return impulse; + } + + void ConfigGraph(int num_samples, double sample_rate, int fft_size) { + graph_config_ = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "audio_in" + input_stream: "sample_rate" + output_stream: "audio_out" + node { + calculator: "AudioToTensorCalculator" + input_stream: "AUDIO:audio_in" + input_stream: "SAMPLE_RATE:sample_rate" + output_stream: "TENSORS:tensors" + output_stream: "DC_AND_NYQUIST:dc_and_nyquist" + options { + [mediapipe.AudioToTensorCalculatorOptions.ext] { + num_channels: 1 + num_samples: $0 + num_overlapping_samples: 0 + target_sample_rate: $1 + fft_size: $2 + } + } + } + node { + calculator: "TensorsToAudioCalculator" + input_stream: "TENSORS:tensors" + input_stream: "DC_AND_NYQUIST:dc_and_nyquist" + output_stream: "AUDIO:audio_out" + options { + [mediapipe.TensorsToAudioCalculatorOptions.ext] { + fft_size: $2 + } + } + } + )", + /*$0=*/num_samples, + /*$1=*/sample_rate, + /*$2=*/fft_size)); + tool::AddVectorSink("audio_out", &graph_config_, &audio_out_packets_); + } + + void RunGraph(const Matrix& input_data, double sample_rate) { + MP_ASSERT_OK(graph_.Initialize(graph_config_)); + MP_ASSERT_OK(graph_.StartRun({})); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "sample_rate", MakePacket(sample_rate).At(Timestamp(0)))); + MP_ASSERT_OK(graph_.AddPacketToInputStream( + "audio_in", MakePacket(input_data).At(Timestamp(0)))); + MP_ASSERT_OK(graph_.CloseAllInputStreams()); + MP_ASSERT_OK(graph_.WaitUntilDone()); + } + + std::vector audio_out_packets_; + CalculatorGraphConfig graph_config_; + CalculatorGraph graph_; +}; + +TEST_F(TensorsToAudioCalculatorFftTest, TestInvalidFftSize) { + ConfigGraph(320, 16000, 103); + MP_ASSERT_OK(graph_.Initialize(graph_config_)); + MP_ASSERT_OK(graph_.StartRun({})); + auto status = graph_.WaitUntilIdle(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_THAT(status.message(), + ::testing::HasSubstr("FFT size must be of the form")); +} + +TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtTheCenter) { + constexpr int sample_size = 320; + constexpr double sample_rate = 16000; + ConfigGraph(sample_size, sample_rate, 320); + + Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 2); + RunGraph(impulse_data, sample_rate); + ASSERT_EQ(1, audio_out_packets_.size()); + MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType()); + // The impulse signal at the center is not affected by the window function. + EXPECT_EQ(audio_out_packets_[0].Get(), impulse_data); +} + +TEST_F(TensorsToAudioCalculatorFftTest, TestWindowedImpulseSignal) { + constexpr int sample_size = 320; + constexpr double sample_rate = 16000; + ConfigGraph(sample_size, sample_rate, 320); + Matrix impulse_data = CreateImpulseSignalData(sample_size, sample_size / 4); + RunGraph(impulse_data, sample_rate); + ASSERT_EQ(1, audio_out_packets_.size()); + MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType()); + // As the impulse signal sits at the 1/4 of the hann window, the inverse + // window function reduces it by half. + EXPECT_EQ(audio_out_packets_[0].Get(), impulse_data / 2); +} + +TEST_F(TensorsToAudioCalculatorFftTest, TestImpulseSignalAtBeginning) { + constexpr int sample_size = 320; + constexpr double sample_rate = 16000; + ConfigGraph(sample_size, sample_rate, 320); + Matrix impulse_data = CreateImpulseSignalData(sample_size, 0); + RunGraph(impulse_data, sample_rate); + ASSERT_EQ(1, audio_out_packets_.size()); + MP_ASSERT_OK(audio_out_packets_[0].ValidateAsType()); + // As the impulse signal sits at the beginning of the hann window, the inverse + // window function completely removes it. + EXPECT_EQ(audio_out_packets_[0].Get(), Matrix::Zero(1, sample_size)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index 9439acf96..daa129928 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode( // Returns a PacketSequencerCalculator node. CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config, - bool synchronize_io) { + bool async_selection) { CalculatorGraphConfig::Node* result = config->add_node(); *result->mutable_calculator() = "PacketSequencerCalculator"; - if (synchronize_io) { + if (!async_selection) { *result->mutable_input_stream_handler()->mutable_input_stream_handler() = "DefaultInputStreamHandler"; } @@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField& streams, return tags.count({tag, 0}) > 0; } +// Returns true if a set of "TAG::index" includes a TagIndex. +bool ContainsTag(const proto_ns::RepeatedPtrField& tags, + TagIndex item) { + for (const std::string& t : tags) { + if (ParseTagIndex(t) == item) return true; + } + return false; +} + absl::StatusOr SwitchContainer::GetConfig( const Subgraph::SubgraphOptions& options) { CalculatorGraphConfig config; @@ -263,17 +272,17 @@ absl::StatusOr SwitchContainer::GetConfig( std::string enable_stream = "ENABLE:gate_enable"; // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams. - bool synchronize_io = - Subgraph::GetOptions(options) - .synchronize_io(); + const auto& switch_options = + Subgraph::GetOptions(options); + bool async_selection = switch_options.async_selection(); if (HasTag(container_node.input_stream(), "SELECT")) { - select_node = BuildTimestampNode(&config, synchronize_io); + select_node = BuildTimestampNode(&config, async_selection); select_node->add_input_stream("INPUT:gate_select"); select_node->add_output_stream("OUTPUT:gate_select_timed"); select_stream = "SELECT:gate_select_timed"; } if (HasTag(container_node.input_stream(), "ENABLE")) { - enable_node = BuildTimestampNode(&config, synchronize_io); + enable_node = BuildTimestampNode(&config, async_selection); enable_node->add_input_stream("INPUT:gate_enable"); enable_node->add_output_stream("OUTPUT:gate_enable_timed"); enable_stream = "ENABLE:gate_enable_timed"; @@ -296,7 +305,7 @@ absl::StatusOr SwitchContainer::GetConfig( mux->add_input_side_packet("SELECT:gate_select"); mux->add_input_side_packet("ENABLE:gate_enable"); - // Add input streams for graph and demux and the timestamper. + // Add input streams for graph and demux. config.add_input_stream("SELECT:gate_select"); config.add_input_stream("ENABLE:gate_enable"); config.add_input_side_packet("SELECT:gate_select"); @@ -306,6 +315,12 @@ absl::StatusOr SwitchContainer::GetConfig( std::string stream = CatStream(p.first, p.second); config.add_input_stream(stream); demux->add_input_stream(stream); + } + + // Add input streams for the timestamper. + auto& tick_streams = switch_options.tick_input_stream(); + for (const auto& p : input_tags) { + if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue; TagIndex tick_tag{"TICK", tick_index++}; if (select_node) { select_node->add_input_stream(CatStream(tick_tag, p.second)); diff --git a/mediapipe/framework/tool/switch_container.proto b/mediapipe/framework/tool/switch_container.proto index a9c2d9094..16050ef6f 100644 --- a/mediapipe/framework/tool/switch_container.proto +++ b/mediapipe/framework/tool/switch_container.proto @@ -25,6 +25,14 @@ message SwitchContainerOptions { // Activates channel 1 for enable = true, channel 0 otherwise. optional bool enable = 4; - // Use DefaultInputStreamHandler for muxing & demuxing. + // Use DefaultInputStreamHandler for demuxing. optional bool synchronize_io = 5; + + // Use ImmediateInputStreamHandler for channel selection. + optional bool async_selection = 6; + + // Specifies an input stream, "TAG:index", that defines the processed + // timestamps. SwitchContainer awaits output at the last processed + // timestamp before advancing from one selected channel to the next. + repeated string tick_input_stream = 7; } diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index de4aa0b14..b20979b10 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { input_stream: "INPUT:enable" input_stream: "TICK:foo" output_stream: "OUTPUT:switchcontainer__gate_enable_timed" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } } node { name: "switchcontainer__SwitchDemuxCalculator" @@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { // Shows the SwitchContainer container runs with a pair of simple subnodes. TEST(SwitchContainerTest, RunsWithSubnodes) { EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); - CalculatorGraphConfig supergraph = SubnodeContainerExample(); + CalculatorGraphConfig supergraph = + SubnodeContainerExample("async_selection: true"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); RunTestContainer(supergraph); } diff --git a/mediapipe/framework/tool/switch_demux_calculator.cc b/mediapipe/framework/tool/switch_demux_calculator.cc index b9ba2a0fb..c066d470a 100644 --- a/mediapipe/framework/tool/switch_demux_calculator.cc +++ b/mediapipe/framework/tool/switch_demux_calculator.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -54,21 +55,47 @@ namespace mediapipe { // contained subgraph or calculator nodes. // class SwitchDemuxCalculator : public CalculatorBase { - static constexpr char kSelectTag[] = "SELECT"; - static constexpr char kEnableTag[] = "ENABLE"; - public: static absl::Status GetContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; + private: + absl::Status RecordPackets(CalculatorContext* cc); + int ChannelIndex(Timestamp timestamp); + absl::Status SendActivePackets(CalculatorContext* cc); + private: int channel_index_; std::set channel_tags_; + using PacketQueue = std::map>; + PacketQueue input_queue_; + std::map channel_history_; }; REGISTER_CALCULATOR(SwitchDemuxCalculator); +namespace { +static constexpr char kSelectTag[] = "SELECT"; +static constexpr char kEnableTag[] = "ENABLE"; + +// Returns the last received timestamp for an input stream. +inline Timestamp SettledTimestamp(const InputStreamShard& input) { + return input.Value().Timestamp(); +} + +// Returns the last received timestamp for channel selection. +inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) { + Timestamp result = Timestamp::Done(); + if (cc->Inputs().HasTag(kEnableTag)) { + result = SettledTimestamp(cc->Inputs().Tag(kEnableTag)); + } else if (cc->Inputs().HasTag(kSelectTag)) { + result = SettledTimestamp(cc->Inputs().Tag(kSelectTag)); + } + return result; +} +} // namespace + absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { // Allow any one of kSelectTag, kEnableTag. cc->Inputs().Tag(kSelectTag).Set().Optional(); @@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Outputs().TagMap()); + channel_history_[Timestamp::Unstarted()] = channel_index_; // Relay side packets to all channels. // Note: This is necessary because Calculator::Open only proceeds when every @@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { } absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { - // Update the input channel index if specified. - channel_index_ = tool::GetChannelIndex(*cc, channel_index_); + MP_RETURN_IF_ERROR(RecordPackets(cc)); + MP_RETURN_IF_ERROR(SendActivePackets(cc)); + return absl::OkStatus(); +} - // Relay packets and timestamps only to channel_index_. +// Enqueue all arriving packets and bounds. +absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) { + // Enqueue any new arriving packets. for (const std::string& tag : channel_tags_) { for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { - auto& input = cc->Inputs().Get(tag, index); - std::string output_tag = tool::ChannelTag(tag, channel_index_); - auto output_id = cc->Outputs().GetId(output_tag, index); - if (output_id.IsValid()) { - auto& output = cc->Outputs().Get(output_tag, index); - tool::Relay(input, &output); + auto input_id = cc->Inputs().GetId(tag, index); + Packet packet = cc->Inputs().Get(input_id).Value(); + if (packet.Timestamp() == cc->InputTimestamp()) { + input_queue_[input_id].push(packet); } } } + + // Enque any new input channel and its activation timestamp. + Timestamp channel_settled = ChannelSettledTimestamp(cc); + int new_channel_index = tool::GetChannelIndex(*cc, channel_index_); + if (channel_settled == cc->InputTimestamp() && + new_channel_index != channel_index_) { + channel_index_ = new_channel_index; + channel_history_[channel_settled] = channel_index_; + } + return absl::OkStatus(); +} + +// Returns the channel index for a Timestamp. +int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) { + auto it = std::prev(channel_history_.upper_bound(timestamp)); + return it->second; +} + +// Dispatches all queued input packets with known channels. +absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) { + // Dispatch any queued input packets with a defined channel_index. + Timestamp channel_settled = ChannelSettledTimestamp(cc); + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { + auto input_id = cc->Inputs().GetId(tag, index); + auto& queue = input_queue_[input_id]; + while (!queue.empty() && queue.front().Timestamp() <= channel_settled) { + int channel_index = ChannelIndex(queue.front().Timestamp()); + std::string output_tag = tool::ChannelTag(tag, channel_index); + auto output_id = cc->Outputs().GetId(output_tag, index); + if (output_id.IsValid()) { + cc->Outputs().Get(output_id).AddPacket(queue.front()); + } + queue.pop(); + } + } + } + + // Discard all select packets not needed for any remaining input packets. + Timestamp input_settled = Timestamp::Done(); + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { + auto input_id = cc->Inputs().GetId(tag, index); + Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id)); + if (!input_queue_[input_id].empty()) { + Timestamp stream_bound = input_queue_[input_id].front().Timestamp(); + stream_settled = + std::min(stream_settled, stream_bound.PreviousAllowedInStream()); + } + } + } + Timestamp input_bound = input_settled.NextAllowedInStream(); + auto history_bound = std::prev(channel_history_.upper_bound(input_bound)); + channel_history_.erase(channel_history_.begin(), history_bound); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/switch_mux_calculator.cc b/mediapipe/framework/tool/switch_mux_calculator.cc index 1a3136620..230544b6b 100644 --- a/mediapipe/framework/tool/switch_mux_calculator.cc +++ b/mediapipe/framework/tool/switch_mux_calculator.cc @@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Inputs().TagMap()); - channel_history_[Timestamp::Unset()] = channel_index_; + channel_history_[Timestamp::Unstarted()] = channel_index_; // Relay side packets only from channel_index_. for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {