mediapipe/mediapipe/framework/calculator_graph_stopping_test.cc
MediaPipe Team b899d17f18 Project import generated by Copybara.
GitOrigin-RevId: 8e1da4611d93ccb7d9674713157d43be0348d98f
2021-07-27 22:36:23 -04:00

388 lines
13 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/sink.h"
#include "mediapipe/framework/tool/status_util.h"
namespace mediapipe {}
namespace testing_ns {
using mediapipe::CalculatorBase;
using mediapipe::CalculatorContext;
using mediapipe::CalculatorContract;
using mediapipe::CalculatorGraphConfig;
using mediapipe::GetFromUniquePtr;
using mediapipe::InputStreamShardSet;
using mediapipe::MakePacket;
using mediapipe::OutputStreamShardSet;
using mediapipe::Timestamp;
namespace proto_ns = mediapipe::proto_ns;
constexpr char kEventTag[] = "EVENT";
constexpr char kOutTag[] = "OUT";
using mediapipe::CalculatorGraph;
using mediapipe::Packet;
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
public:
static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
cc->Outputs().Tag(kOutTag).Set<int>();
cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Tag(kOutTag).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_)));
count_++;
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(::testing_ns::InfiniteSequenceCalculator);
class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
}
cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
if (!cc->Inputs().Get("", i).IsEmpty()) {
cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value());
}
}
return (++count_ <= max_count_) ? absl::OkStatus()
: mediapipe::tool::StatusStop();
}
absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus();
}
private:
int count_ = 0;
int max_count_ = 10;
};
REGISTER_CALCULATOR(::testing_ns::StoppingPassThroughCalculator);
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_ += amount; }
private:
std::atomic<int64_t> supply_;
};
// A ProcessFunction that passes through all packets.
absl::Status DoProcess(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return absl::OkStatus();
}
typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A Calculator that delegates its Process function to a callback function.
class ProcessCallbackCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0));
}
cc->InputSidePackets().Index(0).Set<std::unique_ptr<ProcessFunction>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
callback_ =
*GetFromUniquePtr<ProcessFunction>(cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
return callback_(cc->Inputs(), &(cc->Outputs()));
}
private:
ProcessFunction callback_;
};
REGISTER_CALCULATOR(::testing_ns::ProcessCallbackCalculator);
// Tests CloseAllPacketSources.
TEST(CalculatorGraphStoppingTest, CloseAllPacketSources) {
CalculatorGraphConfig graph_config;
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
max_queue_size: 5
input_stream: 'input'
node {
calculator: 'InfiniteSequenceCalculator'
output_stream: 'OUT:count'
output_stream: 'EVENT:event'
}
node {
calculator: 'StoppingPassThroughCalculator'
input_stream: 'count'
input_stream: 'input'
output_stream: 'count_out'
output_stream: 'input_out'
output_stream: 'EVENT:event_out'
}
package: 'testing_ns'
)",
&graph_config));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
// Observe output packets, and call CloseAllPacketSources after kNumPackets.
std::vector<Packet> out_packets;
std::vector<Packet> count_packets;
std::vector<int> event_packets;
std::vector<int> event_out_packets;
int kNumPackets = 8;
MP_ASSERT_OK(graph.ObserveOutputStream( //
"input_out", [&](const Packet& packet) {
out_packets.push_back(packet);
if (out_packets.size() >= kNumPackets) {
MP_EXPECT_OK(graph.CloseAllPacketSources());
}
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.ObserveOutputStream( //
"count_out", [&](const Packet& packet) {
count_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.ObserveOutputStream( //
"event", [&](const Packet& packet) {
event_packets.push_back(packet.Get<int>());
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.ObserveOutputStream( //
"event_out", [&](const Packet& packet) {
event_out_packets.push_back(packet.Get<int>());
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
for (int i = 0; i < kNumPackets; ++i) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(i).At(Timestamp(i))));
}
// The graph run should complete with no error status.
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(kNumPackets, out_packets.size());
EXPECT_LE(kNumPackets, count_packets.size());
std::vector<int> expected_events = {1, 2};
EXPECT_EQ(event_packets, expected_events);
EXPECT_EQ(event_out_packets, expected_events);
}
// Verify that deadlock due to throttling can be reported.
TEST(CalculatorGraphStoppingTest, DeadlockReporting) {
CalculatorGraphConfig config;
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: 'in_1'
input_stream: 'in_2'
max_queue_size: 2
node {
calculator: 'ProcessCallbackCalculator'
input_stream: 'in_1'
input_stream: 'in_2'
output_stream: 'out_1'
output_stream: 'out_2'
input_side_packet: 'callback_1'
}
package: 'testing_ns'
report_deadlock: true
)",
&config));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
graph.SetGraphInputStreamAddMode(
CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL);
std::vector<Packet> out_packets;
MP_ASSERT_OK(
graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
// Lambda that waits for a local semaphore.
AtomicSemaphore semaphore(0);
ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore.Acquire(1);
return DoProcess(inputs, outputs);
};
// Lambda that adds a packet to the calculator graph.
auto add_packet = [&graph](std::string s, int i) {
return graph.AddPacketToInputStream(s, MakePacket<int>(i).At(Timestamp(i)));
};
// Start the graph.
MP_ASSERT_OK(graph.StartRun({
{"callback_1", AdoptAsUniquePtr(new auto(callback_1))},
}));
// Add 3 packets to "in_1" with no packets on "in_2".
// This causes throttling and deadlock with max_queue_size 2.
semaphore.Release(3);
MP_EXPECT_OK(add_packet("in_1", 1));
MP_EXPECT_OK(add_packet("in_1", 2));
EXPECT_FALSE(add_packet("in_1", 3).ok());
absl::Status status = graph.WaitUntilIdle();
EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable);
EXPECT_THAT(
status.message(),
testing::HasSubstr("Detected a deadlock due to input throttling"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
EXPECT_FALSE(graph.WaitUntilDone().ok());
ASSERT_EQ(0, out_packets.size());
}
// Verify that input streams grow due to deadlock resolution.
TEST(CalculatorGraphStoppingTest, DeadlockResolution) {
CalculatorGraphConfig config;
ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"(
input_stream: 'in_1'
input_stream: 'in_2'
max_queue_size: 2
node {
calculator: 'ProcessCallbackCalculator'
input_stream: 'in_1'
input_stream: 'in_2'
output_stream: 'out_1'
output_stream: 'out_2'
input_side_packet: 'callback_1'
}
package: 'testing_ns'
)",
&config));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
graph.SetGraphInputStreamAddMode(
CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL);
std::vector<Packet> out_packets;
MP_ASSERT_OK(
graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
// Lambda that waits for a local semaphore.
AtomicSemaphore semaphore(0);
ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
semaphore.Acquire(1);
return DoProcess(inputs, outputs);
};
// Lambda that adds a packet to the calculator graph.
auto add_packet = [&graph](std::string s, int i) {
return graph.AddPacketToInputStream(s, MakePacket<int>(i).At(Timestamp(i)));
};
// Start the graph.
MP_ASSERT_OK(graph.StartRun({
{"callback_1", AdoptAsUniquePtr(new auto(callback_1))},
}));
// Add 9 packets to "in_1" with no packets on "in_2".
// This grows the input stream "in_1" to max-queue-size 10.
semaphore.Release(9);
for (int i = 1; i <= 9; ++i) {
MP_EXPECT_OK(add_packet("in_1", i));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
// Advance the timestamp-bound and flush "in_1".
semaphore.Release(1);
MP_EXPECT_OK(add_packet("in_2", 30));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Fill up input stream "in_1", with the semaphore blocked and deadlock
// resolution disabled.
for (int i = 11; i < 23; ++i) {
MP_EXPECT_OK(add_packet("in_1", i));
}
// Adding any more packets fails with error "Graph is throttled".
graph.SetGraphInputStreamAddMode(
CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL);
EXPECT_FALSE(add_packet("in_1", 23).ok());
// Allow the 12 blocked calls to "callback_1" to complete.
semaphore.Release(12);
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(21, out_packets.size());
}
} // namespace testing_ns