Add AT_FIRST_TICK processing to SidePacketToStreamCalculator.

PiperOrigin-RevId: 578824863
This commit is contained in:
MediaPipe Team 2023-11-02 05:56:09 -07:00 committed by Copybara-Service
parent e81fc5d0aa
commit f8197651e8
3 changed files with 174 additions and 29 deletions

View File

@ -727,6 +727,7 @@ cc_library(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -742,6 +743,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util", "//mediapipe/framework/tool:options_util",
"//mediapipe/util:packet_test_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -17,6 +17,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -32,6 +33,7 @@ namespace {
constexpr char kTagAtPreStream[] = "AT_PRESTREAM"; constexpr char kTagAtPreStream[] = "AT_PRESTREAM";
constexpr char kTagAtPostStream[] = "AT_POSTSTREAM"; constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
constexpr char kTagAtZero[] = "AT_ZERO"; constexpr char kTagAtZero[] = "AT_ZERO";
constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK";
constexpr char kTagAtTick[] = "AT_TICK"; constexpr char kTagAtTick[] = "AT_TICK";
constexpr char kTagTick[] = "TICK"; constexpr char kTagTick[] = "TICK";
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP"; constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
@ -43,6 +45,7 @@ static std::map<std::string, Timestamp>* kTimestampMap = []() {
res->emplace(kTagAtPostStream, Timestamp::PostStream()); res->emplace(kTagAtPostStream, Timestamp::PostStream());
res->emplace(kTagAtZero, Timestamp(0)); res->emplace(kTagAtZero, Timestamp(0));
res->emplace(kTagAtTick, Timestamp::Unset()); res->emplace(kTagAtTick, Timestamp::Unset());
res->emplace(kTagAtFirstTick, Timestamp::Unset());
res->emplace(kTagAtTimestamp, Timestamp::Unset()); res->emplace(kTagAtTimestamp, Timestamp::Unset());
return res; return res;
}(); }();
@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) {
// timestamp, depending on the tag used to define output stream(s). (One tag can // timestamp, depending on the tag used to define output stream(s). (One tag can
// be used only.) // be used only.)
// //
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP // Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK,
// and corresponding timestamps are Timestamp::PreStream(), // AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(),
// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK // Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
// input, and timestamp received from a side input. // input, and timestamp received from a side input.
// //
@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase {
private: private:
bool is_tick_processing_ = false; bool is_tick_processing_ = false;
bool close_on_first_tick_ = false;
std::string output_tag_; std::string output_tag_;
}; };
REGISTER_CALCULATOR(SidePacketToStreamCalculator); REGISTER_CALCULATOR(SidePacketToStreamCalculator);
@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
const auto& tags = cc->Outputs().GetTags(); const auto& tags = cc->Outputs().GetTags();
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, "
"AT_TIMESTAMP tags is allowed and required to specify output " "AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to "
"stream(s)."; "specify output stream(s).";
RET_CHECK( const bool has_tick_output =
(cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) || cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick);
(!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick))) const bool has_tick_input = cc->Inputs().HasTag(kTagTick);
<< "Either both of TICK and AT_TICK should be used or none of them."; RET_CHECK((has_tick_output && has_tick_input) ||
(!has_tick_output && !has_tick_input))
<< "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them.";
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) && RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) || cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
(!cc->Outputs().HasTag(kTagAtTimestamp) && (!cc->Outputs().HasTag(kTagAtTimestamp) &&
@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) {
// timestamp bound update. // timestamp bound update.
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
} }
if (output_tag_ == kTagAtFirstTick) {
close_on_first_tick_ = true;
}
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
if (is_tick_processing_) { if (is_tick_processing_) {
if (cc->Outputs().Get(output_tag_, 0).IsClosed()) {
return absl::OkStatus();
}
// TICK input is guaranteed to be non-empty, as it's the only input stream // TICK input is guaranteed to be non-empty, as it's the only input stream
// for this calculator. // for this calculator.
const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp(); const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp();
@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
cc->Outputs() cc->Outputs()
.Get(output_tag_, i) .Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); .AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
if (close_on_first_tick_) {
cc->Outputs().Get(output_tag_, i).Close();
}
} }
return absl::OkStatus(); return absl::OkStatus();
@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
if (!cc->Outputs().HasTag(kTagAtTick) && if (!cc->Outputs().HasTag(kTagAtTick) &&
!cc->Outputs().HasTag(kTagAtFirstTick) &&
!cc->Outputs().HasTag(kTagAtTimestamp)) { !cc->Outputs().HasTag(kTagAtTimestamp)) {
const auto& timestamp = kTimestampMap->at(output_tag_); const auto& timestamp = kTimestampMap->at(output_tag_);
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {

View File

@ -27,13 +27,17 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_util.h" #include "mediapipe/framework/tool/options_util.h"
#include "mediapipe/util/packet_test_util.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using testing::HasSubstr; using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) { TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
EXPECT_THAT( EXPECT_THAT(
status.message(), status.message(),
HasSubstr( HasSubstr(
"Either both of TICK and AT_TICK should be used or none of them.")); "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) { TEST(SidePacketToStreamCalculator,
WrongConfigWithMissingTickForFirstTickProcessing) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_side_packet: "side_packet"
output_stream: "AT_FIRST_TICK:packet"
}
)pb");
CalculatorGraph graph;
auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
HasSubstr(
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them."));
}
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
"or none of them.")); "or none of them."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) { TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
CalculatorGraph graph; CalculatorGraph graph;
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(status.message(),
status.message(), HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
"AT_TIMESTAMP tags is allowed and required to specify output " "allowed and required to specify output stream(s)."));
"stream(s)."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) { TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
CalculatorGraph graph; CalculatorGraph graph;
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(status.message(),
status.message(), HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " "AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
"AT_TIMESTAMP tags is allowed and required to specify output " "allowed and required to specify output stream(s)."));
"stream(s)."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) { TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
"Same number of input side packets and output streams is required.")); "Same number of input side packets and output streams is required."));
} }
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) { TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) {
tick_and_verify(/*at_timestamp=*/1025); tick_and_verify(/*at_timestamp=*/1025);
} }
TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) { TEST(SidePacketToStreamCalculator, AtFirstTick) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet"
output_stream: "packet"
node {
calculator: "SidePacketToStreamCalculator"
input_stream: "TICK:tick"
input_side_packet: "side_packet"
output_stream: "AT_FIRST_TICK:packet"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("packet", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value = 20;
const Timestamp kTestTimestamp(1234);
MP_ASSERT_OK(
graph.StartRun({{"side_packet", MakePacket<int>(expected_value)}}));
auto insert_tick = [&graph](Timestamp at_timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
};
insert_tick(kTestTimestamp);
EXPECT_THAT(output_packets,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value))));
output_packets.clear();
// Should not result in an additional output.
insert_tick(kTestTimestamp + 1);
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(
@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
tick_and_verify(/*at_timestamp=*/1025); tick_and_verify(/*at_timestamp=*/1025);
} }
TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "tick"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "packet0"
output_stream: "packet1"
node {
calculator: "SidePacketToStreamCalculator"
input_stream: "TICK:tick"
input_side_packet: "side_packet0"
input_side_packet: "side_packet1"
output_stream: "AT_FIRST_TICK:0:packet0"
output_stream: "AT_FIRST_TICK:1:packet1"
}
)pb");
std::vector<Packet> output_packets0;
tool::AddVectorSink("packet0", &graph_config, &output_packets0);
std::vector<Packet> output_packets1;
tool::AddVectorSink("packet1", &graph_config, &output_packets1);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
const int expected_value0 = 20;
const int expected_value1 = 128;
const Timestamp kTestTimestamp(1234);
MP_ASSERT_OK(
graph.StartRun({{"side_packet0", MakePacket<int>(expected_value0)},
{"side_packet1", MakePacket<int>(expected_value1)}}));
auto insert_tick = [&graph](Timestamp at_timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
};
insert_tick(kTestTimestamp);
EXPECT_THAT(output_packets0,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value0))));
EXPECT_THAT(output_packets1,
ElementsAre(PacketContainsTimestampAndPayload<int>(
Eq(kTestTimestamp), Eq(expected_value1))));
output_packets0.clear();
output_packets1.clear();
// Should not result in an additional output.
insert_tick(kTestTimestamp + 1);
EXPECT_THAT(output_packets0, IsEmpty());
EXPECT_THAT(output_packets1, IsEmpty());
}
TEST(SidePacketToStreamCalculator, AtTimestamp) { TEST(SidePacketToStreamCalculator, AtTimestamp) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) {
EXPECT_EQ(expected_value, output_packets.back().Get<int>()); EXPECT_EQ(expected_value, output_packets.back().Get<int>());
} }
TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) { TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) {
CalculatorGraphConfig graph_config = CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>( ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb( R"pb(