diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index aacf694c1..729e91492 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -727,6 +727,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", ], alwayslink = 1, ) @@ -742,6 +743,7 @@ cc_test( "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:options_util", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index 311f7d815..686d705dd 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/status/status.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/ret_check.h" @@ -32,6 +33,7 @@ namespace { constexpr char kTagAtPreStream[] = "AT_PRESTREAM"; constexpr char kTagAtPostStream[] = "AT_POSTSTREAM"; constexpr char kTagAtZero[] = "AT_ZERO"; +constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK"; constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagTick[] = "TICK"; constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP"; @@ -43,6 +45,7 @@ static std::map* kTimestampMap = []() { res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtTick, Timestamp::Unset()); + res->emplace(kTagAtFirstTick, Timestamp::Unset()); res->emplace(kTagAtTimestamp, Timestamp::Unset()); return res; }(); @@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) { // timestamp, depending on the tag used to define output stream(s). (One tag can // be used only.) // -// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP -// and corresponding timestamps are Timestamp::PreStream(), +// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK, +// AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(), // Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK // input, and timestamp received from a side input. // @@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase { private: bool is_tick_processing_ = false; + bool close_on_first_tick_ = false; std::string output_tag_; }; REGISTER_CALCULATOR(SidePacketToStreamCalculator); @@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator); absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { const auto& tags = cc->Outputs().GetTags(); RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) - << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s)."; - RET_CHECK( - (cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || - (!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) - << "Either both of TICK and AT_TICK should be used or none of them."; + << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, " + "AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to " + "specify output stream(s)."; + const bool has_tick_output = + cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick); + const bool has_tick_input = cc->Inputs().HasTag(kTagTick); + RET_CHECK((has_tick_output && has_tick_input) || + (!has_tick_output && !has_tick_input)) + << "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them."; RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) && cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) || (!cc->Outputs().HasTag(kTagAtTimestamp) && @@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { // timestamp bound update. cc->SetOffset(TimestampDiff(0)); } + if (output_tag_ == kTagAtFirstTick) { + close_on_first_tick_ = true; + } return absl::OkStatus(); } absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { if (is_tick_processing_) { + if (cc->Outputs().Get(output_tag_, 0).IsClosed()) { + return absl::OkStatus(); + } // TICK input is guaranteed to be non-empty, as it's the only input stream // for this calculator. const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp(); @@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Get(output_tag_, i) .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); + if (close_on_first_tick_) { + cc->Outputs().Get(output_tag_, i).Close(); + } } return absl::OkStatus(); @@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { if (!cc->Outputs().HasTag(kTagAtTick) && + !cc->Outputs().HasTag(kTagAtFirstTick) && !cc->Outputs().HasTag(kTagAtTimestamp)) { const auto& timestamp = kTimestampMap->at(output_tag_); for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc index 086b73fcd..6c0941b44 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc @@ -27,13 +27,17 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/tool/options_util.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { -using testing::HasSubstr; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; -TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { +TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { EXPECT_THAT( status.message(), HasSubstr( - "Either both of TICK and AT_TICK should be used or none of them.")); + "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { +TEST(SidePacketToStreamCalculator, + WrongConfigWithMissingTickForFirstTickProcessing) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_side_packet: "side_packet" + output_stream: "AT_FIRST_TICK:packet" + } + )pb"); + CalculatorGraph graph; + auto status = graph.Initialize(graph_config); + EXPECT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr( + "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output " + "should be used or none of them.")); +} + +TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { "or none of them.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s).")); + EXPECT_THAT(status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, " + "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is " + "allowed and required to specify output stream(s).")); } -TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { +TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { CalculatorGraph graph; auto status = graph.Initialize(graph_config); EXPECT_FALSE(status.ok()); - EXPECT_THAT( - status.message(), - HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " - "AT_TIMESTAMP tags is allowed and required to specify output " - "stream(s).")); + EXPECT_THAT(status.message(), + HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, " + "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is " + "allowed and required to specify output stream(s).")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { "Same number of input side packets and output streams is required.")); } -TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { +TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) { tick_and_verify(/*at_timestamp=*/1025); } -TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { +TEST(SidePacketToStreamCalculator, AtFirstTick) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet" + output_stream: "packet" + node { + calculator: "SidePacketToStreamCalculator" + input_stream: "TICK:tick" + input_side_packet: "side_packet" + output_stream: "AT_FIRST_TICK:packet" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("packet", &graph_config, &output_packets); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value = 20; + const Timestamp kTestTimestamp(1234); + MP_ASSERT_OK( + graph.StartRun({{"side_packet", MakePacket(expected_value)}})); + + auto insert_tick = [&graph](Timestamp at_timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tick", MakePacket(/*doesn't matter*/ 1).At(at_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + insert_tick(kTestTimestamp); + + EXPECT_THAT(output_packets, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value)))); + + output_packets.clear(); + + // Should not result in an additional output. + insert_tick(kTestTimestamp + 1); + EXPECT_THAT(output_packets, IsEmpty()); +} + +TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb( @@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { tick_and_verify(/*at_timestamp=*/1025); } +TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) { + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie( + R"pb( + input_stream: "tick" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "packet0" + output_stream: "packet1" + node { + calculator: "SidePacketToStreamCalculator" + input_stream: "TICK:tick" + input_side_packet: "side_packet0" + input_side_packet: "side_packet1" + output_stream: "AT_FIRST_TICK:0:packet0" + output_stream: "AT_FIRST_TICK:1:packet1" + } + )pb"); + std::vector output_packets0; + tool::AddVectorSink("packet0", &graph_config, &output_packets0); + std::vector output_packets1; + tool::AddVectorSink("packet1", &graph_config, &output_packets1); + CalculatorGraph graph; + + MP_ASSERT_OK(graph.Initialize(graph_config)); + const int expected_value0 = 20; + const int expected_value1 = 128; + const Timestamp kTestTimestamp(1234); + MP_ASSERT_OK( + graph.StartRun({{"side_packet0", MakePacket(expected_value0)}, + {"side_packet1", MakePacket(expected_value1)}})); + + auto insert_tick = [&graph](Timestamp at_timestamp) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tick", MakePacket(/*doesn't matter*/ 1).At(at_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + insert_tick(kTestTimestamp); + + EXPECT_THAT(output_packets0, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value0)))); + EXPECT_THAT(output_packets1, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(kTestTimestamp), Eq(expected_value1)))); + + output_packets0.clear(); + output_packets1.clear(); + + // Should not result in an additional output. + insert_tick(kTestTimestamp + 1); + EXPECT_THAT(output_packets0, IsEmpty()); + EXPECT_THAT(output_packets1, IsEmpty()); +} + TEST(SidePacketToStreamCalculator, AtTimestamp) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( @@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) { EXPECT_EQ(expected_value, output_packets.back().Get()); } -TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) { +TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie( R"pb(