mediapipe/mediapipe/framework/tool/switch_mux_calculator_test.cc
MediaPipe Team 4dc4b19ddb Project import generated by Copybara.
GitOrigin-RevId: 1e13be30e2c6838d4a2ff768a39c414bc80534bb
2022-09-06 21:46:17 +00:00

330 lines
14 KiB
C++

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/container_util.h"
namespace mediapipe {
namespace {
// Returns a CalculatorGraph to run a single calculator.
CalculatorGraph BuildCalculatorGraph(CalculatorGraphConfig::Node node_config) {
CalculatorGraphConfig config;
*config.add_node() = node_config;
*config.mutable_input_stream() = node_config.input_stream();
*config.mutable_output_stream() = node_config.output_stream();
*config.mutable_input_side_packet() = node_config.input_side_packet();
*config.mutable_output_side_packet() = node_config.output_side_packet();
return CalculatorGraph(config);
}
// Creates a string packet.
Packet pack(std::string data, int timestamp) {
return MakePacket<std::string>(data).At(Timestamp(timestamp));
}
// Creates an int packet.
Packet pack(int data, int timestamp) {
return MakePacket<int>(data).At(Timestamp(timestamp));
}
// Tests showing packet channel synchronization through SwitchMuxCalculator.
class SwitchMuxCalculatorTest : public ::testing::Test {
protected:
SwitchMuxCalculatorTest() {}
~SwitchMuxCalculatorTest() override {}
void SetUp() override {}
void TearDown() override {}
// Defines a SwitchMuxCalculator CalculatorGraphConfig::Node.
CalculatorGraphConfig::Node BuildNodeConfig() {
CalculatorGraphConfig::Node result;
*result.mutable_calculator() = "SwitchMuxCalculator";
*result.add_input_stream() = "SELECT:select";
for (int c = 0; c < 3; ++c) {
*result.add_input_stream() =
absl::StrCat(tool::ChannelTag("FRAME", c), ":frame_", c);
*result.add_input_stream() =
absl::StrCat(tool::ChannelTag("MASK", c), ":mask_", c);
}
*result.add_output_stream() = "FRAME:frame";
*result.add_output_stream() = "MASK:mask";
return result;
}
};
// Shows the SwitchMuxCalculator is available.
TEST_F(SwitchMuxCalculatorTest, IsRegistered) {
EXPECT_TRUE(CalculatorBaseRegistry::IsRegistered("SwitchMuxCalculator"));
}
// Shows that channels are queued until packets arrive.
TEST_F(SwitchMuxCalculatorTest, ChannelEarly) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames;
MP_ASSERT_OK(graph.ObserveOutputStream("frame", [&](const Packet& p) {
output_frames.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Finalize input for the "mask" input stream.
MP_ASSERT_OK(graph.CloseInputStream("mask_0"));
MP_ASSERT_OK(graph.CloseInputStream("mask_1"));
MP_ASSERT_OK(graph.CloseInputStream("mask_2"));
// All channels are specified before any frame packets arrive.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 1)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 10)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 11)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 21)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
// The packet at timestamp 10 is passed from channel 0.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t10", 10)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
EXPECT_EQ(output_frames[0].Get<std::string>(), "p0_t10");
// The packet at timestamp 20 is passed from channel 1.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t20", 20)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 2);
EXPECT_EQ(output_frames[1].Get<std::string>(), "p1_t20");
// The packet at timestamp 30 is passed from channel 0.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t30", 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 3);
EXPECT_EQ(output_frames[2].Get<std::string>(), "p0_t30");
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that packets are queued until channels are specified.
TEST_F(SwitchMuxCalculatorTest, ChannelsLate) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames;
MP_ASSERT_OK(graph.ObserveOutputStream("frame", [&](const Packet& p) {
output_frames.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Finalize input for the "mask" input stream.
MP_ASSERT_OK(graph.CloseInputStream("mask_0"));
MP_ASSERT_OK(graph.CloseInputStream("mask_1"));
MP_ASSERT_OK(graph.CloseInputStream("mask_2"));
// All frame packets arrive before any channels are specified.
// All packets are queued awaiting channel choices.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t10", 10)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t20", 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t30", 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 0);
// The packet at timestamp 10 is released from channel 0.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 1)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 10)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
EXPECT_EQ(output_frames[0].Get<std::string>(), "p0_t10");
// The packet at timestamp 20 is released from channel 1.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 11)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 20)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 2);
EXPECT_EQ(output_frames[1].Get<std::string>(), "p1_t20");
// The packet at timestamp 30 is released from channel 0.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 21)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 3);
EXPECT_EQ(output_frames[2].Get<std::string>(), "p0_t30");
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that no packets are queued when channels are specified just in time.
TEST_F(SwitchMuxCalculatorTest, ChannelsOnTime) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames;
MP_ASSERT_OK(graph.ObserveOutputStream("frame", [&](const Packet& p) {
output_frames.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Finalize input for the "mask" input stream.
MP_ASSERT_OK(graph.CloseInputStream("mask_0"));
MP_ASSERT_OK(graph.CloseInputStream("mask_1"));
MP_ASSERT_OK(graph.CloseInputStream("mask_2"));
// Channel 0 is selected just before corresponding packets arrive.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 1)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 10)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t10", 10)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
EXPECT_EQ(output_frames[0].Get<std::string>(), "p0_t10");
// Channel 1 is selected just before corresponding packets arrive.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 11)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t20", 20)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 2);
EXPECT_EQ(output_frames[1].Get<std::string>(), "p1_t20");
// Channel 0 is selected just before corresponding packets arrive.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 21)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 30)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t30", 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 3);
EXPECT_EQ(output_frames[2].Get<std::string>(), "p0_t30");
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows what happens if the last packet from a channel never arrives.
TEST_F(SwitchMuxCalculatorTest, ChannelNeverCompletes) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames;
MP_ASSERT_OK(graph.ObserveOutputStream("frame", [&](const Packet& p) {
output_frames.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Finalize input for the "mask" input stream.
MP_ASSERT_OK(graph.CloseInputStream("mask_0"));
MP_ASSERT_OK(graph.CloseInputStream("mask_1"));
MP_ASSERT_OK(graph.CloseInputStream("mask_2"));
// Channel 0 is selected, but it's closing packet never arrives.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 1)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 10)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 0);
// Channel 1 is selected, but we still wait for channel 0 to finish.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 11)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t20", 20)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 0);
// Finally channel 0 advances, and channel 1 can be delivered.
// Note that "p0_t15" is discarded because its channel is deselected.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t15", 15)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
EXPECT_EQ(output_frames[0].Get<std::string>(), "p1_t20");
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows how output is sequenced when one channel is delayed substantially.
// Queues up "SELECT" packets for channel 0, 1, and 2.
// Queues up "frame" packets for channel 0 and 2.
// The output packets from channel 1, 2, and 0 wait for channel 1.
TEST_F(SwitchMuxCalculatorTest, OneChannelIsSlow) {
CalculatorGraphConfig::Node node_config = BuildNodeConfig();
CalculatorGraph graph = BuildCalculatorGraph(node_config);
std::vector<Packet> output_frames;
MP_ASSERT_OK(graph.ObserveOutputStream("frame", [&](const Packet& p) {
output_frames.push_back(p);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Finalize input for the "mask" input stream.
MP_ASSERT_OK(graph.CloseInputStream("mask_0"));
MP_ASSERT_OK(graph.CloseInputStream("mask_1"));
MP_ASSERT_OK(graph.CloseInputStream("mask_2"));
// Channel 0 is selected, and some packets arrive.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 1)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 10)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t20", 10)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
// Channel 1 is selected, but its packets are delayed.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 11)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t20", 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_2", pack("p2_t20", 20)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
// Channel 2 is selected, packets arrive, but wait for channel 1.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(2, 21)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(2, 30)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_2", pack("p2_t30", 30)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
// Channel 0 is selected again, packets arrive, but wait for channel 1.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 31)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(0, 40)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_0", pack("p0_t40", 40)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
// Channel 1 is selected again, but its packets are still delayed.
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 41)));
MP_ASSERT_OK(graph.AddPacketToInputStream("select", pack(1, 50)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 1);
// Finally, the delayed packets from channel 1 arrive.
// And all packets for all five "SELECT"" inetervals are delivered.
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t20", 20)));
MP_ASSERT_OK(graph.AddPacketToInputStream("frame_1", pack("p1_t50", 50)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(output_frames.size(), 5);
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe