Add AT_FIRST_TICK processing to SidePacketToStreamCalculator.
PiperOrigin-RevId: 578824863
This commit is contained in:
parent
e81fc5d0aa
commit
f8197651e8
|
@ -727,6 +727,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -742,6 +743,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_util",
|
||||
"//mediapipe/util:packet_test_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -32,6 +33,7 @@ namespace {
|
|||
constexpr char kTagAtPreStream[] = "AT_PRESTREAM";
|
||||
constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
|
||||
constexpr char kTagAtZero[] = "AT_ZERO";
|
||||
constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK";
|
||||
constexpr char kTagAtTick[] = "AT_TICK";
|
||||
constexpr char kTagTick[] = "TICK";
|
||||
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
|
||||
|
@ -43,6 +45,7 @@ static std::map<std::string, Timestamp>* kTimestampMap = []() {
|
|||
res->emplace(kTagAtPostStream, Timestamp::PostStream());
|
||||
res->emplace(kTagAtZero, Timestamp(0));
|
||||
res->emplace(kTagAtTick, Timestamp::Unset());
|
||||
res->emplace(kTagAtFirstTick, Timestamp::Unset());
|
||||
res->emplace(kTagAtTimestamp, Timestamp::Unset());
|
||||
return res;
|
||||
}();
|
||||
|
@ -59,8 +62,8 @@ std::string GetOutputTag(const CC& cc) {
|
|||
// timestamp, depending on the tag used to define output stream(s). (One tag can
|
||||
// be used only.)
|
||||
//
|
||||
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_TIMESTAMP
|
||||
// and corresponding timestamps are Timestamp::PreStream(),
|
||||
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK,
|
||||
// AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(),
|
||||
// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
|
||||
// input, and timestamp received from a side input.
|
||||
//
|
||||
|
@ -96,6 +99,7 @@ class SidePacketToStreamCalculator : public CalculatorBase {
|
|||
|
||||
private:
|
||||
bool is_tick_processing_ = false;
|
||||
bool close_on_first_tick_ = false;
|
||||
std::string output_tag_;
|
||||
};
|
||||
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
||||
|
@ -103,13 +107,16 @@ REGISTER_CALCULATOR(SidePacketToStreamCalculator);
|
|||
absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
|
||||
const auto& tags = cc->Outputs().GetTags();
|
||||
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
|
||||
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s).";
|
||||
RET_CHECK(
|
||||
(cc->Outputs().HasTag(kTagAtTick) && cc->Inputs().HasTag(kTagTick)) ||
|
||||
(!cc->Outputs().HasTag(kTagAtTick) && !cc->Inputs().HasTag(kTagTick)))
|
||||
<< "Either both of TICK and AT_TICK should be used or none of them.";
|
||||
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, "
|
||||
"AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to "
|
||||
"specify output stream(s).";
|
||||
const bool has_tick_output =
|
||||
cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick);
|
||||
const bool has_tick_input = cc->Inputs().HasTag(kTagTick);
|
||||
RET_CHECK((has_tick_output && has_tick_input) ||
|
||||
(!has_tick_output && !has_tick_input))
|
||||
<< "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them.";
|
||||
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
|
||||
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
|
||||
(!cc->Outputs().HasTag(kTagAtTimestamp) &&
|
||||
|
@ -148,11 +155,17 @@ absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) {
|
|||
// timestamp bound update.
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
}
|
||||
if (output_tag_ == kTagAtFirstTick) {
|
||||
close_on_first_tick_ = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
||||
if (is_tick_processing_) {
|
||||
if (cc->Outputs().Get(output_tag_, 0).IsClosed()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// TICK input is guaranteed to be non-empty, as it's the only input stream
|
||||
// for this calculator.
|
||||
const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp();
|
||||
|
@ -160,6 +173,9 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
|||
cc->Outputs()
|
||||
.Get(output_tag_, i)
|
||||
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
|
||||
if (close_on_first_tick_) {
|
||||
cc->Outputs().Get(output_tag_, i).Close();
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -170,6 +186,7 @@ absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
||||
if (!cc->Outputs().HasTag(kTagAtTick) &&
|
||||
!cc->Outputs().HasTag(kTagAtFirstTick) &&
|
||||
!cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||
const auto& timestamp = kTimestampMap->at(output_tag_);
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
|
||||
|
|
|
@ -27,13 +27,17 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_util.h"
|
||||
#include "mediapipe/util/packet_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using testing::HasSubstr;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTick) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -52,10 +56,35 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTick) {
|
|||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr(
|
||||
"Either both of TICK and AT_TICK should be used or none of them."));
|
||||
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
|
||||
TEST(SidePacketToStreamCalculator,
|
||||
WrongConfigWithMissingTickForFirstTickProcessing) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "packet"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "AT_FIRST_TICK:packet"
|
||||
}
|
||||
)pb");
|
||||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr(
|
||||
"Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
|
||||
"should be used or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMissingTimestampSideInput) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -76,7 +105,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MissingTimestampSideInput) {
|
|||
"or none of them."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNonExistentTag) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -92,14 +121,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NonExistentTag) {
|
|||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s)."));
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
|
||||
"AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
|
||||
"allowed and required to specify output stream(s)."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithMixedTags) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -117,14 +145,13 @@ TEST(SidePacketToStreamCalculator, WrongConfig_MixedTags) {
|
|||
CalculatorGraph graph;
|
||||
auto status = graph.Initialize(graph_config);
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and "
|
||||
"AT_TIMESTAMP tags is allowed and required to specify output "
|
||||
"stream(s)."));
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, "
|
||||
"AT_TICK, AT_FIRST_TICK and AT_TIMESTAMP tags is "
|
||||
"allowed and required to specify output stream(s)."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -146,7 +173,7 @@ TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughSidePackets) {
|
|||
"Same number of input side packets and output streams is required."));
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, WrongConfig_NotEnoughOutputStreams) {
|
||||
TEST(SidePacketToStreamCalculator, WrongConfigWithNotEnoughOutputStreams) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -248,7 +275,50 @@ TEST(SidePacketToStreamCalculator, AtTick) {
|
|||
tick_and_verify(/*at_timestamp=*/1025);
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
|
||||
TEST(SidePacketToStreamCalculator, AtFirstTick) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "packet"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_stream: "TICK:tick"
|
||||
input_side_packet: "side_packet"
|
||||
output_stream: "AT_FIRST_TICK:packet"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("packet", &graph_config, &output_packets);
|
||||
CalculatorGraph graph;
|
||||
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
const int expected_value = 20;
|
||||
const Timestamp kTestTimestamp(1234);
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"side_packet", MakePacket<int>(expected_value)}}));
|
||||
|
||||
auto insert_tick = [&graph](Timestamp at_timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
};
|
||||
|
||||
insert_tick(kTestTimestamp);
|
||||
|
||||
EXPECT_THAT(output_packets,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value))));
|
||||
|
||||
output_packets.clear();
|
||||
|
||||
// Should not result in an additional output.
|
||||
insert_tick(kTestTimestamp + 1);
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTickWithMultipleSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
@ -302,6 +372,62 @@ TEST(SidePacketToStreamCalculator, AtTick_MultipleSidePackets) {
|
|||
tick_and_verify(/*at_timestamp=*/1025);
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtFirstTickWithMultipleSidePackets) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "tick"
|
||||
input_side_packet: "side_packet0"
|
||||
input_side_packet: "side_packet1"
|
||||
output_stream: "packet0"
|
||||
output_stream: "packet1"
|
||||
node {
|
||||
calculator: "SidePacketToStreamCalculator"
|
||||
input_stream: "TICK:tick"
|
||||
input_side_packet: "side_packet0"
|
||||
input_side_packet: "side_packet1"
|
||||
output_stream: "AT_FIRST_TICK:0:packet0"
|
||||
output_stream: "AT_FIRST_TICK:1:packet1"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets0;
|
||||
tool::AddVectorSink("packet0", &graph_config, &output_packets0);
|
||||
std::vector<Packet> output_packets1;
|
||||
tool::AddVectorSink("packet1", &graph_config, &output_packets1);
|
||||
CalculatorGraph graph;
|
||||
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
const int expected_value0 = 20;
|
||||
const int expected_value1 = 128;
|
||||
const Timestamp kTestTimestamp(1234);
|
||||
MP_ASSERT_OK(
|
||||
graph.StartRun({{"side_packet0", MakePacket<int>(expected_value0)},
|
||||
{"side_packet1", MakePacket<int>(expected_value1)}}));
|
||||
|
||||
auto insert_tick = [&graph](Timestamp at_timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tick", MakePacket<int>(/*doesn't matter*/ 1).At(at_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
};
|
||||
|
||||
insert_tick(kTestTimestamp);
|
||||
|
||||
EXPECT_THAT(output_packets0,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value0))));
|
||||
EXPECT_THAT(output_packets1,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<int>(
|
||||
Eq(kTestTimestamp), Eq(expected_value1))));
|
||||
|
||||
output_packets0.clear();
|
||||
output_packets1.clear();
|
||||
|
||||
// Should not result in an additional output.
|
||||
insert_tick(kTestTimestamp + 1);
|
||||
EXPECT_THAT(output_packets0, IsEmpty());
|
||||
EXPECT_THAT(output_packets1, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTimestamp) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
|
@ -334,7 +460,7 @@ TEST(SidePacketToStreamCalculator, AtTimestamp) {
|
|||
EXPECT_EQ(expected_value, output_packets.back().Get<int>());
|
||||
}
|
||||
|
||||
TEST(SidePacketToStreamCalculator, AtTimestamp_MultipleOutputs) {
|
||||
TEST(SidePacketToStreamCalculator, AtTimestampWithMultipleOutputs) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
|
|
Loading…
Reference in New Issue
Block a user