Internal change

PiperOrigin-RevId: 539719443
This commit is contained in:
MediaPipe Team 2023-06-12 11:51:34 -07:00 committed by Copybara-Service
parent fe0d1b1e83
commit 96cc0fd07b
5 changed files with 370 additions and 0 deletions

View File

@ -1355,6 +1355,22 @@ cc_test(
], ],
) )
cc_test(
name = "calculator_graph_summary_packet_test",
srcs = ["calculator_graph_summary_packet_test.cc"],
deps = [
":calculator_framework",
":packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
],
)
cc_test( cc_test(
name = "calculator_runner_test", name = "calculator_runner_test",
size = "medium", size = "medium",

View File

@ -0,0 +1,327 @@
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::Value;
namespace {
MATCHER_P2(IntPacket, value, timestamp, "") {
*result_listener << "where object is (value: " << arg.template Get<int>()
<< ", timestamp: " << arg.Timestamp() << ")";
return Value(arg.template Get<int>(), Eq(value)) &&
Value(arg.Timestamp(), Eq(timestamp));
}
// Calculates and produces sum of all passed inputs when no more packets can be
// expected on the input stream.
class SummaryPacketCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"SUMMARY"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
// Makes sure there are no automatic timestamp bound updates when Process
// is called.
cc->SetTimestampOffset(TimestampDiff::Unset());
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
// bound update. (ImmediateInputStreamhandler handles multiple input
// streams differently, so, in that case, calculator adjustments may be
// required.)
// TODO: update all input stream handlers to support "done"
// timestamp bound update.
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
// Enables processing timestamp bound updates. For this use case we are
// specifically interested in "done" timestamp bound update. (E.g. when
// all input packet sources are closed.)
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) {
value_ += kIn(cc).Get();
}
if (kOut(cc).IsClosed()) {
// This can happen:
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
// source calculator finished generating packets sent to kIn) and
// HasNextAllowedInStream() == true (which is an often case).
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
// bound update.
return absl::OkStatus();
}
// TODO: input stream holding a packet with timestamp that has
// no next timestamp allowed in stream should always result in
// InputStream::IsDone() == true.
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
// kOut(cc).Send(value_) can be used here as well, however in the case of
// source calculator sending inputs into kIn the resulting timestamp is
// not well defined (e.g. it can be the last packet timestamp or
// Timestamp::Max())
// TODO: last packet from source should always result in
// InputStream::IsDone() == true.
kOut(cc).Send(value_, Timestamp::Max());
kOut(cc).Close();
}
return absl::OkStatus();
}
private:
int value_ = 0;
};
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnClosingAllPacketSources) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp(10));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
send_packet(20, Timestamp(11));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
}
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp(10));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
send_packet(20, Timestamp::Max());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnPreStreamTimestamp) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PreStream());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnPostStreamTimestamp) {
std::vector<Packet> output_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PostStream());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
class IntGeneratorCalculator : public Node {
public:
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kOut);
absl::Status Process(CalculatorContext* cc) final {
kOut(cc).Send(20, Timestamp(0));
kOut(cc).Send(10, Timestamp(1000));
return tool::StatusStop();
}
};
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnSourceCalculatorCompletion) {
std::vector<Packet> output_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "IntGeneratorCalculator"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
}
class EmitOnCloseCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Close(CalculatorContext* cc) final {
kOut(cc).Send(20, Timestamp(0));
kOut(cc).Send(10, Timestamp(1000));
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnAnotherCalculatorClosure) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "EmitOnCloseCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseInputStream("input"));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
} // namespace
} // namespace mediapipe

View File

@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
return *this + 1; return *this + 1;
} }
bool Timestamp::HasNextAllowedInStream() const {
if (*this >= Max() || *this == PreStream()) {
return false;
}
return true;
}
Timestamp Timestamp::PreviousAllowedInStream() const { Timestamp Timestamp::PreviousAllowedInStream() const {
if (*this <= Min() || *this == PostStream()) { if (*this <= Min() || *this == PostStream()) {
// Indicates that no previous timestamps may occur. // Indicates that no previous timestamps may occur.

View File

@ -186,6 +186,10 @@ class Timestamp {
// CHECKs that this->IsAllowedInStream(). // CHECKs that this->IsAllowedInStream().
Timestamp NextAllowedInStream() const; Timestamp NextAllowedInStream() const;
// Returns true if there's a next timestamp in the range [Min .. Max] after
// this one.
bool HasNextAllowedInStream() const;
// Returns the previous timestamp in the range [Min .. Max], or // Returns the previous timestamp in the range [Min .. Max], or
// Unstarted() if no Packets may preceed one with this timestamp. // Unstarted() if no Packets may preceed one with this timestamp.
Timestamp PreviousAllowedInStream() const; Timestamp PreviousAllowedInStream() const;

View File

@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
Timestamp::PostStream().NextAllowedInStream()); Timestamp::PostStream().NextAllowedInStream());
} }
TEST(TimestampTest, HasNextAllowedInStream) {
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
}
TEST(TimestampTest, SpecialValueDifferences) { TEST(TimestampTest, SpecialValueDifferences) {
{ // Lower range { // Lower range
const std::vector<Timestamp> timestamps = { const std::vector<Timestamp> timestamps = {