diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index a7d9e0a63..86608285b 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1355,6 +1355,22 @@ cc_test( ], ) +cc_test( + name = "calculator_graph_summary_packet_test", + srcs = ["calculator_graph_summary_packet_test.cc"], + deps = [ + ":calculator_framework", + ":packet", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/stream_handler:immediate_input_stream_handler", + "//mediapipe/framework/tool:sink", + ], +) + cc_test( name = "calculator_runner_test", size = "medium", diff --git a/mediapipe/framework/calculator_graph_summary_packet_test.cc b/mediapipe/framework/calculator_graph_summary_packet_test.cc new file mode 100644 index 000000000..c8d1e7eb7 --- /dev/null +++ b/mediapipe/framework/calculator_graph_summary_packet_test.cc @@ -0,0 +1,327 @@ +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Node; +using ::mediapipe::api2::Output; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Value; + +namespace { + +MATCHER_P2(IntPacket, value, timestamp, "") { + *result_listener << "where object is (value: " << arg.template Get() + << ", timestamp: " << arg.Timestamp() << ")"; + return Value(arg.template Get(), Eq(value)) && + Value(arg.Timestamp(), Eq(timestamp)); +} + +// Calculates and produces sum of all passed inputs when no more packets can be +// expected on the input stream. +class SummaryPacketCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"SUMMARY"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc) { + // Makes sure there are no automatic timestamp bound updates when Process + // is called. + cc->SetTimestampOffset(TimestampDiff::Unset()); + // Currently, only ImmediateInputStreamHandler supports "done" timestamp + // bound update. (ImmediateInputStreamhandler handles multiple input + // streams differently, so, in that case, calculator adjustments may be + // required.) + // TODO: update all input stream handlers to support "done" + // timestamp bound update. + cc->SetInputStreamHandler("ImmediateInputStreamHandler"); + // Enables processing timestamp bound updates. For this use case we are + // specifically interested in "done" timestamp bound update. (E.g. when + // all input packet sources are closed.) + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + if (!kIn(cc).IsEmpty()) { + value_ += kIn(cc).Get(); + } + + if (kOut(cc).IsClosed()) { + // This can happen: + // 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g. + // source calculator finished generating packets sent to kIn) and + // HasNextAllowedInStream() == true (which is an often case). + // 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still + // invoke Process() with Timestamp::Max to indicate "Done" timestamp + // bound update. + return absl::OkStatus(); + } + + // TODO: input stream holding a packet with timestamp that has + // no next timestamp allowed in stream should always result in + // InputStream::IsDone() == true. + if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) { + // kOut(cc).Send(value_) can be used here as well, however in the case of + // source calculator sending inputs into kIn the resulting timestamp is + // not well defined (e.g. it can be the last packet timestamp or + // Timestamp::Max()) + // TODO: last packet from source should always result in + // InputStream::IsDone() == true. + kOut(cc).Send(value_, Timestamp::Max()); + kOut(cc).Close(); + } + return absl::OkStatus(); + } + + private: + int value_ = 0; +}; +MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnClosingAllPacketSources) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp(10)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + send_packet(20, Timestamp(11)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); +} + +TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp(10)); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + send_packet(20, Timestamp::Max()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnPreStreamTimestamp) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp::PreStream()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnPostStreamTimestamp) { + std::vector output_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + input_stream: 'input' + node { + calculator: "SummaryPacketCalculator" + input_stream: 'IN:input' + output_stream: 'SUMMARY:output' + } + )pb"); + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + auto send_packet = [&graph](int value, Timestamp timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input", MakePacket(value).At(timestamp))); + }; + + send_packet(10, Timestamp::PostStream()); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +class IntGeneratorCalculator : public Node { + public: + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kOut); + + absl::Status Process(CalculatorContext* cc) final { + kOut(cc).Send(20, Timestamp(0)); + kOut(cc).Send(10, Timestamp(1000)); + return tool::StatusStop(); + } +}; +MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnSourceCalculatorCompletion) { + std::vector output_packets; + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + node { + calculator: "IntGeneratorCalculator" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_EXPECT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); +} + +class EmitOnCloseCalculator : public Node { + public: + static constexpr Input kIn{"IN"}; + static constexpr Output kOut{"INT"}; + + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } + + absl::Status Close(CalculatorContext* cc) final { + kOut(cc).Send(20, Timestamp(0)); + kOut(cc).Send(10, Timestamp(1000)); + return absl::OkStatus(); + } +}; +MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator); + +TEST(SummaryPacketCalculatorUseCaseTest, + ProducesSummaryPacketOnAnotherCalculatorClosure) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "input" + node { + calculator: "EmitOnCloseCalculator" + input_stream: "IN:input" + output_stream: "INT:int_value" + } + node { + calculator: "SummaryPacketCalculator" + input_stream: "IN:int_value" + output_stream: "SUMMARY:output" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("output", &graph_config, &output_packets); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config, {})); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, IsEmpty()); + + MP_ASSERT_OK(graph.CloseInputStream("input")); + MP_ASSERT_OK(graph.WaitUntilIdle()); + EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max()))); + + output_packets.clear(); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_THAT(output_packets, IsEmpty()); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/timestamp.cc b/mediapipe/framework/timestamp.cc index 05b69747f..4ece74c99 100644 --- a/mediapipe/framework/timestamp.cc +++ b/mediapipe/framework/timestamp.cc @@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const { return *this + 1; } +bool Timestamp::HasNextAllowedInStream() const { + if (*this >= Max() || *this == PreStream()) { + return false; + } + return true; +} + Timestamp Timestamp::PreviousAllowedInStream() const { if (*this <= Min() || *this == PostStream()) { // Indicates that no previous timestamps may occur. diff --git a/mediapipe/framework/timestamp.h b/mediapipe/framework/timestamp.h index 966ec1839..d125d28bb 100644 --- a/mediapipe/framework/timestamp.h +++ b/mediapipe/framework/timestamp.h @@ -186,6 +186,10 @@ class Timestamp { // CHECKs that this->IsAllowedInStream(). Timestamp NextAllowedInStream() const; + // Returns true if there's a next timestamp in the range [Min .. Max] after + // this one. + bool HasNextAllowedInStream() const; + // Returns the previous timestamp in the range [Min .. Max], or // Unstarted() if no Packets may preceed one with this timestamp. Timestamp PreviousAllowedInStream() const; diff --git a/mediapipe/framework/timestamp_test.cc b/mediapipe/framework/timestamp_test.cc index 5f5cc3428..3ba0b5c36 100644 --- a/mediapipe/framework/timestamp_test.cc +++ b/mediapipe/framework/timestamp_test.cc @@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) { Timestamp::PostStream().NextAllowedInStream()); } +TEST(TimestampTest, HasNextAllowedInStream) { + EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream()); + EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream()); + EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream()); + + EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream()); + EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream()); +} + TEST(TimestampTest, SpecialValueDifferences) { { // Lower range const std::vector timestamps = {