Internal change
PiperOrigin-RevId: 484583911
This commit is contained in:
parent
f4f8b11ffc
commit
c5c639d634
|
@ -75,6 +75,7 @@ constexpr char kTestGraphConfig2[] = R"pb(
|
||||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||||
options {
|
options {
|
||||||
[mediapipe.SwitchContainerOptions.ext] {
|
[mediapipe.SwitchContainerOptions.ext] {
|
||||||
|
async_selection: true
|
||||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,6 +102,7 @@ constexpr char kTestGraphConfig3[] = R"pb(
|
||||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||||
options {
|
options {
|
||||||
[mediapipe.SwitchContainerOptions.ext] {
|
[mediapipe.SwitchContainerOptions.ext] {
|
||||||
|
async_selection: true
|
||||||
contained_node: {
|
contained_node: {
|
||||||
calculator: "BypassCalculator"
|
calculator: "BypassCalculator"
|
||||||
node_options: {
|
node_options: {
|
||||||
|
|
|
@ -105,10 +105,10 @@ CalculatorGraphConfig::Node* BuildMuxNode(
|
||||||
|
|
||||||
// Returns a PacketSequencerCalculator node.
|
// Returns a PacketSequencerCalculator node.
|
||||||
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
|
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
|
||||||
bool synchronize_io) {
|
bool async_selection) {
|
||||||
CalculatorGraphConfig::Node* result = config->add_node();
|
CalculatorGraphConfig::Node* result = config->add_node();
|
||||||
*result->mutable_calculator() = "PacketSequencerCalculator";
|
*result->mutable_calculator() = "PacketSequencerCalculator";
|
||||||
if (synchronize_io) {
|
if (!async_selection) {
|
||||||
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
|
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
|
||||||
"DefaultInputStreamHandler";
|
"DefaultInputStreamHandler";
|
||||||
}
|
}
|
||||||
|
@ -263,17 +263,17 @@ absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
|
||||||
std::string enable_stream = "ENABLE:gate_enable";
|
std::string enable_stream = "ENABLE:gate_enable";
|
||||||
|
|
||||||
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
|
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
|
||||||
bool synchronize_io =
|
bool async_selection =
|
||||||
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
|
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
|
||||||
.synchronize_io();
|
.async_selection();
|
||||||
if (HasTag(container_node.input_stream(), "SELECT")) {
|
if (HasTag(container_node.input_stream(), "SELECT")) {
|
||||||
select_node = BuildTimestampNode(&config, synchronize_io);
|
select_node = BuildTimestampNode(&config, async_selection);
|
||||||
select_node->add_input_stream("INPUT:gate_select");
|
select_node->add_input_stream("INPUT:gate_select");
|
||||||
select_node->add_output_stream("OUTPUT:gate_select_timed");
|
select_node->add_output_stream("OUTPUT:gate_select_timed");
|
||||||
select_stream = "SELECT:gate_select_timed";
|
select_stream = "SELECT:gate_select_timed";
|
||||||
}
|
}
|
||||||
if (HasTag(container_node.input_stream(), "ENABLE")) {
|
if (HasTag(container_node.input_stream(), "ENABLE")) {
|
||||||
enable_node = BuildTimestampNode(&config, synchronize_io);
|
enable_node = BuildTimestampNode(&config, async_selection);
|
||||||
enable_node->add_input_stream("INPUT:gate_enable");
|
enable_node->add_input_stream("INPUT:gate_enable");
|
||||||
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
|
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
|
||||||
enable_stream = "ENABLE:gate_enable_timed";
|
enable_stream = "ENABLE:gate_enable_timed";
|
||||||
|
|
|
@ -25,6 +25,9 @@ message SwitchContainerOptions {
|
||||||
// Activates channel 1 for enable = true, channel 0 otherwise.
|
// Activates channel 1 for enable = true, channel 0 otherwise.
|
||||||
optional bool enable = 4;
|
optional bool enable = 4;
|
||||||
|
|
||||||
// Use DefaultInputStreamHandler for muxing & demuxing.
|
// Use DefaultInputStreamHandler for demuxing.
|
||||||
optional bool synchronize_io = 5;
|
optional bool synchronize_io = 5;
|
||||||
|
|
||||||
|
// Use ImmediateInputStreamHandler for channel selection.
|
||||||
|
optional bool async_selection = 6;
|
||||||
}
|
}
|
||||||
|
|
|
@ -252,6 +252,9 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
||||||
input_stream: "INPUT:enable"
|
input_stream: "INPUT:enable"
|
||||||
input_stream: "TICK:foo"
|
input_stream: "TICK:foo"
|
||||||
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
|
output_stream: "OUTPUT:switchcontainer__gate_enable_timed"
|
||||||
|
input_stream_handler {
|
||||||
|
input_stream_handler: "DefaultInputStreamHandler"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name: "switchcontainer__SwitchDemuxCalculator"
|
name: "switchcontainer__SwitchDemuxCalculator"
|
||||||
|
@ -306,7 +309,8 @@ TEST(SwitchContainerTest, ApplyToSubnodes) {
|
||||||
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
// Shows the SwitchContainer container runs with a pair of simple subnodes.
|
||||||
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
TEST(SwitchContainerTest, RunsWithSubnodes) {
|
||||||
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
EXPECT_TRUE(SubgraphRegistry::IsRegistered("SwitchContainer"));
|
||||||
CalculatorGraphConfig supergraph = SubnodeContainerExample();
|
CalculatorGraphConfig supergraph =
|
||||||
|
SubnodeContainerExample("async_selection: true");
|
||||||
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph));
|
||||||
RunTestContainer(supergraph);
|
RunTestContainer(supergraph);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -54,21 +55,47 @@ namespace mediapipe {
|
||||||
// contained subgraph or calculator nodes.
|
// contained subgraph or calculator nodes.
|
||||||
//
|
//
|
||||||
class SwitchDemuxCalculator : public CalculatorBase {
|
class SwitchDemuxCalculator : public CalculatorBase {
|
||||||
static constexpr char kSelectTag[] = "SELECT";
|
|
||||||
static constexpr char kEnableTag[] = "ENABLE";
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc);
|
static absl::Status GetContract(CalculatorContract* cc);
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override;
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
absl::Status Process(CalculatorContext* cc) override;
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::Status RecordPackets(CalculatorContext* cc);
|
||||||
|
int ChannelIndex(Timestamp timestamp);
|
||||||
|
absl::Status SendActivePackets(CalculatorContext* cc);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int channel_index_;
|
int channel_index_;
|
||||||
std::set<std::string> channel_tags_;
|
std::set<std::string> channel_tags_;
|
||||||
|
using PacketQueue = std::map<CollectionItemId, std::queue<Packet>>;
|
||||||
|
PacketQueue input_queue_;
|
||||||
|
std::map<Timestamp, int> channel_history_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(SwitchDemuxCalculator);
|
REGISTER_CALCULATOR(SwitchDemuxCalculator);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static constexpr char kSelectTag[] = "SELECT";
|
||||||
|
static constexpr char kEnableTag[] = "ENABLE";
|
||||||
|
|
||||||
|
// Returns the last received timestamp for an input stream.
|
||||||
|
inline Timestamp SettledTimestamp(const InputStreamShard& input) {
|
||||||
|
return input.Value().Timestamp();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the last received timestamp for channel selection.
|
||||||
|
inline Timestamp ChannelSettledTimestamp(CalculatorContext* cc) {
|
||||||
|
Timestamp result = Timestamp::Done();
|
||||||
|
if (cc->Inputs().HasTag(kEnableTag)) {
|
||||||
|
result = SettledTimestamp(cc->Inputs().Tag(kEnableTag));
|
||||||
|
} else if (cc->Inputs().HasTag(kSelectTag)) {
|
||||||
|
result = SettledTimestamp(cc->Inputs().Tag(kSelectTag));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
||||||
// Allow any one of kSelectTag, kEnableTag.
|
// Allow any one of kSelectTag, kEnableTag.
|
||||||
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||||
|
@ -125,6 +152,7 @@ absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) {
|
||||||
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||||
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
|
channel_tags_ = ChannelTags(cc->Outputs().TagMap());
|
||||||
|
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||||
|
|
||||||
// Relay side packets to all channels.
|
// Relay side packets to all channels.
|
||||||
// Note: This is necessary because Calculator::Open only proceeds when every
|
// Note: This is necessary because Calculator::Open only proceeds when every
|
||||||
|
@ -164,21 +192,77 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
|
absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) {
|
||||||
// Update the input channel index if specified.
|
MP_RETURN_IF_ERROR(RecordPackets(cc));
|
||||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
MP_RETURN_IF_ERROR(SendActivePackets(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
// Relay packets and timestamps only to channel_index_.
|
// Enqueue all arriving packets and bounds.
|
||||||
|
absl::Status SwitchDemuxCalculator::RecordPackets(CalculatorContext* cc) {
|
||||||
|
// Enqueue any new arriving packets.
|
||||||
for (const std::string& tag : channel_tags_) {
|
for (const std::string& tag : channel_tags_) {
|
||||||
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||||
auto& input = cc->Inputs().Get(tag, index);
|
auto input_id = cc->Inputs().GetId(tag, index);
|
||||||
std::string output_tag = tool::ChannelTag(tag, channel_index_);
|
Packet packet = cc->Inputs().Get(input_id).Value();
|
||||||
auto output_id = cc->Outputs().GetId(output_tag, index);
|
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||||
if (output_id.IsValid()) {
|
input_queue_[input_id].push(packet);
|
||||||
auto& output = cc->Outputs().Get(output_tag, index);
|
|
||||||
tool::Relay(input, &output);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enque any new input channel and its activation timestamp.
|
||||||
|
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||||
|
int new_channel_index = tool::GetChannelIndex(*cc, channel_index_);
|
||||||
|
if (channel_settled == cc->InputTimestamp() &&
|
||||||
|
new_channel_index != channel_index_) {
|
||||||
|
channel_index_ = new_channel_index;
|
||||||
|
channel_history_[channel_settled] = channel_index_;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the channel index for a Timestamp.
|
||||||
|
int SwitchDemuxCalculator::ChannelIndex(Timestamp timestamp) {
|
||||||
|
auto it = std::prev(channel_history_.upper_bound(timestamp));
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatches all queued input packets with known channels.
|
||||||
|
absl::Status SwitchDemuxCalculator::SendActivePackets(CalculatorContext* cc) {
|
||||||
|
// Dispatch any queued input packets with a defined channel_index.
|
||||||
|
Timestamp channel_settled = ChannelSettledTimestamp(cc);
|
||||||
|
for (const std::string& tag : channel_tags_) {
|
||||||
|
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||||
|
auto input_id = cc->Inputs().GetId(tag, index);
|
||||||
|
auto& queue = input_queue_[input_id];
|
||||||
|
while (!queue.empty() && queue.front().Timestamp() <= channel_settled) {
|
||||||
|
int channel_index = ChannelIndex(queue.front().Timestamp());
|
||||||
|
std::string output_tag = tool::ChannelTag(tag, channel_index);
|
||||||
|
auto output_id = cc->Outputs().GetId(output_tag, index);
|
||||||
|
if (output_id.IsValid()) {
|
||||||
|
cc->Outputs().Get(output_id).AddPacket(queue.front());
|
||||||
|
}
|
||||||
|
queue.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discard all select packets not needed for any remaining input packets.
|
||||||
|
Timestamp input_settled = Timestamp::Done();
|
||||||
|
for (const std::string& tag : channel_tags_) {
|
||||||
|
for (int index = 0; index < cc->Inputs().NumEntries(tag); ++index) {
|
||||||
|
auto input_id = cc->Inputs().GetId(tag, index);
|
||||||
|
Timestamp stream_settled = SettledTimestamp(cc->Inputs().Get(input_id));
|
||||||
|
if (!input_queue_[input_id].empty()) {
|
||||||
|
Timestamp stream_bound = input_queue_[input_id].front().Timestamp();
|
||||||
|
stream_settled =
|
||||||
|
std::min(stream_settled, stream_bound.PreviousAllowedInStream());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Timestamp input_bound = input_settled.NextAllowedInStream();
|
||||||
|
auto history_bound = std::prev(channel_history_.upper_bound(input_bound));
|
||||||
|
channel_history_.erase(channel_history_.begin(), history_bound);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) {
|
||||||
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
|
options_ = cc->Options<mediapipe::SwitchContainerOptions>();
|
||||||
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
channel_index_ = tool::GetChannelIndex(*cc, channel_index_);
|
||||||
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
|
channel_tags_ = ChannelTags(cc->Inputs().TagMap());
|
||||||
channel_history_[Timestamp::Unset()] = channel_index_;
|
channel_history_[Timestamp::Unstarted()] = channel_index_;
|
||||||
|
|
||||||
// Relay side packets only from channel_index_.
|
// Relay side packets only from channel_index_.
|
||||||
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
|
for (const std::string& tag : ChannelTags(cc->InputSidePackets().TagMap())) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user