Add a test case for "summary packet" to test failing upstream calculator

PiperOrigin-RevId: 540331486
This commit is contained in:
MediaPipe Team 2023-06-14 11:31:12 -07:00 committed by Copybara-Service
parent 66a29bf371
commit a1be5f3e72
2 changed files with 111 additions and 7 deletions

View File

@ -1368,6 +1368,7 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/status",
],
)

View File

@ -1,3 +1,4 @@
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
@ -15,6 +16,7 @@ using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Value;
@ -57,6 +59,7 @@ class SummaryPacketCalculator : public Node {
absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) {
value_ += kIn(cc).Get();
value_set_ = true;
}
if (kOut(cc).IsClosed()) {
@ -74,13 +77,19 @@ class SummaryPacketCalculator : public Node {
// no next timestamp allowed in stream should always result in
// InputStream::IsDone() == true.
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
// kOut(cc).Send(value_) can be used here as well, however in the case of
// source calculator sending inputs into kIn the resulting timestamp is
// not well defined (e.g. it can be the last packet timestamp or
// Timestamp::Max())
// TODO: last packet from source should always result in
// InputStream::IsDone() == true.
kOut(cc).Send(value_, Timestamp::Max());
// `Process` may or may not be invoked for "done" timestamp bound when
// upstream calculator fails in `Close`. Hence, extra care is needed to
// identify whether the calculator needs to send output.
// TODO: remove when "done" timestamp bound flakiness fixed.
if (value_set_) {
// kOut(cc).Send(value_) can be used here as well, however in the case
// of source calculator sending inputs into kIn the resulting timestamp
// is not well defined (e.g. it can be the last packet timestamp or
// Timestamp::Max())
// TODO: last packet from source should always result in
// InputStream::IsDone() == true.
kOut(cc).Send(value_, Timestamp::Max());
}
kOut(cc).Close();
}
return absl::OkStatus();
@ -88,6 +97,7 @@ class SummaryPacketCalculator : public Node {
private:
int value_ = 0;
bool value_set_ = false;
};
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
@ -323,5 +333,98 @@ TEST(SummaryPacketCalculatorUseCaseTest,
EXPECT_THAT(output_packets, IsEmpty());
}
class FailureInCloseCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Close(CalculatorContext* cc) final {
return absl::InternalError("error");
}
};
MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "FailureInCloseCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseInputStream("input"));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
EXPECT_THAT(output_packets, IsEmpty());
}
class FailureInProcessCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final {
return absl::InternalError("error");
}
};
MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "FailureInProcessCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PostStream());
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
EXPECT_THAT(output_packets, IsEmpty());
}
} // namespace
} // namespace mediapipe