Internal change

PiperOrigin-RevId: 484583911
This commit is contained in:
Hadon Nash 2022-10-28 11:44:31 -07:00 committed by Copybara-Service
parent f4f8b11ffc
commit c5c639d634
6 changed files with 114 additions and 21 deletions

View File

@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options { options {
[mediapipe.SwitchContainerOptions.ext] { [mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: { calculator: "AppearancesPassThroughSubgraph" } contained_node: { calculator: "AppearancesPassThroughSubgraph" }
} }
} }
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options { options {
[mediapipe.SwitchContainerOptions.ext] { [mediapipe.SwitchContainerOptions.ext] {
async_selection: true
contained_node: { contained_node: {
calculator: "BypassCalculator" calculator: "BypassCalculator"
node_options: { node_options: {

View File

@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
// Returns a PacketSequencerCalculator node. // Returns a PacketSequencerCalculator node.
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config, CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
bool synchronize_io) { bool async_selection) {
CalculatorGraphConfig::Node* result = config->add_node(); CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "PacketSequencerCalculator"; *result->mutable_calculator() = "PacketSequencerCalculator";
if (synchronize_io) { if (!async_selection) {
*result->mutable_input_stream_handler()->mutable_input_stream_handler() = *result->mutable_input_stream_handler()->mutable_input_stream_handler() =
"DefaultInputStreamHandler"; "DefaultInputStreamHandler";
} }
@ -263,17 +263,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
std::string enable_stream = "ENABLE:gate_enable"; std::string enable_stream = "ENABLE:gate_enable";
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams. // Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
bool synchronize_io = bool async_selection =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options) Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
.synchronize_io(); .async_selection();
if (HasTag(container_node.input_stream(), "SELECT")) { if (HasTag(container_node.input_stream(), "SELECT")) {
select_node = BuildTimestampNode(&config, synchronize_io); select_node = BuildTimestampNode(&config, async_selection);
select_node->add_input_stream("INPUT:gate_select"); select_node->add_input_stream("INPUT:gate_select");
select_node->add_output_stream("OUTPUT:gate_select_timed"); select_node->add_output_stream("OUTPUT:gate_select_timed");
select_stream = "SELECT:gate_select_timed"; select_stream = "SELECT:gate_select_timed";
} }
if (HasTag(container_node.input_stream(), "ENABLE")) { if (HasTag(container_node.input_stream(), "ENABLE")) {
enable_node = BuildTimestampNode(&config, synchronize_io); enable_node = BuildTimestampNode(&config, async_selection);
enable_node->add_input_stream("INPUT:gate_enable"); enable_node->add_input_stream("INPUT:gate_enable");
enable_node->add_output_stream("OUTPUT:gate_enable_timed"); enable_node->add_output_stream("OUTPUT:gate_enable_timed");
enable_stream = "ENABLE:gate_enable_timed"; enable_stream = "ENABLE:gate_enable_timed";

View File

@ -25,6 +25,9 @@ message SwitchContainerOptions {
// Activates channel 1 for enable = true, channel 0 otherwise. // Activates channel 1 for enable = true, channel 0 otherwise.
optional bool enable = 4; optional bool enable = 4;
// Use DefaultInputStreamHandler for muxing & demuxing. // Use DefaultInputStreamHandler for demuxing.
optional bool synchronize_io = 5; optional bool synchronize_io = 5;
// Use ImmediateInputStreamHandler for channel selection.
optional bool async_selection = 6;
} }

View File

@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
input_stream: "INPUT:enable" input_stream: "INPUT:enable"
input_stream: "TICK:foo" input_stream: "TICK:foo"
output_stream: "OUTPUT:switchcontainer__gate_enable_timed" output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
input_stream_handler {
input_stream_handler: "DefaultInputStreamHandler"
}
} }
node { node {
name: "switchcontainer__SwitchDemuxCalculator" name: "switchcontainer__SwitchDemuxCalculator"
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
// Shows the SwitchContainer container runs with a pair of simple subnodes. // Shows the SwitchContainer container runs with a pair of simple subnodes.
TEST(SwitchContainerTest, RunsWithSubnodes) { TEST(SwitchContainerTest, RunsWithSubnodes) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer")); EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
CalculatorGraphConfig supergraph = SubnodeContainerExample(); CalculatorGraphConfig supergraph =
SubnodeContainerExample("async_selection: true");
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph)); MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
RunTestContainer(supergraph); RunTestContainer(supergraph);
} }

View File

@ -14,6 +14,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <queue>
#include <set> #include <set>
#include <string> #include <string>
@ -54,21 +55,47 @@ namespace mediapipe {
// contained subgraph or calculator nodes. // contained subgraph or calculator nodes.
// //
class SwitchDemuxCalculator : public CalculatorBase { class SwitchDemuxCalculator : public CalculatorBase {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
public: public:
static absl::Status GetContract(CalculatorContract* cc); static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private:
absl::Status RecordPackets(CalculatorContext* cc);
int ChannelIndex(Timestamp timestamp);
absl::Status SendActivePackets(CalculatorContext* cc);
private: private:
int channel_index_; int channel_index_;
std::set<std::string> channel_tags_; std::set<std::string> channel_tags_;
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
PacketQueue input_queue_;
std::map<Timestamp, int> channel_history_;
}; };
REGISTER_CALCULATOR(SwitchDemuxCalculator); REGISTER_CALCULATOR(SwitchDemuxCalculator);
namespace {
static constexpr char kSelectTag[] = "SELECT";
static constexpr char kEnableTag[] = "ENABLE";
// Returns the last received timestamp for an input stream.
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
return input.Value().Timestamp();
}
// Returns the last received timestamp for channel selection.
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
Timestamp result = Timestamp::Done();
if (cc->Inputs().HasTag(kEnableTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
} else if (cc->Inputs().HasTag(kSelectTag)) {
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
}
return result;
}
} // namespace
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
// Allow any one of kSelectTag, kEnableTag. // Allow any one of kSelectTag, kEnableTag.
cc->Inputs().Tag(kSelectTag).Set<int>().Optional(); cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Outputs().TagMap()); channel_tags_ = ChannelTags(cc->Outputs().TagMap());
channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets to all channels. // Relay side packets to all channels.
// Note: This is necessary because Calculator::Open only proceeds when every // Note: This is necessary because Calculator::Open only proceeds when every
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
} }
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
// Update the input channel index if specified. MP_RETURN_IF_ERROR(RecordPackets(cc));
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); MP_RETURN_IF_ERROR(SendActivePackets(cc));
return absl::OkStatus();
}
// Relay packets and timestamps only to channel_index_. // Enqueue all arriving packets and bounds.
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
// Enqueue any new arriving packets.
for (const std::string& tag : channel_tags_) { for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) { for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto& input = cc->Inputs().Get(tag, index); auto input_id = cc->Inputs().GetId(tag, index);
std::string output_tag = tool::ChannelTag(tag, channel_index_); Packet packet = cc->Inputs().Get(input_id).Value();
auto output_id = cc->Outputs().GetId(output_tag, index); if (packet.Timestamp() == cc->InputTimestamp()) {
if (output_id.IsValid()) { input_queue_[input_id].push(packet);
auto& output = cc->Outputs().Get(output_tag, index);
tool::Relay(input, &output);
} }
} }
} }
// Enque any new input channel and its activation timestamp.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
if (channel_settled == cc->InputTimestamp() &&
new_channel_index != channel_index_) {
channel_index_ = new_channel_index;
channel_history_[channel_settled] = channel_index_;
}
return absl::OkStatus();
}
// Returns the channel index for a Timestamp.
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
auto it = std::prev(channel_history_.upper_bound(timestamp));
return it->second;
}
// Dispatches all queued input packets with known channels.
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
// Dispatch any queued input packets with a defined channel_index.
Timestamp channel_settled = ChannelSettledTimestamp(cc);
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
auto& queue = input_queue_[input_id];
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
int channel_index = ChannelIndex(queue.front().Timestamp());
std::string output_tag = tool::ChannelTag(tag, channel_index);
auto output_id = cc->Outputs().GetId(output_tag, index);
if (output_id.IsValid()) {
cc->Outputs().Get(output_id).AddPacket(queue.front());
}
queue.pop();
}
}
}
// Discard all select packets not needed for any remaining input packets.
Timestamp input_settled = Timestamp::Done();
for (const std::string& tag : channel_tags_) {
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
auto input_id = cc->Inputs().GetId(tag, index);
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
if (!input_queue_[input_id].empty()) {
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
stream_settled =
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
}
}
}
Timestamp input_bound = input_settled.NextAllowedInStream();
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
channel_history_.erase(channel_history_.begin(), history_bound);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<mediapipe::SwitchContainerOptions>(); options_ = cc->Options<mediapipe::SwitchContainerOptions>();
channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
channel_tags_ = ChannelTags(cc->Inputs().TagMap()); channel_tags_ = ChannelTags(cc->Inputs().TagMap());
channel_history_[Timestamp::Unset()] = channel_index_; channel_history_[Timestamp::Unstarted()] = channel_index_;
// Relay side packets only from channel_index_. // Relay side packets only from channel_index_.
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) { for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {