diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 45bace271..5d0594de9 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) { : absl::StrCat("Timestamp(", t.DebugString(), ")"); } -template -std::string SourceString(Packet packet) { - std::ostringstream oss; - if (packet.IsEmpty()) { - oss << "Packet()"; - } else { - oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" - << packet.Get() << ")"; - } - oss << ".At(" << SourceString(packet.Timestamp()) << ")"; - return oss.str(); -} - -template -class PacketsEqMatcher - : public ::testing::MatcherInterface { - public: - PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} - void DescribeTo(::std::ostream* os) const override { - *os << "The expected packet contents: \n"; - Print(packets_, os); - } - bool MatchAndExplain( - const PacketContainer& value, - ::testing::MatchResultListener* listener) const override { - if (!Equals(packets_, value)) { - if (listener->IsInterested()) { - *listener << "The actual packet contents: \n"; - Print(value, listener->stream()); - } - return false; - } - return true; - } - - private: - bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { - if (c1.size() != c2.size()) { - return false; - } - for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { - Packet p1 = *i1, p2 = *i2; - if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() || - (!p1.IsEmpty() && - p1.Get() != p2.Get())) { - return false; - } - } - return true; - } - void Print(const PacketContainer& packets, ::std::ostream* os) const { - for (auto it = packets.begin(); it != packets.end(); ++it) { - const Packet& packet = *it; - *os << (it == packets.begin() ? "{" : ""); - *os << SourceString(packet); - *os << (std::next(it) == packets.end() ? "}" : ", "); - } - } - - const PacketContainer packets_; -}; - -template -::testing::Matcher PacketsEq( - const PacketContainer& packets) { - return MakeMatcher( - new PacketsEqMatcher(packets)); -} - // A Calculator::Process callback function. typedef std::function @@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2)); + EXPECT_THAT(out_2_packets, + ElementsAreArray(PacketMatchers(expected_output_2))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MakePacket(true).At(Timestamp(190000)), MakePacket(false).At(Timestamp(200000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } std::vector StripBoundsUpdates(const std::vector& packets, @@ -891,9 +823,6 @@ std::vector StripBoundsUpdates(const std::vector& packets, // Shows how FlowLimiterCalculator releases auxiliary input packets. // In this test, auxiliary input packets arrive at twice the primary rate. TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(6).At(Timestamp(60000)), Packet().At(Timestamp(80000)), }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); // Packets following input packets 2 and 6, and not input packets 4 and 8. std::vector expected_auxiliary_output = { @@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { }; std::vector actual_2 = StripBoundsUpdates(out_2_packets, Timestamp(90000)); - EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output)); + EXPECT_THAT(actual_2, + ElementsAreArray(PacketMatchers(expected_auxiliary_output))); std::vector expected_3 = StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999)); std::vector actual_3 = StripBoundsUpdates(out_3_packets, Timestamp(39999)); - EXPECT_THAT(actual_3, IntPacketsEq(expected_3)); + EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers(expected_3))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(true).At(Timestamp(60000)), MakePacket(false).At(Timestamp(80000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } } // anonymous namespace diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19c51853c..8ccdac3b9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1469,6 +1469,7 @@ cc_test( "//mediapipe/framework/stream_handler:mux_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index c17a2e1e2..526a74835 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { manager_->LockIntroData(); } +void CalculatorGraph::GraphInputStream::SetNextTimestampBound( + Timestamp timestamp) { + shard_.SetNextTimestampBound(timestamp); +} + void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { - // Since GraphInputStream doesn't allow SetOffset() and - // SetNextTimestampBound(), the timestamp bound to propagate is only - // determined by the timestamp of the output packets. - CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() - << "\" failed"; - manager_->PropagateUpdatesToMirrors( - shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); + manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_); } void CalculatorGraph::GraphInputStream::Close() { @@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream( return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } +absl::Status CalculatorGraph::SetInputStreamTimestampBound( + const std::string& stream_name, Timestamp timestamp) { + std::unique_ptr* stream = + mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamTimestampBound called on input stream \"$0\" which is not " + "a graph input stream.", + stream_name); + (*stream)->SetNextTimestampBound(timestamp); + (*stream)->PropagateUpdatesToMirrors(); + return absl::OkStatus(); +} + // We avoid having two copies of this code for AddPacketToInputStream( // const Packet&) and AddPacketToInputStream(Packet &&) by having this // internal-only templated version. T&& is a forwarding reference here, so diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index c51476102..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -257,6 +257,10 @@ class CalculatorGraph { absl::Status AddPacketToInputStream(const std::string& stream_name, Packet&& packet); + // Indicates that input will arrive no earlier than a certain timestamp. + absl::Status SetInputStreamTimestampBound(const std::string& stream_name, + Timestamp timestamp); + // Sets the queue size of a graph input stream, overriding the graph default. absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, int max_queue_size); @@ -425,6 +429,8 @@ class CalculatorGraph { void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + void SetNextTimestampBound(Timestamp timestamp); + void PropagateUpdatesToMirrors(); void Close(); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index b55f9459d..d149337cc 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { @@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(EmptyPacketCalculator); -// This test shows that an output timestamp bound can be specified by outputing +// This test shows that an output timestamp bound can be specified by outputting // an empty packet with a settled timestamp. TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { // OffsetAndBoundCalculator runs on parallel threads and sends ts @@ -1580,6 +1583,195 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); } + // Shut down the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows that input timestamp bounds can be specified using +// CalculatorGraph::SetInputStreamTimestampBound. +TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in timestamp bounds. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + MP_ASSERT_OK(graph.SetInputStreamTimestampBound( + "input_0", Timestamp(ts).NextAllowedInStream())); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 timestamp bounds are converted to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows how an input stream with infrequent packets, such as +// configuration protobufs, can be consumed while processing more frequent +// packets, such as video frames. +TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) { + // PassThroughCalculator consuming two input streams, with default ISH. + std::string config_str = R"pb( + input_stream: "INFREQUENT:config" + input_stream: "FREQUENT:frame" + node { + calculator: "PassThroughCalculator" + input_stream: "CONFIG:config" + input_stream: "VIDEO:frame" + output_stream: "VIDEO:output_frame" + output_stream: "CONFIG:output_config" + } + )pb"; + + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector frame_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_frame", + [&](const Packet& p) { + frame_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + std::vector config_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_config", + [&](const Packet& p) { + config_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Utility functions to send packets or timestamp bounds. + auto send_fn = [&](std::string stream, std::string value, int ts) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + stream, + MakePacket(absl::StrCat(value)).At(Timestamp(ts)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + auto bound_fn = [&](std::string stream, int ts) { + MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + // Send in a frame packet. + send_fn("frame", "frame_0", 0); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers({}))); + bound_fn("config", 10000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_1", 20000); + // The frame is not processed yet. + // The PassThroughCalculator with TimestampOffset 0 now propagates + // Timestamp bound 10000 to both "output_frame" and "output_config", + // which appears here as Packet().At(Timestamp(9999). The timestamp + // bounds at 29999 and 50000 are propagated similarly. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + }))); + bound_fn("config", 30000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_2", 40000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + }))); + send_fn("config", "config_1", 50000); + // The frame is processed after a fresh config arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_3", 60000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + }))); + bound_fn("config", 70000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + MakePacket("frame_3").At(Timestamp(60000)), + }))); + + // One config packet is deleivered. + EXPECT_THAT(config_packets, + ElementsAreArray(PacketMatchers({ + Packet().At(Timestamp(0)), + Packet().At(Timestamp(9999)), + Packet().At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + Packet().At(Timestamp(40000)), + MakePacket("config_1").At(Timestamp(50000)), + Packet().At(Timestamp(60000)), + }))); + // Shutdown the graph. MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); diff --git a/mediapipe/util/packet_test_util.h b/mediapipe/util/packet_test_util.h index 106d7f8d4..61e9322e1 100644 --- a/mediapipe/util/packet_test_util.h +++ b/mediapipe/util/packet_test_util.h @@ -32,30 +32,29 @@ namespace mediapipe { namespace internal { template -class PacketMatcher : public ::testing::MatcherInterface { +class PacketMatcher : public testing::MatcherInterface { public: template explicit PacketMatcher(InnerMatcher inner_matcher) : inner_matcher_( - ::testing::SafeMatcherCast(inner_matcher)) {} + testing::SafeMatcherCast(inner_matcher)) {} // Returns true iff the packet contains value of PayloadType satisfying // the inner matcher. - bool MatchAndExplain( - const Packet& packet, - ::testing::MatchResultListener* listener) const override { + bool MatchAndExplain(const Packet& packet, + testing::MatchResultListener* listener) const override { if (!packet.ValidateAsType().ok()) { *listener << packet.DebugString() << " does not contain expected type " << ExpectedTypeName(); return false; } - ::testing::StringMatchResultListener match_listener; + testing::StringMatchResultListener match_listener; const PayloadType& payload = packet.Get(); const bool matches = inner_matcher_.MatchAndExplain(payload, &match_listener); const std::string explanation = match_listener.str(); *listener << packet.DebugString() << " containing value " - << ::testing::PrintToString(payload); + << testing::PrintToString(payload); if (!explanation.empty()) { *listener << ", which " << explanation; } @@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface { return ::mediapipe::Demangle(typeid(PayloadType).name()); } - const ::testing::Matcher inner_matcher_; + const testing::Matcher inner_matcher_; }; +inline std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +std::string SourceString(Packet packet) { + std::ostringstream oss; + if (packet.IsEmpty()) { + oss << "Packet()"; + } else { + oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" + << packet.Get() << ")"; + } + oss << ".At(" << SourceString(packet.Timestamp()) << ")"; + return oss.str(); +} + } // namespace internal // Creates matcher validating that the packet contains value of expected type @@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface { // // EXPECT_THAT(MakePacket(42), PacketContains(Eq(42))) template -inline ::testing::Matcher PacketContains( +inline testing::Matcher PacketContains( InnerMatcher inner_matcher) { - return ::testing::MakeMatcher( + return testing::MakeMatcher( new internal::PacketMatcher(inner_matcher)); } @@ -110,7 +128,7 @@ inline ::testing::Matcher PacketContains( // Eq(42))) template -inline ::testing::Matcher PacketContainsTimestampAndPayload( +inline testing::Matcher PacketContainsTimestampAndPayload( TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) { return testing::AllOf( testing::Property("Packet::Timestamp", &Packet::Timestamp, @@ -118,6 +136,46 @@ inline ::testing::Matcher PacketContainsTimestampAndPayload( PacketContains(content_matcher)); } +template +class PacketEqMatcher : public testing::MatcherInterface { + public: + PacketEqMatcher(Packet packet) : packet_(packet) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet: " << internal::SourceString(packet_); + } + bool MatchAndExplain(Packet value, + testing::MatchResultListener* listener) const override { + bool unequal = (value.Timestamp() != packet_.Timestamp() || + value.IsEmpty() != packet_.IsEmpty() || + (!value.IsEmpty() && value.Get() != packet_.Get())); + if (unequal && listener->IsInterested()) { + *listener << "The actual packet: " << internal::SourceString(value); + } + return !unequal; + } + const Packet packet_; +}; + +template +testing::Matcher PacketEq(Packet packet) { + return MakeMatcher(new PacketEqMatcher(packet)); +} + +template +std::vector> PacketMatchers( + std::vector packets) { + std::vector> result; + for (const auto& packet : packets) { + result.push_back(PacketEq(packet)); + } + return result; +} + +} // namespace mediapipe + +namespace mediapipe { +using mediapipe::PacketContains; +using mediapipe::PacketContainsTimestampAndPayload; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_