// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/core_proto_inc.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/status_util.h" namespace mediapipe {} namespace testing_ns { using mediapipe::CalculatorBase; using mediapipe::CalculatorContext; using mediapipe::CalculatorContract; using mediapipe::CalculatorGraphConfig; using mediapipe::GetFromUniquePtr; using mediapipe::InputStreamShardSet; using mediapipe::MakePacket; using mediapipe::OutputStreamShardSet; using mediapipe::Timestamp; namespace proto_ns = mediapipe::proto_ns; using mediapipe::CalculatorGraph; using mediapipe::Packet; class InfiniteSequenceCalculator : public mediapipe::CalculatorBase { public: static absl::Status GetContract(mediapipe::CalculatorContract* cc) { cc->Outputs().Tag("OUT").Set(); cc->Outputs().Tag("EVENT").Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Tag("OUT").AddPacket( MakePacket(count_).At(Timestamp(count_))); count_++; return absl::OkStatus(); } absl::Status Close(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(::testing_ns::InfiniteSequenceCalculator); class StoppingPassThroughCalculator : public mediapipe::CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); } cc->Outputs().Tag("EVENT").Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { if (!cc->Inputs().Get("", i).IsEmpty()) { cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value()); } } return (++count_ <= max_count_) ? absl::OkStatus() : mediapipe::tool::StatusStop(); } absl::Status Close(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); return absl::OkStatus(); } private: int count_ = 0; int max_count_ = 10; }; REGISTER_CALCULATOR(::testing_ns::StoppingPassThroughCalculator); // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: AtomicSemaphore(int64_t supply) : supply_(supply) {} void Acquire(int64_t amount) { while (supply_.fetch_sub(amount) - amount < 0) { Release(amount); } } void Release(int64_t amount) { supply_ += amount; } private: std::atomic supply_; }; // A ProcessFunction that passes through all packets. absl::Status DoProcess(const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } return absl::OkStatus(); } typedef std::function ProcessFunction; // A Calculator that delegates its Process function to a callback function. class ProcessCallbackCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); } cc->InputSidePackets().Index(0).Set>(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { callback_ = *GetFromUniquePtr(cc->InputSidePackets().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { return callback_(cc->Inputs(), &(cc->Outputs())); } private: ProcessFunction callback_; }; REGISTER_CALCULATOR(::testing_ns::ProcessCallbackCalculator); // Tests CloseAllPacketSources. TEST(CalculatorGraphStoppingTest, CloseAllPacketSources) { CalculatorGraphConfig graph_config; ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( max_queue_size: 5 input_stream: 'input' node { calculator: 'InfiniteSequenceCalculator' output_stream: 'OUT:count' output_stream: 'EVENT:event' } node { calculator: 'StoppingPassThroughCalculator' input_stream: 'count' input_stream: 'input' output_stream: 'count_out' output_stream: 'input_out' output_stream: 'EVENT:event_out' } package: 'testing_ns' )", &graph_config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(graph_config, {})); // Observe output packets, and call CloseAllPacketSources after kNumPackets. std::vector out_packets; std::vector count_packets; std::vector event_packets; std::vector event_out_packets; int kNumPackets = 8; MP_ASSERT_OK(graph.ObserveOutputStream( // "input_out", [&](const Packet& packet) { out_packets.push_back(packet); if (out_packets.size() >= kNumPackets) { MP_EXPECT_OK(graph.CloseAllPacketSources()); } return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "count_out", [&](const Packet& packet) { count_packets.push_back(packet); return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "event", [&](const Packet& packet) { event_packets.push_back(packet.Get()); return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "event_out", [&](const Packet& packet) { event_out_packets.push_back(packet.Get()); return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); for (int i = 0; i < kNumPackets; ++i) { MP_EXPECT_OK(graph.AddPacketToInputStream( "input", MakePacket(i).At(Timestamp(i)))); } // The graph run should complete with no error status. MP_EXPECT_OK(graph.WaitUntilDone()); EXPECT_EQ(kNumPackets, out_packets.size()); EXPECT_LE(kNumPackets, count_packets.size()); std::vector expected_events = {1, 2}; EXPECT_EQ(event_packets, expected_events); EXPECT_EQ(event_out_packets, expected_events); } // Verify that deadlock due to throttling can be reported. TEST(CalculatorGraphStoppingTest, DeadlockReporting) { CalculatorGraphConfig config; ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( input_stream: 'in_1' input_stream: 'in_2' max_queue_size: 2 node { calculator: 'ProcessCallbackCalculator' input_stream: 'in_1' input_stream: 'in_2' output_stream: 'out_1' output_stream: 'out_2' input_side_packet: 'callback_1' } package: 'testing_ns' report_deadlock: true )", &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); graph.SetGraphInputStreamAddMode( CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); std::vector out_packets; MP_ASSERT_OK( graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); // Lambda that waits for a local semaphore. AtomicSemaphore semaphore(0); ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { semaphore.Acquire(1); return DoProcess(inputs, outputs); }; // Lambda that adds a packet to the calculator graph. auto add_packet = [&graph](std::string s, int i) { return graph.AddPacketToInputStream(s, MakePacket(i).At(Timestamp(i))); }; // Start the graph. MP_ASSERT_OK(graph.StartRun({ {"callback_1", AdoptAsUniquePtr(new auto(callback_1))}, })); // Add 3 packets to "in_1" with no packets on "in_2". // This causes throttling and deadlock with max_queue_size 2. semaphore.Release(3); MP_EXPECT_OK(add_packet("in_1", 1)); MP_EXPECT_OK(add_packet("in_1", 2)); EXPECT_FALSE(add_packet("in_1", 3).ok()); absl::Status status = graph.WaitUntilIdle(); EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); EXPECT_THAT( status.message(), testing::HasSubstr("Detected a deadlock due to input throttling")); MP_ASSERT_OK(graph.CloseAllInputStreams()); EXPECT_FALSE(graph.WaitUntilDone().ok()); ASSERT_EQ(0, out_packets.size()); } // Verify that input streams grow due to deadlock resolution. TEST(CalculatorGraphStoppingTest, DeadlockResolution) { CalculatorGraphConfig config; ASSERT_TRUE(proto_ns::TextFormat::ParseFromString(R"( input_stream: 'in_1' input_stream: 'in_2' max_queue_size: 2 node { calculator: 'ProcessCallbackCalculator' input_stream: 'in_1' input_stream: 'in_2' output_stream: 'out_1' output_stream: 'out_2' input_side_packet: 'callback_1' } package: 'testing_ns' )", &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); graph.SetGraphInputStreamAddMode( CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); std::vector out_packets; MP_ASSERT_OK( graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); // Lambda that waits for a local semaphore. AtomicSemaphore semaphore(0); ProcessFunction callback_1 = [&semaphore](const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { semaphore.Acquire(1); return DoProcess(inputs, outputs); }; // Lambda that adds a packet to the calculator graph. auto add_packet = [&graph](std::string s, int i) { return graph.AddPacketToInputStream(s, MakePacket(i).At(Timestamp(i))); }; // Start the graph. MP_ASSERT_OK(graph.StartRun({ {"callback_1", AdoptAsUniquePtr(new auto(callback_1))}, })); // Add 9 packets to "in_1" with no packets on "in_2". // This grows the input stream "in_1" to max-queue-size 10. semaphore.Release(9); for (int i = 1; i <= 9; ++i) { MP_EXPECT_OK(add_packet("in_1", i)); MP_ASSERT_OK(graph.WaitUntilIdle()); } // Advance the timestamp-bound and flush "in_1". semaphore.Release(1); MP_EXPECT_OK(add_packet("in_2", 30)); MP_ASSERT_OK(graph.WaitUntilIdle()); // Fill up input stream "in_1", with the semaphore blocked and deadlock // resolution disabled. for (int i = 11; i < 23; ++i) { MP_EXPECT_OK(add_packet("in_1", i)); } // Adding any more packets fails with error "Graph is throttled". graph.SetGraphInputStreamAddMode( CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); EXPECT_FALSE(add_packet("in_1", 23).ok()); // Allow the 12 blocked calls to "callback_1" to complete. semaphore.Release(12); MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); ASSERT_EQ(21, out_packets.size()); } } // namespace testing_ns