Internal change
PiperOrigin-RevId: 484583911
This commit is contained in:
parent
f4f8b11ffc
commit
c5c639d634
|
@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
|
|||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
async_selection: true
|
||||
contained_node: {
|
||||
calculator: "BypassCalculator"
|
||||
node_options: {
|
||||
|
|
|
@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
|
|||
|
||||
// Returns a PacketSequencerCalculator node.
|
||||
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
|
||||
bool synchronize_io) {
|
||||
bool async_selection) {
|
||||
CalculatorGraphConfig::Node* result = config->add_node();
|
||||
*result->mutable_calculator() = "PacketSequencerCalculator";
|
||||
if (synchronize_io) {
|
||||
if (!async_selection) {
|
||||
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
|
||||
"DefaultInputStreamHandler";
|
||||
}
|
||||
|
@ -263,17 +263,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
|||
std::string enable_stream = "ENABLE:gate_enable";
|
||||
|
||||
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
|
||||
bool synchronize_io =
|
||||
bool async_selection =
|
||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
|
||||
.synchronize_io();
|
||||
.async_selection();
|
||||
if (HasTag(container_node.input_stream(), "SELECT")) {
|
||||
select_node = BuildTimestampNode(&config, synchronize_io);
|
||||
select_node = BuildTimestampNode(&config, async_selection);
|
||||
select_node->add_input_stream("INPUT:gate_select");
|
||||
select_node->add_output_stream("OUTPUT:gate_select_timed");
|
||||
select_stream = "SELECT:gate_select_timed";
|
||||
}
|
||||
if (HasTag(container_node.input_stream(), "ENABLE")) {
|
||||
enable_node = BuildTimestampNode(&config, synchronize_io);
|
||||
enable_node = BuildTimestampNode(&config, async_selection);
|
||||
enable_node->add_input_stream("INPUT:gate_enable");
|
||||
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
|
||||
enable_stream = "ENABLE:gate_enable_timed";
|
||||
|
|
|
@ -25,6 +25,9 @@ message SwitchContainerOptions {
|
|||
// Activates channel 1 for enable = true, channel 0 otherwise.
|
||||
optional bool enable = 4;
|
||||
|
||||
// Use DefaultInputStreamHandler for muxing & demuxing.
|
||||
// Use DefaultInputStreamHandler for demuxing.
|
||||
optional bool synchronize_io = 5;
|
||||
|
||||
// Use ImmediateInputStreamHandler for channel selection.
|
||||
optional bool async_selection = 6;
|
||||
}
|
||||
|
|
|
@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
input_stream: "INPUT:enable"
|
||||
input_stream: "TICK:foo"
|
||||
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
|
||||
input_stream_handler {
|
||||
input_stream_handler: "DefaultInputStreamHandler"
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "switchcontainer__SwitchDemuxCalculator"
|
||||
|
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
|||
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
||||
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
||||
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
||||
CalculatorGraphConfig supergraph =
|
||||
SubnodeContainerExample("async_selection: true");
|
||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||
RunTestContainer(supergraph);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
|
@ -54,21 +55,47 @@ namespace mediapipe {
|
|||
// contained subgraph or calculator nodes.
|
||||
//
|
||||
class SwitchDemuxCalculator : public CalculatorBase {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status RecordPackets(CalculatorContext* cc);
|
||||
int ChannelIndex(Timestamp timestamp);
|
||||
absl::Status SendActivePackets(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
int channel_index_;
|
||||
std::set<std::string> channel_tags_;
|
||||
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
|
||||
PacketQueue input_queue_;
|
||||
std::map<Timestamp, int> channel_history_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SwitchDemuxCalculator);
|
||||
|
||||
namespace {
|
||||
static constexpr char kSelectTag[] = "SELECT";
|
||||
static constexpr char kEnableTag[] = "ENABLE";
|
||||
|
||||
// Returns the last received timestamp for an input stream.
|
||||
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
|
||||
return input.Value().Timestamp();
|
||||
}
|
||||
|
||||
// Returns the last received timestamp for channel selection.
|
||||
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
|
||||
Timestamp result = Timestamp::Done();
|
||||
if (cc->Inputs().HasTag(kEnableTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
|
||||
} else if (cc->Inputs().HasTag(kSelectTag)) {
|
||||
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
||||
// Allow any one of kSelectTag, kEnableTag.
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||
|
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
|||
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets to all channels.
|
||||
// Note: This is necessary because Calculator::Open only proceeds when every
|
||||
|
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
|
||||
// Update the input channel index if specified.
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
MP_RETURN_IF_ERROR(RecordPackets(cc));
|
||||
MP_RETURN_IF_ERROR(SendActivePackets(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Relay packets and timestamps only to channel_index_.
|
||||
// Enqueue all arriving packets and bounds.
|
||||
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
|
||||
// Enqueue any new arriving packets.
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto& input = cc->Inputs().Get(tag, index);
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index_);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
auto& output = cc->Outputs().Get(output_tag, index);
|
||||
tool::Relay(input, &output);
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Packet packet = cc->Inputs().Get(input_id).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
input_queue_[input_id].push(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enque any new input channel and its activation timestamp.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
|
||||
if (channel_settled == cc->InputTimestamp() &&
|
||||
new_channel_index != channel_index_) {
|
||||
channel_index_ = new_channel_index;
|
||||
channel_history_[channel_settled] = channel_index_;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns the channel index for a Timestamp.
|
||||
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
|
||||
auto it = std::prev(channel_history_.upper_bound(timestamp));
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Dispatches all queued input packets with known channels.
|
||||
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
|
||||
// Dispatch any queued input packets with a defined channel_index.
|
||||
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
auto& queue = input_queue_[input_id];
|
||||
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
|
||||
int channel_index = ChannelIndex(queue.front().Timestamp());
|
||||
std::string output_tag = tool::ChannelTag(tag, channel_index);
|
||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||
if (output_id.IsValid()) {
|
||||
cc->Outputs().Get(output_id).AddPacket(queue.front());
|
||||
}
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discard all select packets not needed for any remaining input packets.
|
||||
Timestamp input_settled = Timestamp::Done();
|
||||
for (const std::string& tag : channel_tags_) {
|
||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||
auto input_id = cc->Inputs().GetId(tag, index);
|
||||
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
|
||||
if (!input_queue_[input_id].empty()) {
|
||||
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
|
||||
stream_settled =
|
||||
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
Timestamp input_bound = input_settled.NextAllowedInStream();
|
||||
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
|
||||
channel_history_.erase(channel_history_.begin(), history_bound);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
|
|||
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
|
||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
|
||||
channel_history_[Timestamp::Unset()] = channel_index_;
|
||||
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||
|
||||
// Relay side packets only from channel_index_.
|
||||
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user