Internal change
PiperOrigin-RevId: 539719443
This commit is contained in:
parent
fe0d1b1e83
commit
96cc0fd07b
|
@ -1355,6 +1355,22 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "calculator_graph_summary_packet_test",
|
||||||
|
srcs = ["calculator_graph_summary_packet_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":calculator_framework",
|
||||||
|
":packet",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||||
|
"//mediapipe/framework/tool:sink",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "calculator_runner_test",
|
name = "calculator_runner_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
|
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::Eq;
|
||||||
|
using ::testing::IsEmpty;
|
||||||
|
using ::testing::Value;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
MATCHER_P2(IntPacket, value, timestamp, "") {
|
||||||
|
*result_listener << "where object is (value: " << arg.template Get<int>()
|
||||||
|
<< ", timestamp: " << arg.Timestamp() << ")";
|
||||||
|
return Value(arg.template Get<int>(), Eq(value)) &&
|
||||||
|
Value(arg.Timestamp(), Eq(timestamp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates and produces sum of all passed inputs when no more packets can be
|
||||||
|
// expected on the input stream.
|
||||||
|
class SummaryPacketCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"SUMMARY"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
// Makes sure there are no automatic timestamp bound updates when Process
|
||||||
|
// is called.
|
||||||
|
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||||
|
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
|
||||||
|
// bound update. (ImmediateInputStreamhandler handles multiple input
|
||||||
|
// streams differently, so, in that case, calculator adjustments may be
|
||||||
|
// required.)
|
||||||
|
// TODO: update all input stream handlers to support "done"
|
||||||
|
// timestamp bound update.
|
||||||
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
// Enables processing timestamp bound updates. For this use case we are
|
||||||
|
// specifically interested in "done" timestamp bound update. (E.g. when
|
||||||
|
// all input packet sources are closed.)
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (!kIn(cc).IsEmpty()) {
|
||||||
|
value_ += kIn(cc).Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kOut(cc).IsClosed()) {
|
||||||
|
// This can happen:
|
||||||
|
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
|
||||||
|
// source calculator finished generating packets sent to kIn) and
|
||||||
|
// HasNextAllowedInStream() == true (which is an often case).
|
||||||
|
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
|
||||||
|
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
|
||||||
|
// bound update.
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: input stream holding a packet with timestamp that has
|
||||||
|
// no next timestamp allowed in stream should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
|
||||||
|
// kOut(cc).Send(value_) can be used here as well, however in the case of
|
||||||
|
// source calculator sending inputs into kIn the resulting timestamp is
|
||||||
|
// not well defined (e.g. it can be the last packet timestamp or
|
||||||
|
// Timestamp::Max())
|
||||||
|
// TODO: last packet from source should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
kOut(cc).Send(value_, Timestamp::Max());
|
||||||
|
kOut(cc).Close();
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int value_ = 0;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnClosingAllPacketSources) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp(11));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp::Max());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPreStreamTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PreStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPostStreamTimestamp) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PostStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
class IntGeneratorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return tool::StatusStop();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnSourceCalculatorCompletion) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "IntGeneratorCalculator"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmitOnCloseCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
absl::Status Close(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnAnotherCalculatorClosure) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "EmitOnCloseCalculator"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseInputStream("input"));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
|
||||||
return *this + 1;
|
return *this + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Timestamp::HasNextAllowedInStream() const {
|
||||||
|
if (*this >= Max() || *this == PreStream()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
Timestamp Timestamp::PreviousAllowedInStream() const {
|
Timestamp Timestamp::PreviousAllowedInStream() const {
|
||||||
if (*this <= Min() || *this == PostStream()) {
|
if (*this <= Min() || *this == PostStream()) {
|
||||||
// Indicates that no previous timestamps may occur.
|
// Indicates that no previous timestamps may occur.
|
||||||
|
|
|
@ -186,6 +186,10 @@ class Timestamp {
|
||||||
// CHECKs that this->IsAllowedInStream().
|
// CHECKs that this->IsAllowedInStream().
|
||||||
Timestamp NextAllowedInStream() const;
|
Timestamp NextAllowedInStream() const;
|
||||||
|
|
||||||
|
// Returns true if there's a next timestamp in the range [Min .. Max] after
|
||||||
|
// this one.
|
||||||
|
bool HasNextAllowedInStream() const;
|
||||||
|
|
||||||
// Returns the previous timestamp in the range [Min .. Max], or
|
// Returns the previous timestamp in the range [Min .. Max], or
|
||||||
// Unstarted() if no Packets may preceed one with this timestamp.
|
// Unstarted() if no Packets may preceed one with this timestamp.
|
||||||
Timestamp PreviousAllowedInStream() const;
|
Timestamp PreviousAllowedInStream() const;
|
||||||
|
|
|
@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
|
||||||
Timestamp::PostStream().NextAllowedInStream());
|
Timestamp::PostStream().NextAllowedInStream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TimestampTest, HasNextAllowedInStream) {
|
||||||
|
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
|
||||||
|
|
||||||
|
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TimestampTest, SpecialValueDifferences) {
|
TEST(TimestampTest, SpecialValueDifferences) {
|
||||||
{ // Lower range
|
{ // Lower range
|
||||||
const std::vector<Timestamp> timestamps = {
|
const std::vector<Timestamp> timestamps = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user