Internal change
PiperOrigin-RevId: 539719443
This commit is contained in:
parent
fe0d1b1e83
commit
96cc0fd07b
|
@ -1355,6 +1355,22 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "calculator_graph_summary_packet_test",
|
||||
srcs = ["calculator_graph_summary_packet_test.cc"],
|
||||
deps = [
|
||||
":calculator_framework",
|
||||
":packet",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "calculator_runner_test",
|
||||
size = "medium",
|
||||
|
|
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
|
@ -0,0 +1,327 @@
|
|||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Node;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::Value;
|
||||
|
||||
namespace {
|
||||
|
||||
MATCHER_P2(IntPacket, value, timestamp, "") {
|
||||
*result_listener << "where object is (value: " << arg.template Get<int>()
|
||||
<< ", timestamp: " << arg.Timestamp() << ")";
|
||||
return Value(arg.template Get<int>(), Eq(value)) &&
|
||||
Value(arg.Timestamp(), Eq(timestamp));
|
||||
}
|
||||
|
||||
// Calculates and produces sum of all passed inputs when no more packets can be
|
||||
// expected on the input stream.
|
||||
class SummaryPacketCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int> kIn{"IN"};
|
||||
static constexpr Output<int> kOut{"SUMMARY"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
// Makes sure there are no automatic timestamp bound updates when Process
|
||||
// is called.
|
||||
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
|
||||
// bound update. (ImmediateInputStreamhandler handles multiple input
|
||||
// streams differently, so, in that case, calculator adjustments may be
|
||||
// required.)
|
||||
// TODO: update all input stream handlers to support "done"
|
||||
// timestamp bound update.
|
||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||
// Enables processing timestamp bound updates. For this use case we are
|
||||
// specifically interested in "done" timestamp bound update. (E.g. when
|
||||
// all input packet sources are closed.)
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!kIn(cc).IsEmpty()) {
|
||||
value_ += kIn(cc).Get();
|
||||
}
|
||||
|
||||
if (kOut(cc).IsClosed()) {
|
||||
// This can happen:
|
||||
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
|
||||
// source calculator finished generating packets sent to kIn) and
|
||||
// HasNextAllowedInStream() == true (which is an often case).
|
||||
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
|
||||
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
|
||||
// bound update.
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO: input stream holding a packet with timestamp that has
|
||||
// no next timestamp allowed in stream should always result in
|
||||
// InputStream::IsDone() == true.
|
||||
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
|
||||
// kOut(cc).Send(value_) can be used here as well, however in the case of
|
||||
// source calculator sending inputs into kIn the resulting timestamp is
|
||||
// not well defined (e.g. it can be the last packet timestamp or
|
||||
// Timestamp::Max())
|
||||
// TODO: last packet from source should always result in
|
||||
// InputStream::IsDone() == true.
|
||||
kOut(cc).Send(value_, Timestamp::Max());
|
||||
kOut(cc).Close();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
int value_ = 0;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||
ProducesSummaryPacketOnClosingAllPacketSources) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: 'IN:input'
|
||||
output_stream: 'SUMMARY:output'
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(value).At(timestamp)));
|
||||
};
|
||||
|
||||
send_packet(10, Timestamp(10));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
send_packet(20, Timestamp(11));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||
}
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: 'IN:input'
|
||||
output_stream: 'SUMMARY:output'
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(value).At(timestamp)));
|
||||
};
|
||||
|
||||
send_packet(10, Timestamp(10));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
send_packet(20, Timestamp::Max());
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||
|
||||
output_packets.clear();
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||
ProducesSummaryPacketOnPreStreamTimestamp) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: 'IN:input'
|
||||
output_stream: 'SUMMARY:output'
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(value).At(timestamp)));
|
||||
};
|
||||
|
||||
send_packet(10, Timestamp::PreStream());
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||
|
||||
output_packets.clear();
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||
ProducesSummaryPacketOnPostStreamTimestamp) {
|
||||
std::vector<Packet> output_packets;
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'input'
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: 'IN:input'
|
||||
output_stream: 'SUMMARY:output'
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input", MakePacket<int>(value).At(timestamp)));
|
||||
};
|
||||
|
||||
send_packet(10, Timestamp::PostStream());
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||
|
||||
output_packets.clear();
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
class IntGeneratorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Output<int> kOut{"INT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
kOut(cc).Send(20, Timestamp(0));
|
||||
kOut(cc).Send(10, Timestamp(1000));
|
||||
return tool::StatusStop();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||
ProducesSummaryPacketOnSourceCalculatorCompletion) {
|
||||
std::vector<Packet> output_packets;
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "IntGeneratorCalculator"
|
||||
output_stream: "INT:int_value"
|
||||
}
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: "IN:int_value"
|
||||
output_stream: "SUMMARY:output"
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||
}
|
||||
|
||||
class EmitOnCloseCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<int> kIn{"IN"};
|
||||
static constexpr Output<int> kOut{"INT"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||
|
||||
absl::Status Close(CalculatorContext* cc) final {
|
||||
kOut(cc).Send(20, Timestamp(0));
|
||||
kOut(cc).Send(10, Timestamp(1000));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
|
||||
|
||||
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||
ProducesSummaryPacketOnAnotherCalculatorClosure) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "input"
|
||||
node {
|
||||
calculator: "EmitOnCloseCalculator"
|
||||
input_stream: "IN:input"
|
||||
output_stream: "INT:int_value"
|
||||
}
|
||||
node {
|
||||
calculator: "SummaryPacketCalculator"
|
||||
input_stream: "IN:int_value"
|
||||
output_stream: "SUMMARY:output"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
|
||||
MP_ASSERT_OK(graph.CloseInputStream("input"));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||
|
||||
output_packets.clear();
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
EXPECT_THAT(output_packets, IsEmpty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
|
|||
return *this + 1;
|
||||
}
|
||||
|
||||
bool Timestamp::HasNextAllowedInStream() const {
|
||||
if (*this >= Max() || *this == PreStream()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Timestamp Timestamp::PreviousAllowedInStream() const {
|
||||
if (*this <= Min() || *this == PostStream()) {
|
||||
// Indicates that no previous timestamps may occur.
|
||||
|
|
|
@ -186,6 +186,10 @@ class Timestamp {
|
|||
// CHECKs that this->IsAllowedInStream().
|
||||
Timestamp NextAllowedInStream() const;
|
||||
|
||||
// Returns true if there's a next timestamp in the range [Min .. Max] after
|
||||
// this one.
|
||||
bool HasNextAllowedInStream() const;
|
||||
|
||||
// Returns the previous timestamp in the range [Min .. Max], or
|
||||
// Unstarted() if no Packets may preceed one with this timestamp.
|
||||
Timestamp PreviousAllowedInStream() const;
|
||||
|
|
|
@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
|
|||
Timestamp::PostStream().NextAllowedInStream());
|
||||
}
|
||||
|
||||
TEST(TimestampTest, HasNextAllowedInStream) {
|
||||
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
|
||||
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
|
||||
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
|
||||
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
|
||||
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
|
||||
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
|
||||
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
|
||||
|
||||
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
|
||||
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
|
||||
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
|
||||
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
|
||||
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
|
||||
}
|
||||
|
||||
TEST(TimestampTest, SpecialValueDifferences) {
|
||||
{ // Lower range
|
||||
const std::vector<Timestamp> timestamps = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user