diff --git a/mediapipe/calculators/core/bypass_calculator_test.cc b/mediapipe/calculators/core/bypass_calculator_test.cc index 4d1cd8f79..224742a13 100644 --- a/mediapipe/calculators/core/bypass_calculator_test.cc +++ b/mediapipe/calculators/core/bypass_calculator_test.cc @@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb( output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" options { [mediapipe.SwitchContainerOptions.ext] { + async_selection: true contained_node: { calculator: "AppearancesPassThroughSubgraph" } } } @@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb( output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" options { [mediapipe.SwitchContainerOptions.ext] { + async_selection: true contained_node: { calculator: "BypassCalculator" node_options: { diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index 9439acf96..24a1e6fe7 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode( // Returns a PacketSequencerCalculator node. CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config, - bool synchronize_io) { + bool async_selection) { CalculatorGraphConfig::Node* result = config->add_node(); *result->mutable_calculator() = "PacketSequencerCalculator"; - if (synchronize_io) { + if (!async_selection) { *result->mutable_input_stream_handler()->mutable_input_stream_handler() = "DefaultInputStreamHandler"; } @@ -263,17 +263,17 @@ absl::StatusOr SwitchContainer::GetConfig( std::string enable_stream = "ENABLE:gate_enable"; // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams. - bool synchronize_io = + bool async_selection = Subgraph::GetOptions(options) - .synchronize_io(); + .async_selection(); if (HasTag(container_node.input_stream(), "SELECT")) { - select_node = BuildTimestampNode(&config, synchronize_io); + select_node = BuildTimestampNode(&config, async_selection); select_node->add_input_stream("INPUT:gate_select"); select_node->add_output_stream("OUTPUT:gate_select_timed"); select_stream = "SELECT:gate_select_timed"; } if (HasTag(container_node.input_stream(), "ENABLE")) { - enable_node = BuildTimestampNode(&config, synchronize_io); + enable_node = BuildTimestampNode(&config, async_selection); enable_node->add_input_stream("INPUT:gate_enable"); enable_node->add_output_stream("OUTPUT:gate_enable_timed"); enable_stream = "ENABLE:gate_enable_timed"; diff --git a/mediapipe/framework/tool/switch_container.proto b/mediapipe/framework/tool/switch_container.proto index a9c2d9094..2de5a7abf 100644 --- a/mediapipe/framework/tool/switch_container.proto +++ b/mediapipe/framework/tool/switch_container.proto @@ -25,6 +25,9 @@ message SwitchContainerOptions { // Activates channel 1 for enable = true, channel 0 otherwise. optional bool enable = 4; - // Use DefaultInputStreamHandler for muxing & demuxing. + // Use DefaultInputStreamHandler for demuxing. optional bool synchronize_io = 5; + + // Use ImmediateInputStreamHandler for channel selection. + optional bool async_selection = 6; } diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index de4aa0b14..b20979b10 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { input_stream: "INPUT:enable" input_stream: "TICK:foo" output_stream: "OUTPUT:switchcontainer__gate_enable_timed" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } } node { name: "switchcontainer__SwitchDemuxCalculator" @@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) { // Shows the SwitchContainer container runs with a pair of simple subnodes. TEST(SwitchContainerTest, RunsWithSubnodes) { EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); - CalculatorGraphConfig supergraph = SubnodeContainerExample(); + CalculatorGraphConfig supergraph = + SubnodeContainerExample("async_selection: true"); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); RunTestContainer(supergraph); } diff --git a/mediapipe/framework/tool/switch_demux_calculator.cc b/mediapipe/framework/tool/switch_demux_calculator.cc index b9ba2a0fb..c066d470a 100644 --- a/mediapipe/framework/tool/switch_demux_calculator.cc +++ b/mediapipe/framework/tool/switch_demux_calculator.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -54,21 +55,47 @@ namespace mediapipe { // contained subgraph or calculator nodes. // class SwitchDemuxCalculator : public CalculatorBase { - static constexpr char kSelectTag[] = "SELECT"; - static constexpr char kEnableTag[] = "ENABLE"; - public: static absl::Status GetContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; + private: + absl::Status RecordPackets(CalculatorContext* cc); + int ChannelIndex(Timestamp timestamp); + absl::Status SendActivePackets(CalculatorContext* cc); + private: int channel_index_; std::set channel_tags_; + using PacketQueue = std::map>; + PacketQueue input_queue_; + std::map channel_history_; }; REGISTER_CALCULATOR(SwitchDemuxCalculator); +namespace { +static constexpr char kSelectTag[] = "SELECT"; +static constexpr char kEnableTag[] = "ENABLE"; + +// Returns the last received timestamp for an input stream. +inline Timestamp SettledTimestamp(const InputStreamShard& input) { + return input.Value().Timestamp(); +} + +// Returns the last received timestamp for channel selection. +inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) { + Timestamp result = Timestamp::Done(); + if (cc->Inputs().HasTag(kEnableTag)) { + result = SettledTimestamp(cc->Inputs().Tag(kEnableTag)); + } else if (cc->Inputs().HasTag(kSelectTag)) { + result = SettledTimestamp(cc->Inputs().Tag(kSelectTag)); + } + return result; +} +} // namespace + absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { // Allow any one of kSelectTag, kEnableTag. cc->Inputs().Tag(kSelectTag).Set().Optional(); @@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Outputs().TagMap()); + channel_history_[Timestamp::Unstarted()] = channel_index_; // Relay side packets to all channels. // Note: This is necessary because Calculator::Open only proceeds when every @@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { } absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { - // Update the input channel index if specified. - channel_index_ = tool::GetChannelIndex(*cc, channel_index_); + MP_RETURN_IF_ERROR(RecordPackets(cc)); + MP_RETURN_IF_ERROR(SendActivePackets(cc)); + return absl::OkStatus(); +} - // Relay packets and timestamps only to channel_index_. +// Enqueue all arriving packets and bounds. +absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) { + // Enqueue any new arriving packets. for (const std::string& tag : channel_tags_) { for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { - auto& input = cc->Inputs().Get(tag, index); - std::string output_tag = tool::ChannelTag(tag, channel_index_); - auto output_id = cc->Outputs().GetId(output_tag, index); - if (output_id.IsValid()) { - auto& output = cc->Outputs().Get(output_tag, index); - tool::Relay(input, &output); + auto input_id = cc->Inputs().GetId(tag, index); + Packet packet = cc->Inputs().Get(input_id).Value(); + if (packet.Timestamp() == cc->InputTimestamp()) { + input_queue_[input_id].push(packet); } } } + + // Enque any new input channel and its activation timestamp. + Timestamp channel_settled = ChannelSettledTimestamp(cc); + int new_channel_index = tool::GetChannelIndex(*cc, channel_index_); + if (channel_settled == cc->InputTimestamp() && + new_channel_index != channel_index_) { + channel_index_ = new_channel_index; + channel_history_[channel_settled] = channel_index_; + } + return absl::OkStatus(); +} + +// Returns the channel index for a Timestamp. +int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) { + auto it = std::prev(channel_history_.upper_bound(timestamp)); + return it->second; +} + +// Dispatches all queued input packets with known channels. +absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) { + // Dispatch any queued input packets with a defined channel_index. + Timestamp channel_settled = ChannelSettledTimestamp(cc); + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { + auto input_id = cc->Inputs().GetId(tag, index); + auto& queue = input_queue_[input_id]; + while (!queue.empty() && queue.front().Timestamp() <= channel_settled) { + int channel_index = ChannelIndex(queue.front().Timestamp()); + std::string output_tag = tool::ChannelTag(tag, channel_index); + auto output_id = cc->Outputs().GetId(output_tag, index); + if (output_id.IsValid()) { + cc->Outputs().Get(output_id).AddPacket(queue.front()); + } + queue.pop(); + } + } + } + + // Discard all select packets not needed for any remaining input packets. + Timestamp input_settled = Timestamp::Done(); + for (const std::string& tag : channel_tags_) { + for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { + auto input_id = cc->Inputs().GetId(tag, index); + Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id)); + if (!input_queue_[input_id].empty()) { + Timestamp stream_bound = input_queue_[input_id].front().Timestamp(); + stream_settled = + std::min(stream_settled, stream_bound.PreviousAllowedInStream()); + } + } + } + Timestamp input_bound = input_settled.NextAllowedInStream(); + auto history_bound = std::prev(channel_history_.upper_bound(input_bound)); + channel_history_.erase(channel_history_.begin(), history_bound); return absl::OkStatus(); } diff --git a/mediapipe/framework/tool/switch_mux_calculator.cc b/mediapipe/framework/tool/switch_mux_calculator.cc index 1a3136620..230544b6b 100644 --- a/mediapipe/framework/tool/switch_mux_calculator.cc +++ b/mediapipe/framework/tool/switch_mux_calculator.cc @@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Inputs().TagMap()); - channel_history_[Timestamp::Unset()] = channel_index_; + channel_history_[Timestamp::Unstarted()] = channel_index_; // Relay side packets only from channel_index_. for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {