diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index 24a1e6fe7..daa129928 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -239,6 +239,15 @@ bool HasTag(const proto_ns::RepeatedPtrField& streams, return tags.count({tag, 0}) > 0; } +// Returns true if a set of "TAG::index" includes a TagIndex. +bool ContainsTag(const proto_ns::RepeatedPtrField& tags, + TagIndex item) { + for (const std::string& t : tags) { + if (ParseTagIndex(t) == item) return true; + } + return false; +} + absl::StatusOr SwitchContainer::GetConfig( const Subgraph::SubgraphOptions& options) { CalculatorGraphConfig config; @@ -263,9 +272,9 @@ absl::StatusOr SwitchContainer::GetConfig( std::string enable_stream = "ENABLE:gate_enable"; // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams. - bool async_selection = - Subgraph::GetOptions(options) - .async_selection(); + const auto& switch_options = + Subgraph::GetOptions(options); + bool async_selection = switch_options.async_selection(); if (HasTag(container_node.input_stream(), "SELECT")) { select_node = BuildTimestampNode(&config, async_selection); select_node->add_input_stream("INPUT:gate_select"); @@ -296,7 +305,7 @@ absl::StatusOr SwitchContainer::GetConfig( mux->add_input_side_packet("SELECT:gate_select"); mux->add_input_side_packet("ENABLE:gate_enable"); - // Add input streams for graph and demux and the timestamper. + // Add input streams for graph and demux. config.add_input_stream("SELECT:gate_select"); config.add_input_stream("ENABLE:gate_enable"); config.add_input_side_packet("SELECT:gate_select"); @@ -306,6 +315,12 @@ absl::StatusOr SwitchContainer::GetConfig( std::string stream = CatStream(p.first, p.second); config.add_input_stream(stream); demux->add_input_stream(stream); + } + + // Add input streams for the timestamper. + auto& tick_streams = switch_options.tick_input_stream(); + for (const auto& p : input_tags) { + if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue; TagIndex tick_tag{"TICK", tick_index++}; if (select_node) { select_node->add_input_stream(CatStream(tick_tag, p.second)); diff --git a/mediapipe/framework/tool/switch_container.proto b/mediapipe/framework/tool/switch_container.proto index 2de5a7abf..16050ef6f 100644 --- a/mediapipe/framework/tool/switch_container.proto +++ b/mediapipe/framework/tool/switch_container.proto @@ -30,4 +30,9 @@ message SwitchContainerOptions { // Use ImmediateInputStreamHandler for channel selection. optional bool async_selection = 6; + + // Specifies an input stream, "TAG:index", that defines the processed + // timestamps. SwitchContainer awaits output at the last processed + // timestamp before advancing from one selected channel to the next. + repeated string tick_input_stream = 7; }