// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "mediapipe/framework/calculator_graph.h" #include #include #include #include #include #include #include #include #include #include "absl/container/fixed_array.h" #include "absl/memory/memory.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/counter_factory.h" #include "mediapipe/framework/executor.h" #include "mediapipe/framework/input_stream_handler.h" #include "mediapipe/framework/lifetime_tracker.h" #include "mediapipe/framework/mediapipe_options.pb.h" #include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/status_handler.h" #include "mediapipe/framework/subgraph.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/thread_pool_executor.pb.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/type_map.h" #include "mediapipe/gpu/gpu_service.h" namespace mediapipe { namespace { constexpr char kCounter2Tag[] = "COUNTER2"; constexpr char kCounter1Tag[] = "COUNTER1"; constexpr char kExtraTag[] = "EXTRA"; constexpr char kWaitSemTag[] = "WAIT_SEM"; constexpr char kPostSemTag[] = "POST_SEM"; constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN"; constexpr char kOutputTag[] = "OUTPUT"; constexpr char kInputTag[] = "INPUT"; constexpr char kSelectTag[] = "SELECT"; using testing::ElementsAre; using testing::HasSubstr; // Pass packets through. Note that it calls SetOffset() in Process() // instead of Open(). class SetOffsetInProcessCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { // Input: arbitrary Packets. // Output: copy of the input. cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); cc->GetCounter("PassThrough")->Increment(); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(SetOffsetInProcessCalculator); // A Calculator that outputs the square of its input packet (an int). class SquareIntCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new int(value * value), cc->InputTimestamp()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(SquareIntCalculator); // A Calculator that selects an output stream from "OUTPUT:0", "OUTPUT:1", ..., // using the integer value (0, 1, ...) in the packet on the "SELECT" input // stream, and passes the packet on the "INPUT" input stream to the selected // output stream. // // This calculator is called "Timed" because it sets the next timestamp bound on // the unselected outputs. class DemuxTimedCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Tag(kSelectTag).Set(); PacketType* data_input = &cc->Inputs().Tag(kInputTag); data_input->SetAny(); for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); id < cc->Outputs().EndId("OUTPUT"); ++id) { cc->Outputs().Get(id).SetSameAs(data_input); } cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { select_input_ = cc->Inputs().GetId("SELECT", 0); data_input_ = cc->Inputs().GetId("INPUT", 0); output_base_ = cc->Outputs().GetId("OUTPUT", 0); num_outputs_ = cc->Outputs().NumEntries("OUTPUT"); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int select = cc->Inputs().Get(select_input_).Get(); RET_CHECK(0 <= select && select < num_outputs_); const Timestamp next_timestamp_bound = cc->InputTimestamp().NextAllowedInStream(); for (int i = 0; i < num_outputs_; ++i) { if (i == select) { cc->Outputs() .Get(output_base_ + i) .AddPacket(cc->Inputs().Get(data_input_).Value()); } else { cc->Outputs() .Get(output_base_ + i) .SetNextTimestampBound(next_timestamp_bound); } } return absl::OkStatus(); } private: CollectionItemId select_input_; CollectionItemId data_input_; CollectionItemId output_base_; int num_outputs_ = 0; }; REGISTER_CALCULATOR(DemuxTimedCalculator); // A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ..., // using the integer value (0, 1, ...) in the packet on the "SELECT" input // stream, and passes the packet on the selected input stream to the "OUTPUT" // output stream. // // This calculator is called "Timed" because it requires next timestamp bound // propagation on the unselected inputs. class MuxTimedCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kSelectTag).Set(); CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); PacketType* data_input0 = &cc->Inputs().Get(data_input_id); data_input0->SetAny(); ++data_input_id; for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) { cc->Inputs().Get(data_input_id).SetSameAs(data_input0); } RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { select_input_ = cc->Inputs().GetId("SELECT", 0); data_input_base_ = cc->Inputs().GetId("INPUT", 0); num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); output_ = cc->Outputs().GetId("OUTPUT", 0); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int select = cc->Inputs().Get(select_input_).Get(); RET_CHECK(0 <= select && select < num_data_inputs_); cc->Outputs().Get(output_).AddPacket( cc->Inputs().Get(data_input_base_ + select).Value()); return absl::OkStatus(); } private: CollectionItemId select_input_; CollectionItemId data_input_base_; int num_data_inputs_ = 0; CollectionItemId output_; }; REGISTER_CALCULATOR(MuxTimedCalculator); // A Calculator that adds the integer values in the packets in all the input // streams and outputs the sum to the output stream. class IntAdderCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int sum = 0; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { sum += cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntAdderCalculator); // A Calculator that adds the float values in the packets in all the input // streams and outputs the sum to the output stream. class FloatAdderCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { float sum = 0.0; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { sum += cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new float(sum), cc->InputTimestamp()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(FloatAdderCalculator); // A Calculator that multiplies the integer values in the packets in all the // input streams and outputs the product to the output stream. class IntMultiplierCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int product = 1; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { product *= cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new int(product), cc->InputTimestamp()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntMultiplierCalculator); // A Calculator that multiplies the float value in an input packet by a float // constant scalar (specified in a side packet) and outputs the product to the // output stream. class FloatScalarMultiplierCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->InputSidePackets().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { scalar_ = cc->InputSidePackets().Index(0).Get(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { float value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new float(scalar_ * value), cc->InputTimestamp()); return absl::OkStatus(); } private: float scalar_; }; REGISTER_CALCULATOR(FloatScalarMultiplierCalculator); // A Calculator that converts an integer to a float. class IntToFloatCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new float(static_cast(value)), cc->InputTimestamp()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntToFloatCalculator); template class TypedEmptySourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).SetAny(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new OutputType(), Timestamp::PostStream()); return tool::StatusStop(); } }; typedef TypedEmptySourceCalculator StringEmptySourceCalculator; typedef TypedEmptySourceCalculator IntEmptySourceCalculator; REGISTER_CALCULATOR(StringEmptySourceCalculator); REGISTER_CALCULATOR(IntEmptySourceCalculator); template class TypedSinkCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { return absl::OkStatus(); } }; typedef TypedSinkCalculator StringSinkCalculator; typedef TypedSinkCalculator IntSinkCalculator; REGISTER_CALCULATOR(StringSinkCalculator); REGISTER_CALCULATOR(IntSinkCalculator); // Output kNumOutputPackets packets, the value of each being the next // value in the counter provided as an input side packet. An optional // second input side packet will, if true, cause this calculator to // output the first of the kNumOutputPackets packets during Open. class GlobalCountSourceCalculator : public CalculatorBase { public: static const int kNumOutputPackets; static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).Set*>(); if (cc->InputSidePackets().NumEntries() >= 2) { cc->InputSidePackets().Index(1).Set(); } cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().NumEntries() >= 2 && cc->InputSidePackets().Index(1).Get()) { OutputOne(cc); } return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { OutputOne(cc); if (local_count_ >= kNumOutputPackets) { return tool::StatusStop(); } else { return absl::OkStatus(); } } private: void OutputOne(CalculatorContext* cc) { std::atomic* counter = cc->InputSidePackets().Index(0).Get*>(); int count = counter->fetch_add(1, std::memory_order_relaxed); cc->Outputs().Index(0).Add(new int(count), Timestamp(local_count_)); ++local_count_; } int64_t local_count_ = 0; }; const int GlobalCountSourceCalculator::kNumOutputPackets = 5; REGISTER_CALCULATOR(GlobalCountSourceCalculator); static const int kTestSequenceLength = 15; // Outputs the integers 0, 1, 2, 3, ..., 14, all with timestamp 0. class TestSequence1SourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_), Timestamp(0)); ++count_; ++num_outputs_; if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { return absl::OkStatus(); } } private: int count_ = 0; int num_outputs_ = 0; }; REGISTER_CALCULATOR(TestSequence1SourceCalculator); // Outputs the integers 1, 2, 3, 4, ..., 15, with decreasing timestamps // 100, 99, 98, 97, .... class TestSequence2SourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_), Timestamp(timestamp_)); ++count_; ++num_outputs_; --timestamp_; if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { return absl::OkStatus(); } } private: int count_ = 1; int num_outputs_ = 0; int timestamp_ = 100; }; REGISTER_CALCULATOR(TestSequence2SourceCalculator); // Outputs the integers 0, 1, 2 repeatedly for a total of 15 outputs. class Modulo3SourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_ % 3), Timestamp(count_ % 3)); ++count_; ++num_outputs_; if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { return absl::OkStatus(); } } private: int count_ = 0; int num_outputs_ = 0; }; REGISTER_CALCULATOR(Modulo3SourceCalculator); // A source calculator that outputs 100 packets all at once and stops. The // number of output packets (100) is deliberately chosen to be equal to // max_queue_size, which fills the input streams connected to this source // calculator and causes the MediaPipe scheduler to throttle this source // calculator. class OutputAllSourceCalculator : public CalculatorBase { public: static constexpr int kNumOutputPackets = 100; static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { for (int i = 0; i < kNumOutputPackets; ++i) { cc->Outputs().Index(0).Add(new int(0), Timestamp(i)); } return tool::StatusStop(); } }; REGISTER_CALCULATOR(OutputAllSourceCalculator); // A source calculator that outputs one packet at a time. The total number of // output packets needs to be large enough to eventually fill an input stream // connected to this source calculator and to force the MediaPipe scheduler to // run this source calculator as a throttled source when the graph cannot make // progress. class OutputOneAtATimeSourceCalculator : public CalculatorBase { public: static constexpr int kNumOutputPackets = 1000; static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { if (index_ < kNumOutputPackets) { cc->Outputs().Index(0).Add(new int(0), Timestamp(index_)); ++index_; return absl::OkStatus(); } return tool::StatusStop(); } private: int index_ = 0; }; REGISTER_CALCULATOR(OutputOneAtATimeSourceCalculator); // A calculator that passes through one out of every 101 input packets and // discards the rest. The decimation ratio (101) is carefully chosen to be // greater than max_queue_size (100) so that an input stream parallel to the // input stream connected to this calculator can become full. class DecimatorCalculator : public CalculatorBase { public: static constexpr int kDecimationRatio = 101; static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { if (index_ % kDecimationRatio == 0) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } ++index_; return absl::OkStatus(); } private: int index_ = 0; }; REGISTER_CALCULATOR(DecimatorCalculator); // An error will be produced in Open() if ERROR_ON_OPEN is true. Otherwise, // this calculator simply passes its input packets through, unchanged. class ErrorOnOpenCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Tag(kErrorOnOpenTag).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get()) { return absl::NotFoundError("expected error"); } return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(ErrorOnOpenCalculator); // A calculator that outputs an initial packet of value 0 at time 0 in the // Open() method, and then delays each input packet by one time unit in the // Process() method. The input stream and output stream have the integer type. class UnitDelayCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new int(0), Timestamp(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); cc->Outputs().Index(0).AddPacket( packet.At(packet.Timestamp().NextAllowedInStream())); return absl::OkStatus(); } }; REGISTER_CALCULATOR(UnitDelayCalculator); class UnitDelayUntimedCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new int(0), Timestamp::Min()); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(UnitDelayUntimedCalculator); class FloatUnitDelayCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new float(0.0), Timestamp(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); cc->Outputs().Index(0).AddPacket( packet.At(packet.Timestamp().NextAllowedInStream())); return absl::OkStatus(); } }; REGISTER_CALCULATOR(FloatUnitDelayCalculator); // A sink calculator that asserts its input stream is empty in Open() and // discards input packets in Process(). class AssertEmptyInputInOpenCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { RET_CHECK(cc->Inputs().Index(0).Value().IsEmpty()); RET_CHECK_EQ(cc->Inputs().Index(0).Value().Timestamp(), Timestamp::Unset()); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } }; REGISTER_CALCULATOR(AssertEmptyInputInOpenCalculator); // A slow sink calculator that expects 10 input integers with the values // 0, 1, ..., 9. class SlowCountingSinkCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { absl::SleepFor(absl::Milliseconds(10)); int value = cc->Inputs().Index(0).Get(); CHECK_EQ(value, counter_); ++counter_; return absl::OkStatus(); } absl::Status Close(CalculatorContext* cc) override { CHECK_EQ(10, counter_); return absl::OkStatus(); } private: int counter_ = 0; }; REGISTER_CALCULATOR(SlowCountingSinkCalculator); template class TypedStatusHandler : public StatusHandler { public: ~TypedStatusHandler() override = 0; static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // const absl::Status& pre_run_status) { return absl::OkStatus(); } static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // const absl::Status& run_status) { return absl::OkStatus(); } }; typedef TypedStatusHandler StringStatusHandler; typedef TypedStatusHandler Uint32StatusHandler; REGISTER_STATUS_HANDLER(StringStatusHandler); REGISTER_STATUS_HANDLER(Uint32StatusHandler); // A string generator that will succeed. class StaticCounterStringGenerator : public PacketGenerator { public: static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).SetAny(); } output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { output_side_packets->Index(0) = MakePacket("fixed_string"); ++num_packets_generated_; return absl::OkStatus(); } static int NumPacketsGenerated() { return num_packets_generated_; } private: static int num_packets_generated_; }; int StaticCounterStringGenerator::num_packets_generated_ = 0; REGISTER_PACKET_GENERATOR(StaticCounterStringGenerator); // A failing PacketGenerator and Calculator to verify that status handlers get // called. Both claim to output strings but instead always fail. class FailingPacketGenerator : public PacketGenerator { public: static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).SetAny(); } output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { return absl::UnknownError("this always fails."); } }; REGISTER_PACKET_GENERATOR(FailingPacketGenerator); // Passes the integer through if it is positive. class EnsurePositivePacketGenerator : public PacketGenerator { public: static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).Set(); output_side_packets->Index(i).Set(); } return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { for (int i = 0; i < input_side_packets.NumEntries(); ++i) { if (input_side_packets.Index(i).Get() > 0) { output_side_packets->Index(i) = input_side_packets.Index(i); } else { return absl::UnknownError( absl::StrCat("Integer ", i, " was not positive.")); } } return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(EnsurePositivePacketGenerator); // A Status handler which takes an int side packet and fails in pre run // if that packet is FailableStatusHandler::kFailPreRun and fails post // run if that value is FailableStatusHandler::kFailPostRun. If the // int is any other value then no failures occur. class FailableStatusHandler : public StatusHandler { public: enum { kOk = 0, kFailPreRun = 1, kFailPostRun = 2, }; static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, const absl::Status& pre_run_status) { if (input_side_packets.Index(0).Get() == kFailPreRun) { return absl::UnknownError( "FailableStatusHandler failing pre run as intended."); } else { return absl::OkStatus(); } } static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, const absl::Status& run_status) { if (input_side_packets.Index(0).Get() == kFailPostRun) { return absl::UnknownError( "FailableStatusHandler failing post run as intended."); } else { return absl::OkStatus(); } } }; REGISTER_STATUS_HANDLER(FailableStatusHandler); class FailingSourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { return absl::UnknownError("this always fails."); } }; REGISTER_CALCULATOR(FailingSourceCalculator); // A simple Semaphore for synchronizing test threads. class AtomicSemaphore { public: AtomicSemaphore(int64_t supply) : supply_(supply) {} void Acquire(int64_t amount) { while (supply_.fetch_sub(amount) - amount < 0) { Release(amount); } } void Release(int64_t amount) { supply_ += amount; } private: std::atomic supply_; }; // This calculator posts to a semaphore when it starts its Process method, // and waits on a different semaphore before returning from Process. // This allows a test to run code when the calculator is running Process // without having to depend on any specific timing. class SemaphoreCalculator : public CalculatorBase { public: using Semaphore = AtomicSemaphore; static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Tag(kPostSemTag).Set(); cc->InputSidePackets().Tag(kWaitSemTag).Set(); cc->SetTimestampOffset(TimestampDiff(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->InputSidePackets().Tag(kPostSemTag).Get()->Release(1); cc->InputSidePackets().Tag(kWaitSemTag).Get()->Acquire(1); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); return absl::OkStatus(); } }; REGISTER_CALCULATOR(SemaphoreCalculator); // A calculator that has no input streams and output streams, runs only once, // and takes 20 milliseconds to run. class OneShot20MsCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { absl::SleepFor(absl::Milliseconds(20)); return tool::StatusStop(); } }; REGISTER_CALCULATOR(OneShot20MsCalculator); // A source calculator that outputs a packet containing the return value of // pthread_self() (the pthread id of the current thread). class PthreadSelfSourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).AddPacket( MakePacket(pthread_self()).At(Timestamp(0))); return tool::StatusStop(); } }; REGISTER_CALCULATOR(PthreadSelfSourceCalculator); // A source calculator for testing the Calculator::InputTimestamp() method. // It outputs five int packets with timestamps 0, 1, 2, 3, 4. class CheckInputTimestampSourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); return absl::OkStatus(); } // InputTimestamp() always returns Timestamp(0) in Process() for source // nodes. absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); ++count_; if (count_ >= 5) { return tool::StatusStop(); } else { return absl::OkStatus(); } } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(CheckInputTimestampSourceCalculator); // A sink calculator for testing the Calculator::InputTimestamp() method. // It expects to consume the output of a CheckInputTimestampSourceCalculator. class CheckInputTimestampSinkCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); return absl::OkStatus(); } // InputTimestamp() returns the timestamp of input packets in Process() for // non-source nodes. absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), cc->Inputs().Index(0).Value().Timestamp()); RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); ++count_; return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. absl::Status Close(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(CheckInputTimestampSinkCalculator); // A source calculator for testing the Calculator::InputTimestamp() method. // It outputs int packets with timestamps 0, 1, 2, ... until being closed by // the framework. class CheckInputTimestamp2SourceCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); return absl::OkStatus(); } // InputTimestamp() always returns Timestamp(0) in Process() for source // nodes. absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); ++count_; return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(CheckInputTimestamp2SourceCalculator); // A sink calculator for testing the Calculator::InputTimestamp() method. // It expects to consume the output of a CheckInputTimestamp2SourceCalculator. // It returns tool::StatusStop() after consuming five input packets. class CheckInputTimestamp2SinkCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); return absl::OkStatus(); } // InputTimestamp() returns the timestamp of input packets in Process() for // non-source nodes. absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), cc->Inputs().Index(0).Value().Timestamp()); RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); ++count_; if (count_ >= 5) { return tool::StatusStop(); } else { return absl::OkStatus(); } } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. absl::Status Close(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator); // A calculator checks if either of two input streams contains a packet and // sends the packet to the single output stream with the same timestamp. class SimpleMuxCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { data_input_base_ = cc->Inputs().BeginId(); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int select_packet_index = -1; if (!cc->Inputs().Index(0).IsEmpty()) { select_packet_index = 0; } else if (!cc->Inputs().Index(1).IsEmpty()) { select_packet_index = 1; } if (select_packet_index != -1) { cc->Outputs().Index(0).AddPacket( cc->Inputs().Get(data_input_base_ + select_packet_index).Value()); } return absl::OkStatus(); } private: CollectionItemId data_input_base_; }; REGISTER_CALCULATOR(SimpleMuxCalculator); // Mock status handler that reports the number of times HandleStatus was called // by modifying the int in its input side packet. class IncrementingStatusHandler : public StatusHandler { public: static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Tag(kExtraTag).SetAny().Optional(); input_side_packets->Tag(kCounter1Tag).Set>(); input_side_packets->Tag(kCounter2Tag).Set>(); return absl::OkStatus(); } static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // const absl::Status& pre_run_status) { int* counter = GetFromUniquePtr(input_side_packets.Tag(kCounter1Tag)); (*counter)++; return pre_run_status_result_; } static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // const absl::Status& run_status) { int* counter = GetFromUniquePtr(input_side_packets.Tag(kCounter2Tag)); (*counter)++; return post_run_status_result_; } static void SetPreRunStatusResult(const absl::Status& status) { pre_run_status_result_ = status; } static void SetPostRunStatusResult(const absl::Status& status) { post_run_status_result_ = status; } private: // Return values of HandlePreRunStatus() and HandleSTatus(), respectively. static absl::Status pre_run_status_result_; static absl::Status post_run_status_result_; }; absl::Status IncrementingStatusHandler::pre_run_status_result_ = absl::OkStatus(); absl::Status IncrementingStatusHandler::post_run_status_result_ = absl::OkStatus(); REGISTER_STATUS_HANDLER(IncrementingStatusHandler); // A simple executor that runs tasks directly on the current thread. // NOTE: If CurrentThreadExecutor is used, some CalculatorGraph methods may // behave differently. For example, CalculatorGraph::StartRun will run the // graph rather than returning immediately after starting the graph. // Similarly, CalculatorGraph::AddPacketToInputStream will run the graph // (until it's idle) rather than returning immediately after adding the packet // to the graph input stream. class CurrentThreadExecutor : public Executor { public: ~CurrentThreadExecutor() override { CHECK(!executing_); CHECK(tasks_.empty()); } void Schedule(std::function task) override { if (executing_) { // Queue the task for later execution (after the currently-running task // returns) rather than running the task immediately. This is especially // important for a source node (which can be rescheduled immediately after // running) to avoid an indefinitely-deep call stack. tasks_.emplace_back(std::move(task)); } else { CHECK(tasks_.empty()); executing_ = true; task(); while (!tasks_.empty()) { task = tasks_.front(); tasks_.pop_front(); task(); } executing_ = false; } } private: // True if the executor is executing tasks. bool executing_ = false; // The tasks to execute. std::deque> tasks_; }; // Returns a CalculatorGraphConfig used by tests. CalculatorGraphConfig GetConfig() { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( # The graph configuration. We list the nodes in an arbitrary (not # topologically-sorted) order to verify that CalculatorGraph can # handle such configurations. node { calculator: "RangeCalculator" output_stream: "range3" output_stream: "range3_sum" output_stream: "range3_mean" input_side_packet: "node_3_converted" } node { calculator: "RangeCalculator" output_stream: "range5" output_stream: "range5_sum" output_stream: "range5_mean" input_side_packet: "node_5_converted" } node { calculator: "MergeCalculator" input_stream: "range3" input_stream: "range5_copy" output_stream: "merge" } node { calculator: "MergeCalculator" input_stream: "range3_sum" input_stream: "range5_sum" output_stream: "merge_sum" } node { calculator: "PassThroughCalculator" input_stream: "range3_stddev" input_stream: "range5_stddev" output_stream: "range3_stddev_2" output_stream: "range5_stddev_2" } node { calculator: "PassThroughCalculator" input_stream: "A:range3_stddev_2" input_stream: "range5_stddev_2" output_stream: "A:range3_stddev_3" output_stream: "range5_stddev_3" } node { calculator: "PassThroughCalculator" input_stream: "B:range3_stddev_3" input_stream: "B:1:range5_stddev_3" output_stream: "B:range3_stddev_4" output_stream: "B:1:range5_stddev_4" } node { calculator: "MergeCalculator" input_stream: "range3_stddev_4" input_stream: "range5_stddev_4" output_stream: "merge_stddev" } node { calculator: "StdDevCalculator" input_stream: "DATA:range3" input_stream: "MEAN:range3_mean" output_stream: "range3_stddev" } node { calculator: "StdDevCalculator" input_stream: "DATA:range5" input_stream: "MEAN:range5_mean" output_stream: "range5_stddev" } node { name: "copy_range5" calculator: "PassThroughCalculator" input_stream: "range5" output_stream: "range5_copy" } node { calculator: "SaverCalculator" input_stream: "merge" output_stream: "final" } node { calculator: "SaverCalculator" input_stream: "merge_sum" output_stream: "final_sum" } node { calculator: "SaverCalculator" input_stream: "merge_stddev" output_stream: "final_stddev" } packet_generator { packet_generator: "IntSplitterPacketGenerator" input_side_packet: "node_3" output_side_packet: "node_3_converted" } packet_generator { packet_generator: "TaggedIntSplitterPacketGenerator" input_side_packet: "node_5" output_side_packet: "HIGH:unused_high" output_side_packet: "LOW:unused_low" output_side_packet: "PAIR:node_5_converted" } )pb"); return config; } // |graph| points to an empty CalculatorGraph object created by the default // constructor, before CalculatorGraph::Initialize() is called. void RunComprehensiveTest(CalculatorGraph* graph, const CalculatorGraphConfig& the_config, bool define_node_5) { CalculatorGraphConfig proto(the_config); Packet dumped_final_sum_packet; Packet dumped_final_packet; Packet dumped_final_stddev_packet; tool::AddPostStreamPacketSink("final", &proto, &dumped_final_packet); tool::AddPostStreamPacketSink("final_sum", &proto, &dumped_final_sum_packet); tool::AddPostStreamPacketSink("final_stddev", &proto, &dumped_final_stddev_packet); MP_ASSERT_OK(graph->Initialize(proto)); std::map extra_side_packets; extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 3))); if (define_node_5) { extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } // Call graph->Run() several times, to make sure that the appropriate // cleanup happens between iterations. for (int iteration = 0; iteration < 2; ++iteration) { LOG(INFO) << "Loop iteration " << iteration; dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); dumped_final_packet = Packet(); MP_ASSERT_OK(graph->Run(extra_side_packets)); // The merger will output the timestamp and all ints output from // the range calculators. The saver will concatenate together the // strings with a '/' deliminator. EXPECT_EQ( "Timestamp(0) 300 500/" "Timestamp(3) 301 empty/" "Timestamp(5) empty 501/" "Timestamp(6) 302 empty/" "Timestamp(9) 303 empty/" "Timestamp(10) empty 502/" "Timestamp(12) 304 empty/" "Timestamp(15) 305 503", dumped_final_packet.Get()); // Verify that the headers got set correctly. EXPECT_EQ( "RangeCalculator3 RangeCalculator5", graph->FindOutputStreamManager("merge")->Header().Get()); // Verify that sum packets get correctly processed. // (The first is a sum of all the 3's output and the second of all // the 5's). EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 1815 2006"), dumped_final_sum_packet.Get()); EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory() ->GetCounter("copy_range5-PassThrough") ->Get()); // Verify that stddev packets get correctly processed. // The standard deviation computed as: // sqrt(sum((x-mean(x))**2 / length(x))) // for x = 300:305 is 1.707825 (multiplied by 100 and rounded it is 171) // for x = 500:503 is 1.118034 (multiplied by 100 and rounded it is 112) EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 171 112"), dumped_final_stddev_packet.Get()); EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory() ->GetCounter("copy_range5-PassThrough") ->Get()); } LOG(INFO) << "After Loop Runs."; // Verify that the graph can still run (but not successfully) when // one of the nodes is caused to fail. extra_side_packets.clear(); extra_side_packets.emplace("node_3", Adopt(new uint64_t((15LL << 32) | 0))); if (define_node_5) { extra_side_packets.emplace("node_5", Adopt(new uint64_t((15LL << 32) | 5))); } dumped_final_sum_packet = Packet(); dumped_final_stddev_packet = Packet(); dumped_final_packet = Packet(); LOG(INFO) << "Expect an error to be logged here."; ASSERT_FALSE(graph->Run(extra_side_packets).ok()); LOG(INFO) << "Error should have been logged."; } TEST(CalculatorGraph, BadInitialization) { CalculatorGraphConfig proto = GetConfig(); CalculatorGraph graph; // Force the config to have a missing Calculator. proto.mutable_node(1)->clear_calculator(); ASSERT_FALSE(graph.Initialize(proto).ok()); } TEST(CalculatorGraph, BadRun) { CalculatorGraphConfig proto = GetConfig(); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(proto)); // Don't set the input side packets. EXPECT_FALSE(graph.Run().ok()); } TEST(CalculatorGraph, RunsCorrectly) { CalculatorGraph graph; CalculatorGraphConfig proto = GetConfig(); RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } TEST(CalculatorGraph, RunsCorrectlyOnApplicationThread) { CalculatorGraph graph; CalculatorGraphConfig proto = GetConfig(); // Force application thread to be used. proto.set_num_threads(0); RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } TEST(CalculatorGraph, RunsCorrectlyWithExternalExecutor) { CalculatorGraph graph; MP_ASSERT_OK(graph.SetExecutor("", std::make_shared(1))); CalculatorGraphConfig proto = GetConfig(); RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } // This test verifies that the MediaPipe framework calls Executor::AddTask() // without holding any mutex, because CurrentThreadExecutor::AddTask() may // result in a recursive call to itself. TEST(CalculatorGraph, RunsCorrectlyWithCurrentThreadExecutor) { CalculatorGraph graph; MP_ASSERT_OK( graph.SetExecutor("", std::make_shared())); CalculatorGraphConfig proto = GetConfig(); RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } TEST(CalculatorGraph, RunsCorrectlyWithNonDefaultExecutors) { CalculatorGraph graph; // Add executors "second" and "third". MP_ASSERT_OK( graph.SetExecutor("second", std::make_shared(1))); MP_ASSERT_OK( graph.SetExecutor("third", std::make_shared(1))); CalculatorGraphConfig proto = GetConfig(); ExecutorConfig* executor = proto.add_executor(); executor->set_name("second"); executor = proto.add_executor(); executor->set_name("third"); for (int i = 0; i < proto.node_size(); ++i) { switch (i % 3) { case 0: // Use the default executor. break; case 1: proto.mutable_node(i)->set_executor("second"); break; case 2: proto.mutable_node(i)->set_executor("third"); break; } } RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } TEST(CalculatorGraph, RunsCorrectlyWithMultipleExecutors) { CalculatorGraph graph; // Add executors "second" and "third". CalculatorGraphConfig proto = GetConfig(); ExecutorConfig* executor = proto.add_executor(); executor->set_name("second"); executor->set_type("ThreadPoolExecutor"); MediaPipeOptions* options = executor->mutable_options(); ThreadPoolExecutorOptions* extension = options->MutableExtension(ThreadPoolExecutorOptions::ext); extension->set_num_threads(1); executor = proto.add_executor(); executor->set_name("third"); executor->set_type("ThreadPoolExecutor"); options = executor->mutable_options(); extension = options->MutableExtension(ThreadPoolExecutorOptions::ext); extension->set_num_threads(1); for (int i = 0; i < proto.node_size(); ++i) { switch (i % 3) { case 0: // Use the default executor. break; case 1: proto.mutable_node(i)->set_executor("second"); break; case 2: proto.mutable_node(i)->set_executor("third"); break; } } RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } // Packet generator for an arbitrary unit64 packet. class Uint64PacketGenerator : public PacketGenerator { public: static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { output_side_packets->Index(0).Set(); return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { output_side_packets->Index(0) = Adopt(new uint64_t(15LL << 32 | 5)); return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); TEST(CalculatorGraph, GeneratePacket) { CalculatorGraph graph; CalculatorGraphConfig proto = GetConfig(); PacketGeneratorConfig* generator = proto.add_packet_generator(); generator->set_packet_generator("Uint64PacketGenerator"); generator->add_output_side_packet("node_5"); RunComprehensiveTest(&graph, proto, false); } TEST(CalculatorGraph, TypeMismatch) { CalculatorGraphConfig config; CalculatorGraphConfig::Node* node = config.add_node(); node->add_output_stream("stream_a"); node = config.add_node(); node->add_input_stream("stream_a"); std::unique_ptr graph; // Type matches, expect success. config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); config.mutable_node(1)->set_calculator("StringSinkCalculator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); MP_EXPECT_OK(graph->Run()); graph.reset(nullptr); // Type matches, expect success. config.mutable_node(0)->set_calculator("IntEmptySourceCalculator"); config.mutable_node(1)->set_calculator("IntSinkCalculator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); MP_EXPECT_OK(graph->Run()); graph.reset(nullptr); // Type mismatch, expect non-crashing failure. config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); config.mutable_node(1)->set_calculator("IntSinkCalculator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); EXPECT_FALSE(graph->Run().ok()); graph.reset(nullptr); // Type mismatch, expect non-crashing failure. config.mutable_node(0)->set_calculator("IntEmptySourceCalculator"); config.mutable_node(1)->set_calculator("StringSinkCalculator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); EXPECT_FALSE(graph->Run().ok()); graph.reset(nullptr); } TEST(CalculatorGraph, LayerOrdering) { CalculatorGraphConfig config; CalculatorGraphConfig::Node* node; node = config.add_node(); node->set_calculator("GlobalCountSourceCalculator"); node->add_input_side_packet("global_counter"); node->add_output_stream("count_layer_0_node_0"); node->set_source_layer(0); node = config.add_node(); node->set_calculator("GlobalCountSourceCalculator"); node->add_input_side_packet("global_counter"); node->add_output_stream("count_layer_1_node_0"); node->set_source_layer(1); node = config.add_node(); node->set_calculator("GlobalCountSourceCalculator"); node->add_input_side_packet("global_counter"); node->add_output_stream("count_layer_1_node_1"); node->set_source_layer(1); node = config.add_node(); node->set_calculator("GlobalCountSourceCalculator"); node->add_input_side_packet("global_counter"); node->add_output_stream("count_layer_2_node_0"); node->set_source_layer(2); // Set num threads to 1 because we rely on sequential execution for this test. config.set_num_threads(1); std::vector dump_layer_0_node_0; std::vector dump_layer_1_node_0; std::vector dump_layer_1_node_1; std::vector dump_layer_2_node_0; tool::AddVectorSink("count_layer_0_node_0", &config, &dump_layer_0_node_0); tool::AddVectorSink("count_layer_1_node_0", &config, &dump_layer_1_node_0); tool::AddVectorSink("count_layer_1_node_1", &config, &dump_layer_1_node_1); tool::AddVectorSink("count_layer_2_node_0", &config, &dump_layer_2_node_0); auto graph = absl::make_unique(); std::atomic global_counter(0); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); MP_ASSERT_OK(graph->Initialize(config)); MP_ASSERT_OK(graph->Run(input_side_packets)); graph.reset(nullptr); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, dump_layer_0_node_0.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, dump_layer_1_node_0.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, dump_layer_1_node_1.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, dump_layer_2_node_0.size()); // Check layer 0. for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { EXPECT_EQ(i, dump_layer_0_node_0[i].Get()); EXPECT_EQ(Timestamp(i), dump_layer_0_node_0[i].Timestamp()); } // Check layer 1 is interleaved (arbitrarily). for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 == dump_layer_1_node_0[i].Get() || GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 == dump_layer_1_node_0[i].Get()); EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 == dump_layer_1_node_1[i].Get() || GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 == dump_layer_1_node_1[i].Get()); EXPECT_EQ(Timestamp(i), dump_layer_1_node_0[i].Timestamp()); EXPECT_EQ(Timestamp(i), dump_layer_1_node_1[i].Timestamp()); } // Check layer 2. for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { EXPECT_EQ(3 * GlobalCountSourceCalculator::kNumOutputPackets + i, dump_layer_2_node_0[i].Get()); EXPECT_EQ(Timestamp(i), dump_layer_2_node_0[i].Timestamp()); } EXPECT_EQ( 20, input_side_packets["global_counter"].Get*>()->load()); } // Tests for status handler input verification. TEST(CalculatorGraph, StatusHandlerInputVerification) { // Status handlers with all inputs present should be OK. auto graph = absl::make_unique(); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( packet_generator { packet_generator: "StaticCounterStringGenerator" output_side_packet: "created_by_factory" } packet_generator { packet_generator: "TaggedIntSplitterPacketGenerator" input_side_packet: "a_uint64" output_side_packet: "HIGH:generated_by_generator" output_side_packet: "LOW:unused_low" output_side_packet: "PAIR:unused_pair" } status_handler { status_handler: "Uint32StatusHandler" input_side_packet: "generated_by_generator" } status_handler { status_handler: "StringStatusHandler" input_side_packet: "created_by_factory" } status_handler { status_handler: "StringStatusHandler" input_side_packet: "extra_string" } )pb"); MP_ASSERT_OK(graph->Initialize(config)); Packet extra_string = Adopt(new std::string("foo")); Packet a_uint64 = Adopt(new uint64_t(0)); MP_EXPECT_OK( graph->Run({{"extra_string", extra_string}, {"a_uint64", a_uint64}})); // Should fail verification when missing a required input. The generator is // OK, but the StringStatusHandler is missing its input. EXPECT_FALSE(graph->Run({{"a_uint64", a_uint64}}).ok()); // Should fail verification when the type of an already created packet is // wrong. Here we give the uint64 packet instead of the string packet to the // StringStatusHandler. EXPECT_FALSE( graph->Run({{"extra_string", a_uint64}, {"a_uint64", a_uint64}}).ok()); // Should fail verification when the type of a packet generated by a base // packet factory is wrong. Everything is correct except we add a status // handler expecting a uint32 but give it the string from the packet factory. auto* invalid_handler = config.add_status_handler(); invalid_handler->set_status_handler("Uint32StatusHandler"); invalid_handler->add_input_side_packet("created_by_factory"); graph.reset(new CalculatorGraph()); absl::Status status = graph->Initialize(config); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("Uint32StatusHandler"), // The problematic input side packet. testing::HasSubstr("created_by_factory"), // Actual type. testing::HasSubstr("string"), // Expected type. testing::HasSubstr( MediaPipeTypeStringOrDemangled()))); // Should fail verification when the type of a to-be-generated packet is // wrong. The added handler now expects a string but will receive the uint32 // generated by the existing generator. invalid_handler->set_status_handler("StringStatusHandler"); invalid_handler->set_input_side_packet(0, "generated_by_generator"); graph.reset(new CalculatorGraph()); // This is caught earlier, when the type of the PacketGenerator output // is compared to the type of the StatusHandler input. status = graph->Initialize(config); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("StringStatusHandler"), // The problematic input side packet. testing::HasSubstr("generated_by_generator"), // Actual type. testing::HasSubstr( MediaPipeTypeStringOrDemangled()), // Expected type. testing::HasSubstr("string"))); } TEST(CalculatorGraph, GenerateInInitialize) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" output_side_packet: "foo1" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "foo1" output_side_packet: "foo2" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "input_in_run" output_side_packet: "foo3" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "input_in_run" input_side_packet: "foo3" output_side_packet: "foo4" } )pb"); int initial_count = StaticCounterStringGenerator::NumPacketsGenerated(); MP_ASSERT_OK(graph.Initialize( config, {{"created_by_factory", MakePacket("default string")}, {"input_in_initialize", MakePacket(10)}})); EXPECT_EQ(initial_count + 2, StaticCounterStringGenerator::NumPacketsGenerated()); MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(11)}})); EXPECT_EQ(initial_count + 4, StaticCounterStringGenerator::NumPacketsGenerated()); MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(12)}})); EXPECT_EQ(initial_count + 6, StaticCounterStringGenerator::NumPacketsGenerated()); } // Resets the counters in the input side packets used in the HandlersRun test. // The value of all these counters will be set to the integer zero, as required // at the beginning of the test. void ResetCounters(std::map* input_side_packets) { (*input_side_packets)["no_input_counter1"] = AdoptAsUniquePtr(new int(0)); (*input_side_packets)["no_input_counter2"] = AdoptAsUniquePtr(new int(0)); (*input_side_packets)["available_input_counter1"] = AdoptAsUniquePtr(new int(0)); (*input_side_packets)["available_input_counter2"] = AdoptAsUniquePtr(new int(0)); (*input_side_packets)["unavailable_input_counter1"] = AdoptAsUniquePtr(new int(0)); (*input_side_packets)["unavailable_input_counter2"] = AdoptAsUniquePtr(new int(0)); } // Tests that status handlers run. // - We specify three status handlers: one taking no input side packets, one // taking // an input side packet that is always provided in the call to Run(), and one // that takes the input side packet that will not be produced by the // FailingPacketGenerator. The first two should proccess their PRE-RUN status // but not their POST-RUN status, the third one should not process either of // them since the graph execution fails before the PRE-RUN step. // - We then replace the FailingPacketGenerator with a non-failing generator, // and should have all three handlers running both PRE and POST-RUN (after the // FailingSourceCalculator fails). // - We test that all three status handlers run (with both status) at the end of // a successful graph run. // - Finally, we verify that when the status handler fails (either on PRE or // POST run), but the calculators don't, we still receive errors from the // calculator run. TEST(CalculatorGraph, HandlersRun) { std::unique_ptr graph(new CalculatorGraph()); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( packet_generator { packet_generator: "FailingPacketGenerator" output_side_packet: "unavailable" } node { calculator: "FailingSourceCalculator" output_stream: "output" } status_handler { status_handler: "IncrementingStatusHandler" input_side_packet: "COUNTER1:no_input_counter1" input_side_packet: "COUNTER2:no_input_counter2" } status_handler { status_handler: "IncrementingStatusHandler" input_side_packet: "COUNTER1:available_input_counter1" input_side_packet: "COUNTER2:available_input_counter2" input_side_packet: "EXTRA:available_string" } status_handler { status_handler: "IncrementingStatusHandler" input_side_packet: "COUNTER1:unavailable_input_counter1" input_side_packet: "COUNTER2:unavailable_input_counter2" input_side_packet: "EXTRA:unavailable" } )pb"); std::map input_side_packets( {{"unused_input", AdoptAsUniquePtr(new int(0))}, {"no_input_counter1", AdoptAsUniquePtr(new int(0))}, {"no_input_counter2", AdoptAsUniquePtr(new int(0))}, {"available_input_counter1", AdoptAsUniquePtr(new int(0))}, {"available_input_counter2", AdoptAsUniquePtr(new int(0))}, {"unavailable_input_counter1", AdoptAsUniquePtr(new int(0))}, {"unavailable_input_counter2", AdoptAsUniquePtr(new int(0))}, {"available_string", Adopt(new std::string("foo"))}}); // When the graph fails in initialize (even because of a PacketGenerator // returning an error), status handlers should not be run. ASSERT_THAT(graph->Initialize(config).ToString(), testing::HasSubstr("FailingPacketGenerator")); EXPECT_EQ(0, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); // Add an input side packet to the packet generator so that it doesn't // run at initialize time. config.mutable_packet_generator(0)->add_input_side_packet("unused_input"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); EXPECT_THAT(graph->Run(input_side_packets).ToString(), testing::HasSubstr("FailingPacketGenerator")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); // Replace the failing packet generator with something that works. All three // status handlers should now process both the PRE and POST-RUN status. config.mutable_packet_generator(0)->set_packet_generator( "StaticCounterStringGenerator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); // The entire graph should still fail because of the FailingSourceCalculator. EXPECT_THAT(graph->Run(input_side_packets).ToString(), testing::HasSubstr("FailingSourceCalculator")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); // Replace the failing calculator with something that works. All three // status handlers should still process both PRE and POST-RUN status as part // of the successful graph run. config.mutable_node(0)->set_calculator("StringEmptySourceCalculator"); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); MP_EXPECT_OK(graph->Run(input_side_packets)); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); absl::Status run_status; // Make status handlers fail. The graph should fail. // First, when the PRE_run fails IncrementingStatusHandler::SetPreRunStatusResult( absl::InternalError("Fail at pre-run")); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); run_status = graph->Run(input_side_packets); EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal); EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at pre-run")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(0, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); // Second, when the POST_run fails IncrementingStatusHandler::SetPreRunStatusResult(absl::OkStatus()); IncrementingStatusHandler::SetPostRunStatusResult( absl::InternalError("Fail at post-run")); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); run_status = graph->Run(input_side_packets); EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal); EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at post-run")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("available_input_counter2"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter1"))); EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); } TEST(CalculatorGraph, CalculatorGraphConfigCopyElision) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } )pb"); // config is consumed and never copied, which avoid copying data. MP_ASSERT_OK(graph.Initialize(std::move(config))); MP_EXPECT_OK(graph.StartRun({})); MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(1).At(Timestamp(1)))); MP_EXPECT_OK(graph.CloseInputStream("in")); MP_EXPECT_OK(graph.WaitUntilDone()); } // Test that calling SetOffset() in Calculator::Process() results in the // absl::StatusCode::kFailedPrecondition error. TEST(CalculatorGraph, SetOffsetInProcess) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'SetOffsetInProcessCalculator' input_stream: 'in' output_stream: 'out' } )pb"); MP_ASSERT_OK(graph.Initialize(config)); MP_EXPECT_OK(graph.StartRun({})); MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(0).At(Timestamp(0)))); absl::Status status = graph.WaitUntilIdle(); EXPECT_FALSE(status.ok()); EXPECT_EQ(absl::StatusCode::kFailedPrecondition, status.code()); } // Test that MediaPipe releases input packets when it is done with them. TEST(CalculatorGraph, InputPacketLifetime) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' input_stream: 'mid' output_stream: 'out' } )pb"); LifetimeTracker tracker; Timestamp timestamp = Timestamp(0); auto new_packet = [×tamp, &tracker] { return Adopt(tracker.MakeObject().release()).At(++timestamp); }; MP_ASSERT_OK(graph.Initialize(config)); MP_EXPECT_OK(graph.StartRun({})); MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); MP_EXPECT_OK(graph.WaitUntilIdle()); EXPECT_EQ(0, tracker.live_count()); MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet())); MP_EXPECT_OK(graph.WaitUntilIdle()); EXPECT_EQ(0, tracker.live_count()); MP_EXPECT_OK(graph.CloseInputStream("in")); MP_EXPECT_OK(graph.WaitUntilDone()); } // Demonstrate an if-then-else graph. TEST(CalculatorGraph, IfThenElse) { // This graph has an if-then-else structure. The left branch, selected by the // select value 0, applies a double (multiply by 2) operation. The right // branch, selected by the select value 1, applies a square operation. The // left branch also has some no-op PassThroughCalculators to make the lengths // of the two branches different. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' input_stream: 'select' node { calculator: 'DemuxTimedCalculator' input_stream: 'INPUT:in' input_stream: 'SELECT:select' output_stream: 'OUTPUT:0:left' output_stream: 'OUTPUT:1:right' } node { calculator: 'PassThroughCalculator' input_stream: 'left' output_stream: 'left1' } node { calculator: 'DoubleIntCalculator' input_stream: 'left1' output_stream: 'left2' } node { calculator: 'PassThroughCalculator' input_stream: 'left2' output_stream: 'left3' } node { calculator: 'SquareIntCalculator' input_stream: 'right' output_stream: 'right1' } node { calculator: 'MuxTimedCalculator' input_stream: 'INPUT:0:left3' input_stream: 'INPUT:1:right1' input_stream: 'SELECT:select' output_stream: 'OUTPUT:out' } )pb"); std::vector packet_dump; tool::AddVectorSink("out", &config, &packet_dump); Timestamp timestamp = Timestamp(0); auto send_inputs = [&graph, ×tamp](int input, int select) { ++timestamp; MP_EXPECT_OK(graph.AddPacketToInputStream( "in", MakePacket(input).At(timestamp))); MP_EXPECT_OK(graph.AddPacketToInputStream( "select", MakePacket(select).At(timestamp))); }; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); // If the "select" input is 0, we apply a double operation. If "select" is 1, // we apply a square operation. To make the code easier to understand, define // symbolic names for the select values. const int kApplyDouble = 0; const int kApplySquare = 1; send_inputs(1, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(1, packet_dump.size()); EXPECT_EQ(2, packet_dump[0].Get()); send_inputs(2, kApplySquare); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(2, packet_dump.size()); EXPECT_EQ(4, packet_dump[1].Get()); send_inputs(3, kApplyDouble); send_inputs(4, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); EXPECT_EQ(4, packet_dump.size()); EXPECT_EQ(6, packet_dump[2].Get()); EXPECT_EQ(8, packet_dump[3].Get()); send_inputs(5, kApplySquare); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(5, packet_dump.size()); EXPECT_EQ(25, packet_dump[4].Get()); send_inputs(6, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(6, packet_dump.size()); EXPECT_EQ(12, packet_dump[5].Get()); send_inputs(7, kApplySquare); send_inputs(8, kApplySquare); send_inputs(9, kApplySquare); send_inputs(10, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(10, packet_dump.size()); EXPECT_EQ(49, packet_dump[6].Get()); EXPECT_EQ(64, packet_dump[7].Get()); EXPECT_EQ(81, packet_dump[8].Get()); EXPECT_EQ(20, packet_dump[9].Get()); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); EXPECT_EQ(10, packet_dump.size()); } // A simple output selecting test calculator, which omits timestamp bounds // for the unselected outputs. class DemuxUntimedCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Tag(kInputTag).SetAny(); cc->Inputs().Tag(kSelectTag).Set(); for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); id < cc->Outputs().EndId("OUTPUT"); ++id) { cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag)); } return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { int index = cc->Inputs().Tag(kSelectTag).Get(); if (!cc->Inputs().Tag(kInputTag).IsEmpty()) { cc->Outputs() .Get("OUTPUT", index) .AddPacket(cc->Inputs().Tag(kInputTag).Value()); } else { cc->Outputs() .Get("OUTPUT", index) .SetNextTimestampBound(cc->InputTimestamp() + 1); } return absl::OkStatus(); } }; REGISTER_CALCULATOR(DemuxUntimedCalculator); // Demonstrate an if-then-else graph. This test differs from the IfThenElse test // in that it uses optional input streams instead of next timestamp bound // propagation. TEST(CalculatorGraph, IfThenElse2) { // This graph has an if-then-else structure. The left branch, selected by the // select value 0, applies a double (multiply by 2) operation. The right // branch, selected by the select value 1, applies a square operation. The // left branch also has some no-op PassThroughCalculators to make the lengths // of the two branches different. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' input_stream: 'select' node { calculator: 'DemuxUntimedCalculator' input_stream: 'INPUT:in' input_stream: 'SELECT:select' output_stream: 'OUTPUT:0:left' output_stream: 'OUTPUT:1:right' } node { calculator: 'PassThroughCalculator' input_stream: 'left' output_stream: 'left1' } node { calculator: 'DoubleIntCalculator' input_stream: 'left1' output_stream: 'left2' } node { calculator: 'PassThroughCalculator' input_stream: 'left2' output_stream: 'left3' } node { calculator: 'SquareIntCalculator' input_stream: 'right' output_stream: 'right1' } node { calculator: 'MuxCalculator' input_stream: 'INPUT:0:left3' input_stream: 'INPUT:1:right1' input_stream: 'SELECT:select' output_stream: 'OUTPUT:out' input_stream_handler { input_stream_handler: 'MuxInputStreamHandler' } } )pb"); std::vector packet_dump; tool::AddVectorSink("out", &config, &packet_dump); Timestamp timestamp = Timestamp(0); auto send_inputs = [&graph, ×tamp](int input, int select) { ++timestamp; MP_EXPECT_OK(graph.AddPacketToInputStream( "in", MakePacket(input).At(timestamp))); MP_EXPECT_OK(graph.AddPacketToInputStream( "select", MakePacket(select).At(timestamp))); }; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); // If the "select" input is 0, we apply a double operation. If "select" is 1, // we apply a square operation. To make the code easier to understand, define // symbolic names for the select values. const int kApplyDouble = 0; const int kApplySquare = 1; send_inputs(1, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(1, packet_dump.size()); EXPECT_EQ(2, packet_dump[0].Get()); send_inputs(2, kApplySquare); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(2, packet_dump.size()); EXPECT_EQ(4, packet_dump[1].Get()); send_inputs(3, kApplyDouble); send_inputs(4, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); EXPECT_EQ(4, packet_dump.size()); EXPECT_EQ(6, packet_dump[2].Get()); EXPECT_EQ(8, packet_dump[3].Get()); send_inputs(5, kApplySquare); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(5, packet_dump.size()); EXPECT_EQ(25, packet_dump[4].Get()); send_inputs(6, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(6, packet_dump.size()); EXPECT_EQ(12, packet_dump[5].Get()); send_inputs(7, kApplySquare); send_inputs(8, kApplySquare); send_inputs(9, kApplySquare); send_inputs(10, kApplyDouble); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(10, packet_dump.size()); EXPECT_EQ(49, packet_dump[6].Get()); EXPECT_EQ(64, packet_dump[7].Get()); EXPECT_EQ(81, packet_dump[8].Get()); EXPECT_EQ(20, packet_dump[9].Get()); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); EXPECT_EQ(10, packet_dump.size()); } // A regression test for bug 28321551. The scheduler should be able to run // the calculator graph to completion without hanging. The test merely checks // that CalculatorGraph::Run() returns. TEST(CalculatorGraph, ClosedSourceNodeShouldNotBeUnthrottled) { // This calculator graph has two source nodes. The first source node, // OutputAllSourceCalculator, outputs a lot of packets in one shot and stops. // The second source node, OutputOneAtATimeSourceCalculator, outputs one // packet at a time. But it is connected to a node, DecimatorCalculator, // that discards most of its input packets and only rarely outputs a packet. // The sink node, MergeCalculator, receives three input streams, two from // the two source nodes and one from DecimatorCalculator. The two input // streams connected to the two source nodes will become full, and the // MediaPipe scheduler will throttle the source nodes. // // The MediaPipe scheduler should not schedule a closed source node, even if // the source node filled an input stream and the input stream changes from // being "full" to "not full". CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( num_threads: 1 max_queue_size: 100 node { calculator: 'OutputAllSourceCalculator' output_stream: 'first_stream' } node { calculator: 'OutputOneAtATimeSourceCalculator' output_stream: 'second_stream' } node { calculator: 'DecimatorCalculator' input_stream: 'second_stream' output_stream: 'decimated_second_stream' } node { calculator: 'MergeCalculator' input_stream: 'first_stream' input_stream: 'second_stream' input_stream: 'decimated_second_stream' output_stream: 'output' } )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run()); } // Tests that a calculator can output a packet in the Open() method. // // The initial output packet generated by UnitDelayCalculator::Open() causes // the following to happen before the scheduler starts to run: // - The downstream PassThroughCalculator becomes ready and is added to the // scheduler queue. // - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is // throttled. // The scheduler should be able to run the graph from this initial state. TEST(CalculatorGraph, OutputPacketInOpen) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( max_queue_size: 1 node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'UnitDelayCalculator' input_stream: 'integers' output_stream: 'delayed_integers' } node { calculator: 'PassThroughCalculator' input_stream: 'delayed_integers' output_stream: 'output' } )pb"); std::vector packet_dump; tool::AddVectorSink("output", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1, packet_dump.size()); EXPECT_EQ(0, packet_dump[0].Get()); EXPECT_EQ(Timestamp(0), packet_dump[0].Timestamp()); for (int i = 1; i <= GlobalCountSourceCalculator::kNumOutputPackets; ++i) { EXPECT_EQ(i, packet_dump[i].Get()); EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } } // Tests that a calculator can output a packet in the Open() method. // // The initial output packet generated by UnitDelayCalculator::Open() causes // the following to happen before the scheduler starts to run: // - The downstream MergeCalculator does not become ready because its second // input stream has no packet. // - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is // throttled. // The scheduler must schedule a throttled source node from the beginning. TEST(CalculatorGraph, OutputPacketInOpen2) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( max_queue_size: 1 node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'UnitDelayCalculator' input_stream: 'integers' output_stream: 'delayed_integers' } node { calculator: 'MergeCalculator' input_stream: 'delayed_integers' input_stream: 'integers' output_stream: 'output' } )pb"); std::vector packet_dump; tool::AddVectorSink("output", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1, packet_dump.size()); int i; for (i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { std::string expected = absl::Substitute("Timestamp($0) $1 $2", packet_dump[i].Timestamp().DebugString(), i, i + 1); EXPECT_EQ(expected, packet_dump[i].Get()); EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } std::string expected = absl::Substitute( "Timestamp($0) $1 empty", packet_dump[i].Timestamp().DebugString(), i); EXPECT_EQ(expected, packet_dump[i].Get()); EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } // Tests that no packets are available on input streams in Open(), even if the // upstream calculator outputs a packet in Open(). TEST(CalculatorGraph, EmptyInputInOpen) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( max_queue_size: 1 node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } # UnitDelayCalculator outputs a packet during Open(). node { calculator: 'UnitDelayCalculator' input_stream: 'integers' output_stream: 'delayed_integers' } node { calculator: 'AssertEmptyInputInOpenCalculator' input_stream: 'delayed_integers' } node { calculator: 'AssertEmptyInputInOpenCalculator' input_stream: 'integers' } )pb"); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_EXPECT_OK(graph.Run(input_side_packets)); } // Test for b/33568859. TEST(CalculatorGraph, UnthrottleRespectsLayers) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( max_queue_size: 1 node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers0' source_layer: 0 } node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' input_side_packet: 'output_in_open' output_stream: 'integers1' source_layer: 1 } node { calculator: 'PassThroughCalculator' input_stream: 'integers1' output_stream: 'integers1passthrough' } )pb"); std::vector layer0_packets; std::vector layer1_packets; tool::AddVectorSink("integers0", &config, &layer0_packets); tool::AddVectorSink("integers1passthrough", &config, &layer1_packets); std::atomic global_counter(0); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); // TODO: Set this value to true. When the calculator outputs a // packet in Open, it will trigget b/33568859, and the test will fail. Use // this test to verify that b/33568859 is fixed. constexpr bool kOutputInOpen = true; input_side_packets["output_in_open"] = MakePacket(kOutputInOpen); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, layer0_packets.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, layer1_packets.size()); // Check that we ran things in the expected order. int count = 0; if (kOutputInOpen) { EXPECT_EQ(count, layer1_packets[0].Get()); EXPECT_EQ(Timestamp(0), layer1_packets[0].Timestamp()); ++count; } for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i, ++count) { EXPECT_EQ(count, layer0_packets[i].Get()); EXPECT_EQ(Timestamp(i), layer0_packets[i].Timestamp()); } for (int i = kOutputInOpen ? 1 : 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i, ++count) { EXPECT_EQ(count, layer1_packets[i].Get()); EXPECT_EQ(Timestamp(i), layer1_packets[i].Timestamp()); } } // The graph calculates the sum of all the integers output by the source node // so far. The graph has one cycle. TEST(CalculatorGraph, Cycle) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'IntAdderCalculator' input_stream: 'integers' input_stream: 'old_sum' input_stream_info: { tag_index: ':1' # 'old_sum' back_edge: true } output_stream: 'sum' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } node { calculator: 'UnitDelayCalculator' input_stream: 'sum' output_stream: 'old_sum' } )pb"); std::vector packet_dump; tool::AddVectorSink("sum", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); int sum = 0; for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { sum += i + 1; EXPECT_EQ(sum, packet_dump[i].Get()); EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } } // The graph calculates the sum of all the integers output by the source node // so far. The graph has one cycle. // // The difference from the "Cycle" test is that the graph is scheduled with // packet timestamps ignored. TEST(CalculatorGraph, CycleUntimed) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream_handler { input_stream_handler: 'BarrierInputStreamHandler' } node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'IntAdderCalculator' input_stream: 'integers' input_stream: 'old_sum' input_stream_info: { tag_index: ':1' # 'old_sum' back_edge: true } output_stream: 'sum' } node { calculator: 'UnitDelayUntimedCalculator' input_stream: 'sum' output_stream: 'old_sum' } )pb"); std::vector packet_dump; tool::AddVectorSink("sum", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); int sum = 0; for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) { sum += i + 1; EXPECT_EQ(sum, packet_dump[i].Get()); } } // This unit test is a direct form I implementation of Example 6.2 of // Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.6. The system // function of the linear time-invariant (LTI) system is // H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2) // The graph has two cycles. TEST(CalculatorGraph, DirectFormI) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'IntToFloatCalculator' input_stream: 'integers' output_stream: 'x' } node { calculator: 'FloatUnitDelayCalculator' input_stream: 'x' output_stream: 'a' } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'a' output_stream: 'b' input_side_packet: 'b1' } node { calculator: 'FloatAdderCalculator' input_stream: 'x' input_stream: 'b' output_stream: 'c' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } node { calculator: 'FloatAdderCalculator' input_stream: 'c' input_stream: 'f' input_stream_info: { tag_index: ':1' # 'f' back_edge: true } output_stream: 'y' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } node { calculator: 'FloatUnitDelayCalculator' input_stream: 'y' output_stream: 'd' } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'd' output_stream: 'e' input_side_packet: 'a1' } node { calculator: 'FloatUnitDelayCalculator' input_stream: 'd' output_stream: 'g' } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'g' output_stream: 'h' input_side_packet: 'a2' } node { calculator: 'FloatAdderCalculator' input_stream: 'e' input_stream: 'h' output_stream: 'f' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } )pb"); std::vector packet_dump; tool::AddVectorSink("y", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); input_side_packets["a2"] = Adopt(new float(-0.9)); input_side_packets["a1"] = Adopt(new float(1.5)); input_side_packets["b1"] = Adopt(new float(2.0)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5); EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get()); EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get()); EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get()); EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get()); EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get()); for (int i = 0; i < 5; ++i) { EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } } // This unit test is a direct form II implementation of Example 6.2 of // Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.7. The system // function of the linear time-invariant (LTI) system is // H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2) // The graph has two cycles. TEST(CalculatorGraph, DirectFormII) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'GlobalCountSourceCalculator' input_side_packet: 'global_counter' output_stream: 'integers' } node { calculator: 'IntToFloatCalculator' input_stream: 'integers' output_stream: 'x' } node { calculator: 'FloatAdderCalculator' input_stream: 'x' input_stream: 'f' input_stream_info: { tag_index: ':1' # 'f' back_edge: true } output_stream: 'a' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } node { calculator: 'FloatUnitDelayCalculator' input_stream: 'a' output_stream: 'b' } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'b' output_stream: 'd' input_side_packet: 'a1' } node { calculator: 'FloatUnitDelayCalculator' input_stream: 'b' output_stream: 'c' } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'c' output_stream: 'e' input_side_packet: 'a2' } node { calculator: 'FloatAdderCalculator' input_stream: 'd' input_stream: 'e' output_stream: 'f' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } node { calculator: 'FloatScalarMultiplierCalculator' input_stream: 'b' output_stream: 'g' input_side_packet: 'b1' } node { calculator: 'FloatAdderCalculator' input_stream: 'a' input_stream: 'g' output_stream: 'y' input_stream_handler { input_stream_handler: 'EarlyCloseInputStreamHandler' } } )pb"); std::vector packet_dump; tool::AddVectorSink("y", &config, &packet_dump); std::atomic global_counter(1); std::map input_side_packets; input_side_packets["global_counter"] = Adopt(new auto(&global_counter)); input_side_packets["a2"] = Adopt(new float(-0.9)); input_side_packets["a1"] = Adopt(new float(1.5)); input_side_packets["b1"] = Adopt(new float(2.0)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run(input_side_packets)); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size()); ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5); EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get()); EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get()); EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get()); EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get()); EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get()); for (int i = 0; i < 5; ++i) { EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp()); } } // Calculates the dot products of two streams of three-dimensional vectors. TEST(CalculatorGraph, DotProduct) { // The use of BarrierInputStreamHandler in this graph aligns the input // packets to a calculator by arrival order rather than by timestamp. CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream_handler { input_stream_handler: 'BarrierInputStreamHandler' } node { calculator: 'TestSequence1SourceCalculator' output_stream: 'test_sequence_1' } node { calculator: 'TestSequence2SourceCalculator' output_stream: 'test_sequence_2' } node { calculator: 'Modulo3SourceCalculator' output_stream: 'select_0_1_2' } node { calculator: 'DemuxUntimedCalculator' input_stream: 'INPUT:test_sequence_1' input_stream: 'SELECT:select_0_1_2' output_stream: 'OUTPUT:0:x_1' output_stream: 'OUTPUT:1:y_1' output_stream: 'OUTPUT:2:z_1' } node { calculator: 'DemuxUntimedCalculator' input_stream: 'INPUT:test_sequence_2' input_stream: 'SELECT:select_0_1_2' output_stream: 'OUTPUT:0:x_2' output_stream: 'OUTPUT:1:y_2' output_stream: 'OUTPUT:2:z_2' } node { calculator: 'IntMultiplierCalculator' input_stream: 'x_1' input_stream: 'x_2' output_stream: 'x_product' } node { calculator: 'IntMultiplierCalculator' input_stream: 'y_1' input_stream: 'y_2' output_stream: 'y_product' } node { calculator: 'IntMultiplierCalculator' input_stream: 'z_1' input_stream: 'z_2' output_stream: 'z_product' } node { calculator: 'IntAdderCalculator' input_stream: 'x_product' input_stream: 'y_product' input_stream: 'z_product' output_stream: 'dot_product' } )pb"); std::vector packet_dump; tool::AddVectorSink("dot_product", &config, &packet_dump); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run()); // The calculator graph performs the following computation: // test_sequence_1 is split into x_1, y_1, z_1. // test_sequence_2 is split into x_2, y_2, z_2. // x_product = x_1 * x_2 // y_product = y_1 * y_2 // z_product = z_1 * z_2 // dot_product = x_product + y_product + z_product // // The values in these streams are: // test_sequence_1: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 // test_sequence_2: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 // x_1: 0, 3, 6, 9, 12 // x_2: 1, 4, 7, 10, 13 // x_product: 0, 12, 42, 90, 156 // y_1: 1, 4, 7, 10, 13 // y_2: 2, 5, 8, 11, 14 // y_product: 2, 20, 56, 110, 182 // z_1: 2, 5, 8, 11, 14 // z_2: 3, 6, 9, 12, 15 // z_product: 6, 30, 72, 132, 210 // dot_product: 8, 62, 170, 332, 548 ASSERT_EQ(kTestSequenceLength / 3, packet_dump.size()); const int expected[] = {8, 62, 170, 332, 548}; for (int i = 0; i < packet_dump.size(); ++i) { EXPECT_EQ(expected[i], packet_dump[i].Get()); } } TEST(CalculatorGraph, TerminatesOnCancelWithOpenGraphInputStreams) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'PassThroughCalculator' input_stream: 'in_a' input_stream: 'in_b' output_stream: 'out_a' output_stream: 'out_b' } input_stream: 'in_a' input_stream: 'in_b' )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); MP_EXPECT_OK(graph.AddPacketToInputStream( "in_a", MakePacket(1).At(Timestamp(1)))); MP_EXPECT_OK(graph.CloseInputStream("in_a")); MP_EXPECT_OK(graph.AddPacketToInputStream( "in_b", MakePacket(2).At(Timestamp(2)))); MP_EXPECT_OK(graph.WaitUntilIdle()); graph.Cancel(); // This tests that the graph doesn't deadlock on WaitUntilDone (because // the scheduler thread is sleeping). absl::Status status = graph.WaitUntilDone(); EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); } TEST(CalculatorGraph, TerminatesOnCancelAfterPause) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } input_stream: 'in' )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); graph.Pause(); // Make the PassThroughCalculator runnable while the scheduler is paused. MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(1).At(Timestamp(1)))); // Now cancel the graph run. A non-empty scheduler queue should not prevent // the scheduler from terminating. graph.Cancel(); // Any attempt to pause the scheduler after the graph run is cancelled should // be ignored. graph.Pause(); // This tests that the graph doesn't deadlock on WaitUntilDone (because // the scheduler thread is sleeping). absl::Status status = graph.WaitUntilDone(); EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); } // A PacketGenerator that simply passes its input Packets through // unchanged. The inputs may be specified by tag or index. The outputs // must match the inputs exactly. Any options may be specified and will // also be ignored. class PassThroughGenerator : public PacketGenerator { public: static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs, PacketTypeSet* outputs) { if (!inputs->TagMap()->SameAs(*outputs->TagMap())) { return absl::InvalidArgumentError( "Input and outputs to PassThroughGenerator must use the same tags " "and indexes."); } for (CollectionItemId id = inputs->BeginId(); id < inputs->EndId(); ++id) { inputs->Get(id).SetAny(); outputs->Get(id).SetSameAs(&inputs->Get(id)); } return absl::OkStatus(); } static absl::Status Generate(const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { for (CollectionItemId id = input_side_packets.BeginId(); id < input_side_packets.EndId(); ++id) { output_side_packets->Get(id) = input_side_packets.Get(id); } return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(PassThroughGenerator); TEST(CalculatorGraph, RecoverAfterRunError) { PacketGeneratorGraph generator_graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { name: 'calculator1' calculator: 'CountingSourceCalculator' output_stream: 'count1' input_side_packet: 'MAX_COUNT:max_count2' input_side_packet: 'ERROR_COUNT:max_error2' } packet_generator { packet_generator: 'EnsurePositivePacketGenerator' input_side_packet: 'max_count1' output_side_packet: 'max_count2' input_side_packet: 'max_error1' output_side_packet: 'max_error2' } status_handler { status_handler: 'FailableStatusHandler' input_side_packet: 'status_handler_command' } )pb"); int packet_count = 0; CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config, {})); MP_ASSERT_OK(graph.ObserveOutputStream("count1", [&packet_count](const Packet& packet) { ++packet_count; return absl::OkStatus(); })); // Set ERROR_COUNT higher than MAX_COUNT and hence the calculator will // finish successfully. packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); // Fail in PacketGenerator::Generate(). // Negative max_count1 will cause EnsurePositivePacketGenerator to fail. ASSERT_FALSE(graph .Run({{"max_count1", MakePacket(-1)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); // Fail in PacketGenerator::Generate() also fail in StatusHandler. ASSERT_FALSE(graph .Run({{"max_count1", MakePacket(-1)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kFailPreRun)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); ASSERT_FALSE( graph .Run({{"max_count1", MakePacket(-1)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kFailPostRun)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); // Fail in Calculator::Process(). ASSERT_FALSE(graph .Run({{"max_count1", MakePacket(1000)}, {"max_error1", MakePacket(10)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); // Fail in Calculator::Process() also fail in StatusHandler. ASSERT_FALSE(graph .Run({{"max_count1", MakePacket(1000)}, {"max_error1", MakePacket(10)}, {"status_handler_command", MakePacket(FailableStatusHandler::kFailPreRun)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); ASSERT_FALSE( graph .Run({{"max_count1", MakePacket(1000)}, {"max_error1", MakePacket(10)}, {"status_handler_command", MakePacket(FailableStatusHandler::kFailPostRun)}}) .ok()); packet_count = 0; MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket(10)}, {"max_error1", MakePacket(20)}, {"status_handler_command", MakePacket(FailableStatusHandler::kOk)}})); EXPECT_EQ(packet_count, 10); } TEST(CalculatorGraph, SetInputStreamMaxQueueSizeWorksSlowCalculator) { using Semaphore = SemaphoreCalculator::Semaphore; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'SemaphoreCalculator' input_stream: 'in' output_stream: 'out' input_side_packet: 'POST_SEM:post_sem' input_side_packet: 'WAIT_SEM:wait_sem' } node { calculator: 'SemaphoreCalculator' input_stream: 'in_2' output_stream: 'out_2' input_side_packet: 'POST_SEM:post_sem_busy' input_side_packet: 'WAIT_SEM:wait_sem_busy' } input_stream: 'in' input_stream: 'in_2' max_queue_size: 100 )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); graph.SetGraphInputStreamAddMode( CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); MP_ASSERT_OK(graph.SetInputStreamMaxQueueSize("in", 1)); Semaphore calc_entered_process(0); Semaphore calc_can_exit_process(0); Semaphore calc_entered_process_busy(0); Semaphore calc_can_exit_process_busy(0); MP_ASSERT_OK(graph.StartRun({ {"post_sem", MakePacket(&calc_entered_process)}, {"wait_sem", MakePacket(&calc_can_exit_process)}, {"post_sem_busy", MakePacket(&calc_entered_process_busy)}, {"wait_sem_busy", MakePacket(&calc_can_exit_process_busy)}, })); Timestamp timestamp(0); // Prevent deadlock resolution by running the "busy" SemaphoreCalculator // for the duration of the test. MP_EXPECT_OK( graph.AddPacketToInputStream("in_2", MakePacket(0).At(timestamp))); MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(0).At(timestamp++))); for (int i = 1; i < 20; ++i, ++timestamp) { // Wait for the calculator to begin its Process call. calc_entered_process.Acquire(1); // Now the calculator is stuck processing a packet. We can queue up // another one. MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(i).At(timestamp))); // We should be prevented from adding another, since the queue is now full. absl::Status status = graph.AddPacketToInputStream( "in", MakePacket(i).At(timestamp + 1)); EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); // Allow calculator to complete its Process call. calc_can_exit_process.Release(1); } // Allow the final Process call to complete. calc_can_exit_process.Release(1); calc_can_exit_process_busy.Release(1); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); } // Verify the scheduler unthrottles the graph input stream to avoid a deadlock, // and won't enter a busy loop. TEST(CalculatorGraph, AddPacketNoBusyLoop) { // The DecimatorCalculator ouputs 1 out of every 101 input packets and drops // the rest, without setting the next timestamp bound on its output. As a // result, the MergeCalculator is not runnable in between and packets on its // "in" input stream will be queued and exceed the max queue size. // // in // | // / \ // / \ // / \ // | \ // v | // +---------+ | // 101:1 |Decimator| | <== Packet buildup // +---------+ | // | | // v v // +----------+ // | Merge | // +----------+ // | // v // out // CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' max_queue_size: 1 node { calculator: 'DecimatorCalculator' input_stream: 'in' output_stream: 'decimated_in' } node { calculator: 'MergeCalculator' input_stream: 'decimated_in' input_stream: 'in' output_stream: 'out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); graph.SetGraphInputStreamAddMode( CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL); std::vector out_packets; // Packets from the output stream "out". MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); const int kDecimationRatio = DecimatorCalculator::kDecimationRatio; // To leave the graph input stream "in" in the throttled state, kNumPackets // can be any value other than a multiple of kDecimationRatio plus one. const int kNumPackets = 2 * kDecimationRatio; for (int i = 0; i < kNumPackets; ++i) { MP_EXPECT_OK(graph.AddPacketToInputStream( "in", MakePacket(i).At(Timestamp(i)))); } // The graph input stream "in" is throttled. Wait until the graph is idle. MP_ASSERT_OK(graph.WaitUntilIdle()); // Check that Pause() does not block forever trying to acquire a mutex. // This is a regression test for an old bug. graph.Pause(); graph.Resume(); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); // The expected output packets are: // "Timestamp(0) 0 0" // "Timestamp(1) empty 1" // ... // "Timestamp(100) empty 100" // "Timestamp(101) 101 101" // "Timestamp(102) empty 102" // ... // "Timestamp(201) empty 201" ASSERT_EQ(kNumPackets, out_packets.size()); for (int i = 0; i < out_packets.size(); ++i) { std::string format = (i % kDecimationRatio == 0) ? "Timestamp($0) $0 $0" : "Timestamp($0) empty $0"; std::string expected = absl::Substitute(format, i); EXPECT_EQ(expected, out_packets[i].Get()); EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); } } namespace nested_ns { typedef std::function ProcessFunction; // A Calculator that delegates its Process function to a callback function. class ProcessCallbackCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); } cc->InputSidePackets().Index(0).Set>(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { callback_ = *GetFromUniquePtr(cc->InputSidePackets().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { return callback_(cc->Inputs(), &(cc->Outputs())); } private: ProcessFunction callback_; }; REGISTER_CALCULATOR(::mediapipe::nested_ns::ProcessCallbackCalculator); } // namespace nested_ns TEST(CalculatorGraph, CalculatorInNamepsace) { CalculatorGraphConfig config; CHECK(proto_ns::TextFormat::ParseFromString(R"( input_stream: 'in_a' node { calculator: 'mediapipe.nested_ns.ProcessCallbackCalculator' input_stream: 'in_a' output_stream: 'out_a' input_side_packet: 'callback_1' } )", &config)); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); nested_ns::ProcessFunction callback_1; MP_ASSERT_OK( graph.StartRun({{"callback_1", AdoptAsUniquePtr(new auto(callback_1))}})); MP_EXPECT_OK(graph.WaitUntilIdle()); } // A ProcessFunction that passes through all packets. absl::Status DoProcess(const InputStreamShardSet& inputs, OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } return absl::OkStatus(); } TEST(CalculatorGraph, ObserveOutputStream) { const int max_count = 10; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CountingSourceCalculator' output_stream: 'count' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'PassThroughCalculator' input_stream: 'count' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' input_stream: 'mid' output_stream: 'out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); // Observe the internal output stream "count" and the unconnected output // stream "out". std::vector count_packets; // Packets from the output stream "count". std::vector out_packets; // Packets from the output stream "out". MP_ASSERT_OK(graph.ObserveOutputStream( "count", [&count_packets](const Packet& packet) { count_packets.push_back(packet); return absl::OkStatus(); })); MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); ASSERT_EQ(max_count, count_packets.size()); for (int i = 0; i < count_packets.size(); ++i) { EXPECT_EQ(i, count_packets[i].Get()); EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp()); } ASSERT_EQ(max_count, out_packets.size()); for (int i = 0; i < out_packets.size(); ++i) { EXPECT_EQ(i, out_packets[i].Get()); EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); } } class PassThroughSubgraph : public Subgraph { public: absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'INPUT:input' output_stream: 'OUTPUT:output' node { calculator: 'PassThroughCalculator' input_stream: 'input' output_stream: 'output' } )pb"); return config; } }; REGISTER_MEDIAPIPE_GRAPH(PassThroughSubgraph); TEST(CalculatorGraph, ObserveOutputStreamSubgraph) { const int max_count = 10; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CountingSourceCalculator' output_stream: 'count' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'PassThroughSubgraph' input_stream: 'INPUT:count' output_stream: 'OUTPUT:out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); // Observe the unconnected output stream "out". std::vector out_packets; // Packets from the output stream "out". MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); ASSERT_EQ(max_count, out_packets.size()); for (int i = 0; i < out_packets.size(); ++i) { EXPECT_EQ(i, out_packets[i].Get()); EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp()); } } TEST(CalculatorGraph, ObserveOutputStreamError) { const int max_count = 10; const int fail_count = 6; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CountingSourceCalculator' output_stream: 'count' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'PassThroughCalculator' input_stream: 'count' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' input_stream: 'mid' output_stream: 'out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); // Observe the internal output stream "count" and the unconnected output // stream "out". std::vector count_packets; // Packets from the output stream "count". std::vector out_packets; // Packets from the output stream "out". MP_ASSERT_OK(graph.ObserveOutputStream( "count", [&count_packets](const Packet& packet) { count_packets.push_back(packet); if (count_packets.size() >= fail_count) { return absl::UnknownError("Expected. MagicString-eatnhuea"); } else { return absl::OkStatus(); } })); MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); return absl::OkStatus(); })); absl::Status status = graph.Run(); ASSERT_THAT(status.message(), testing::HasSubstr("MagicString-eatnhuea")); ASSERT_EQ(fail_count, count_packets.size()); for (int i = 0; i < count_packets.size(); ++i) { EXPECT_EQ(i, count_packets[i].Get()); EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp()); } } TEST(CalculatorGraph, ObserveOutputStreamNonexistent) { const int max_count = 10; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CountingSourceCalculator' output_stream: 'count' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'PassThroughCalculator' input_stream: 'count' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' input_stream: 'mid' output_stream: 'out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); // Observe the internal output stream "count". std::vector count_packets; // Packets from the output stream "count". absl::Status status = graph.ObserveOutputStream( "not_found", [&count_packets](const Packet& packet) { count_packets.push_back(packet); return absl::OkStatus(); }); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_THAT(status.message(), testing::HasSubstr("not_found")); } // Verify that after a fast source node is closed, a slow sink node can // consume all the accumulated input packets. In other words, closing an // output stream still allows its mirrors to process all the received packets. TEST(CalculatorGraph, FastSourceSlowSink) { const int max_count = 10; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( num_threads: 2 max_queue_size: 100 node { calculator: 'CountingSourceCalculator' output_stream: 'out' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'SlowCountingSinkCalculator' input_stream: 'out' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); MP_EXPECT_OK(graph.Run()); } TEST(CalculatorGraph, GraphFinishesWhilePaused) { // The graph contains only one node, and the node runs only once. This test // sets up the following sequence of events (all times in milliseconds): // // Application thread Worker thread // // T=0 graph.StartRun OneShot20MsCalculator::Process starts // T=10 graph.Pause // T=20 OneShot20MsCalculator::Process ends. // So graph finishes running while paused. // T=30 graph.Resume // // graph.WaitUntilDone must not block forever. CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'OneShot20MsCalculator' } )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_EXPECT_OK(graph.StartRun({})); absl::SleepFor(absl::Milliseconds(10)); graph.Pause(); absl::SleepFor(absl::Milliseconds(20)); graph.Resume(); MP_EXPECT_OK(graph.WaitUntilDone()); } // There should be no memory leaks, no error messages (requires manual // inspection of the test log), etc. TEST(CalculatorGraph, ConstructAndDestruct) { CalculatorGraph graph; } // A regression test for b/36364314. UnitDelayCalculator outputs a packet in // Open(). ErrorOnOpenCalculator fails in Open() if ERROR_ON_OPEN is true. TEST(CalculatorGraph, RecoverAfterPreviousFailInOpen) { const int max_count = 10; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CountingSourceCalculator' output_stream: 'a' input_side_packet: 'MAX_COUNT:max_count' } node { calculator: 'UnitDelayCalculator' input_stream: 'a' output_stream: 'b' } node { calculator: 'PassThroughCalculator' input_stream: 'b' output_stream: 'c' } node { calculator: 'ErrorOnOpenCalculator' input_stream: 'c' output_stream: 'd' input_side_packet: 'ERROR_ON_OPEN:fail' } node { calculator: 'IntSinkCalculator' input_stream: 'd' } )pb"); CalculatorGraph graph; MP_ASSERT_OK( graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); for (int i = 0; i < 2; ++i) { EXPECT_FALSE(graph.Run({{"fail", MakePacket(true)}}).ok()); MP_EXPECT_OK(graph.Run({{"fail", MakePacket(false)}})); } } TEST(CalculatorGraph, ReuseValidatedGraphConfig) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" output_side_packet: "foo1" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "foo1" output_side_packet: "foo2" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "input_in_run" output_side_packet: "foo3" } packet_generator { packet_generator: "StaticCounterStringGenerator" input_side_packet: "created_by_factory" input_side_packet: "input_in_initialize" input_side_packet: "input_in_run" input_side_packet: "foo3" output_side_packet: "foo4" } node { calculator: "GlobalCountSourceCalculator" input_side_packet: "global_counter" output_stream: "unused" } )pb"); ValidatedGraphConfig validated_graph; MP_ASSERT_OK(validated_graph.Initialize(config)); std::atomic global_counter(0); Packet global_counter_packet = Adopt(new auto(&global_counter)); absl::FixedArray graphs(30); for (int i = 0; i < graphs.size(); ++i) { CalculatorGraph& graph = graphs[i]; int initial_generator_count = StaticCounterStringGenerator::NumPacketsGenerated(); int initial_calculator_count = global_counter.load(); MP_ASSERT_OK(graph.Initialize( config, {{"created_by_factory", MakePacket("default string")}, {"input_in_initialize", MakePacket(10)}, {"global_counter", global_counter_packet}})); EXPECT_EQ(initial_generator_count + 2, StaticCounterStringGenerator::NumPacketsGenerated()); EXPECT_EQ(initial_calculator_count, global_counter.load()); } for (int k = 0; k < 10; ++k) { for (int i = 0; i < graphs.size(); ++i) { CalculatorGraph& graph = graphs[i]; int initial_generator_count = StaticCounterStringGenerator::NumPacketsGenerated(); int initial_calculator_count = global_counter.load(); MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket(11)}})); EXPECT_EQ(initial_generator_count + 2, StaticCounterStringGenerator::NumPacketsGenerated()); EXPECT_EQ(initial_calculator_count + GlobalCountSourceCalculator::kNumOutputPackets, global_counter.load()); } } } class TestRangeStdDevSubgraph : public Subgraph { public: absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_side_packet: 'node_converted' output_stream: 'DATA:range' output_stream: 'SUM:range_sum' output_stream: 'MEAN:range_mean' output_stream: 'STDDEV:range_stddev' node { calculator: 'RangeCalculator' output_stream: 'range' output_stream: 'range_sum' output_stream: 'range_mean' input_side_packet: 'node_converted' } node { calculator: 'StdDevCalculator' input_stream: 'DATA:range' input_stream: 'MEAN:range_mean' output_stream: 'range_stddev' } )pb"); return config; } }; REGISTER_MEDIAPIPE_GRAPH(TestRangeStdDevSubgraph); class TestMergeSaverSubgraph : public Subgraph { public: absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'DATA1:range1' input_stream: 'DATA2:range2' output_stream: 'MERGE:merge' output_stream: 'FINAL:final' node { name: 'merger' calculator: 'MergeCalculator' input_stream: 'range1' input_stream: 'range2' output_stream: 'merge' } node { calculator: 'SaverCalculator' input_stream: 'merge' output_stream: 'final' } )pb"); return config; } }; REGISTER_MEDIAPIPE_GRAPH(TestMergeSaverSubgraph); CalculatorGraphConfig GetConfigWithSubgraphs() { CalculatorGraphConfig proto = mediapipe::ParseTextProtoOrDie(R"pb( # Ensure stream name for FindOutputStreamManager output_stream: 'MERGE:merge' packet_generator { packet_generator: 'IntSplitterPacketGenerator' input_side_packet: 'node_3' output_side_packet: 'node_3_converted' } packet_generator { packet_generator: 'TaggedIntSplitterPacketGenerator' input_side_packet: 'node_5' output_side_packet: 'HIGH:unused_high' output_side_packet: 'LOW:unused_low' output_side_packet: 'PAIR:node_5_converted' } node { calculator: 'TestRangeStdDevSubgraph' input_side_packet: 'node_3_converted' output_stream: 'DATA:range3' output_stream: 'SUM:range3_sum' output_stream: 'MEAN:range3_mean' output_stream: 'STDDEV:range3_stddev' } node { calculator: 'TestRangeStdDevSubgraph' input_side_packet: 'node_5_converted' output_stream: 'DATA:range5' output_stream: 'SUM:range5_sum' output_stream: 'MEAN:range5_mean' output_stream: 'STDDEV:range5_stddev' } node { name: 'copy_range5' calculator: 'PassThroughCalculator' input_stream: 'range5' output_stream: 'range5_copy' } node { calculator: 'TestMergeSaverSubgraph' input_stream: 'DATA1:range3' input_stream: 'DATA2:range5_copy' output_stream: 'MERGE:merge' output_stream: 'FINAL:final' } node { calculator: 'TestMergeSaverSubgraph' input_stream: 'DATA1:range3_sum' input_stream: 'DATA2:range5_sum' output_stream: 'FINAL:final_sum' } node { calculator: 'TestMergeSaverSubgraph' input_stream: 'DATA1:range3_stddev' input_stream: 'DATA2:range5_stddev' output_stream: 'FINAL:final_stddev' } )pb"); return proto; } TEST(CalculatorGraph, RunsCorrectlyWithSubgraphs) { CalculatorGraph graph; CalculatorGraphConfig proto = GetConfigWithSubgraphs(); RunComprehensiveTest(&graph, proto, /*define_node_5=*/true); } TEST(CalculatorGraph, SetExecutorTwice) { // SetExecutor must not be called more than once for the same executor name. CalculatorGraph graph; MP_EXPECT_OK( graph.SetExecutor("xyz", std::make_shared(1))); MP_EXPECT_OK( graph.SetExecutor("abc", std::make_shared(1))); absl::Status status = graph.SetExecutor("xyz", std::make_shared(1)); EXPECT_EQ(status.code(), absl::StatusCode::kAlreadyExists); EXPECT_THAT(status.message(), testing::HasSubstr("xyz")); } TEST(CalculatorGraph, ReservedNameSetExecutor) { // A reserved executor name such as "__gpu" must not be used. CalculatorGraph graph; absl::Status status = graph.SetExecutor("__gpu", std::make_shared(1)); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } TEST(CalculatorGraph, ReservedNameExecutorConfig) { // A reserved executor name such as "__gpu" must not be used. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' executor { name: '__gpu' type: 'ThreadPoolExecutor' options { [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } } } node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } TEST(CalculatorGraph, ReservedNameNodeExecutor) { // A reserved executor name such as "__gpu" must not be used. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'PassThroughCalculator' executor: '__gpu' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } TEST(CalculatorGraph, NonExistentExecutor) { // Any executor used by a calculator node must either be created by the // graph (which requires an ExecutorConfig with a "type" field) or be // provided to the graph with a CalculatorGraph::SetExecutor() call. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("not declared"))); } TEST(CalculatorGraph, UndeclaredExecutor) { // Any executor used by a calculator node must be declared in an // ExecutorConfig, even if the executor is provided to the graph with a // CalculatorGraph::SetExecutor() call. CalculatorGraph graph; MP_ASSERT_OK( graph.SetExecutor("xyz", std::make_shared(1))); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("not declared"))); } TEST(CalculatorGraph, UntypedExecutorDeclaredButNotSet) { // If an executor is declared without a "type" field, it must be provided to // the graph with a CalculatorGraph::SetExecutor() call. CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' executor { name: 'xyz' } node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("SetExecutor"))); } TEST(CalculatorGraph, DuplicateExecutorConfig) { // More than one ExecutorConfig cannot have the same name. CalculatorGraph graph; MP_ASSERT_OK( graph.SetExecutor("xyz", std::make_shared(1))); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' executor { name: 'xyz' } executor { name: 'xyz' } node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("duplicate"))); } TEST(CalculatorGraph, TypedExecutorDeclaredAndSet) { // If an executor is declared with a "type" field, it must not be provided // to the graph with a CalculatorGraph::SetExecutor() call. CalculatorGraph graph; MP_ASSERT_OK( graph.SetExecutor("xyz", std::make_shared(1))); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' executor { name: 'xyz' type: 'ThreadPoolExecutor' options { [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } } } node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'in' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("SetExecutor"))); } // The graph-level num_threads field and the ExecutorConfig for the default // executor must not both be specified. TEST(CalculatorGraph, NumThreadsAndDefaultExecutorConfig) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' num_threads: 1 executor { type: 'ThreadPoolExecutor' options { [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } } } node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' input_stream: 'mid' output_stream: 'out' } )pb"); absl::Status status = graph.Initialize(config); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("num_threads"), testing::HasSubstr("default executor"))); } // The graph-level num_threads field and the ExecutorConfig for a non-default // executor may coexist. TEST(CalculatorGraph, NumThreadsAndNonDefaultExecutorConfig) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'in' num_threads: 1 executor { name: 'xyz' type: 'ThreadPoolExecutor' options { [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 } } } node { calculator: 'PassThroughCalculator' input_stream: 'in' output_stream: 'mid' } node { calculator: 'PassThroughCalculator' executor: 'xyz' input_stream: 'mid' output_stream: 'out' } )pb"); MP_EXPECT_OK(graph.Initialize(config)); } // Verifies that the application thread is used only when // "ApplicationThreadExecutor" is specified. In this test // "ApplicationThreadExecutor" is specified in the ExecutorConfig for the // default executor. TEST(CalculatorGraph, RunWithNumThreadsInExecutorConfig) { const struct { std::string executor_type; int num_threads; bool use_app_thread_is_expected; } cases[] = {{"ApplicationThreadExecutor", 0, true}, {"", 0, false}, {"ThreadPoolExecutor", 1, false}}; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( executor { options { [mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 0 } } } node { calculator: 'PthreadSelfSourceCalculator' output_stream: 'out' } )pb"); ThreadPoolExecutorOptions* default_executor_options = config.mutable_executor(0)->mutable_options()->MutableExtension( ThreadPoolExecutorOptions::ext); for (int i = 0; i < ABSL_ARRAYSIZE(cases); ++i) { default_executor_options->set_num_threads(cases[i].num_threads); config.mutable_executor(0)->clear_type(); if (cases[i].executor_type != "") { config.mutable_executor(0)->set_type(cases[i].executor_type); } CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); Packet out_packet; MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packet](const Packet& packet) { out_packet = packet; return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); EXPECT_EQ(cases[i].use_app_thread_is_expected, out_packet.Get() == pthread_self()) << "for case " << i; } } TEST(CalculatorGraph, CalculatorGraphNotInitialized) { CalculatorGraph graph; EXPECT_FALSE(graph.Run().ok()); } TEST(CalculatorGraph, SimulateAssertFailure) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( num_threads: 2 node { calculator: 'PassThroughCalculator' input_stream: 'in_a' input_stream: 'in_b' output_stream: 'out_a' output_stream: 'out_b' } input_stream: 'in_a' input_stream: 'in_b' )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); MP_EXPECT_OK(graph.WaitUntilIdle()); // End the test here to simulate an ASSERT_ failure, which will skip the // rest of the test and exit the test function immediately. The test should // not hang in the CalculatorGraph destructor. } // Verifies Calculator::InputTimestamp() returns the expected value in Open(), // Process(), and Close() for both source and non-source nodes. In this test // the source node stops the graph. TEST(CalculatorGraph, CheckInputTimestamp) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CheckInputTimestampSourceCalculator' output_stream: 'integer' } node { calculator: 'CheckInputTimestampSinkCalculator' input_stream: 'integer' } )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run()); } // Verifies Calculator::InputTimestamp() returns the expected value in Open(), // Process(), and Close() for both source and non-source nodes. In this test // the sink node stops the graph, which causes the framework to close the // source node. TEST(CalculatorGraph, CheckInputTimestamp2) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( node { calculator: 'CheckInputTimestamp2SourceCalculator' output_stream: 'integer' } node { calculator: 'CheckInputTimestamp2SinkCalculator' input_stream: 'integer' } )pb"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.Run()); } TEST(CalculatorGraph, GraphInputStreamWithTag) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "VIDEO_METADATA:video_metadata" input_stream: "max_count" node { calculator: "PassThroughCalculator" input_stream: "FIRST_INPUT:video_metadata" input_stream: "max_count" output_stream: "FIRST_INPUT:output_0" output_stream: "output_1" } )pb"); std::vector packet_dump; tool::AddVectorSink("output_0", &config, &packet_dump); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); for (int i = 0; i < 5; ++i) { MP_ASSERT_OK(graph.AddPacketToInputStream( "video_metadata", MakePacket(i).At(Timestamp(i)))); } MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); ASSERT_EQ(5, packet_dump.size()); } TEST(CalculatorGraph, GraphInputStreamBeforeStartRun) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: "VIDEO_METADATA:video_metadata" input_stream: "max_count" node { calculator: "PassThroughCalculator" input_stream: "FIRST_INPUT:video_metadata" input_stream: "max_count" output_stream: "FIRST_INPUT:output_0" output_stream: "output_1" } )pb"); std::vector packet_dump; tool::AddVectorSink("output_0", &config, &packet_dump); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); ASSERT_EQ(graph .AddPacketToInputStream("video_metadata", MakePacket(0).At(Timestamp(0))) .code(), absl::StatusCode::kFailedPrecondition); } // Returns the first packet of the input stream. class FirstPacketFilterCalculator : public CalculatorBase { public: FirstPacketFilterCalculator() {} ~FirstPacketFilterCalculator() override {} static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { if (!seen_first_packet_) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); cc->Outputs().Index(0).Close(); seen_first_packet_ = true; } return absl::OkStatus(); } private: bool seen_first_packet_ = false; }; REGISTER_CALCULATOR(FirstPacketFilterCalculator); constexpr int kDefaultMaxCount = 1000; TEST(CalculatorGraph, TestPollPacket) { CalculatorGraphConfig config; CalculatorGraphConfig::Node* node = config.add_node(); node->set_calculator("CountingSourceCalculator"); node->add_output_stream("output"); node->add_input_side_packet("MAX_COUNT:max_count"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller = graph.AddOutputStreamPoller("output"); ASSERT_TRUE(status_or_poller.ok()); OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); Packet packet; int num_packets = 0; while (poller.Next(&packet)) { EXPECT_EQ(num_packets, packet.Get()); ++num_packets; } MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); EXPECT_FALSE(poller.Next(&packet)); EXPECT_EQ(kDefaultMaxCount, num_packets); } TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) { CalculatorGraphConfig config; CalculatorGraphConfig::Node* node = config.add_node(); node->set_calculator("CountingSourceCalculator"); node->add_output_stream("output"); node->add_input_side_packet("MAX_COUNT:max_count"); for (int queue_size = 1; queue_size < 10; ++queue_size) { CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller = graph.AddOutputStreamPoller("output"); ASSERT_TRUE(status_or_poller.ok()); OutputStreamPoller poller = std::move(status_or_poller.value()); poller.SetMaxQueueSize(queue_size); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); Packet packet; int num_packets = 0; while (poller.Next(&packet)) { EXPECT_EQ(num_packets, packet.Get()); ++num_packets; } MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); EXPECT_FALSE(poller.Next(&packet)); EXPECT_EQ(kDefaultMaxCount, num_packets); } } TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) { CalculatorGraphConfig config; CalculatorGraphConfig::Node* node1 = config.add_node(); node1->set_calculator("CountingSourceCalculator"); node1->add_output_stream("stream1"); node1->add_input_side_packet("MAX_COUNT:max_count"); CalculatorGraphConfig::Node* node2 = config.add_node(); node2->set_calculator("PassThroughCalculator"); node2->add_input_stream("stream1"); node2->add_output_stream("stream2"); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller1 = graph.AddOutputStreamPoller("stream1"); ASSERT_TRUE(status_or_poller1.ok()); OutputStreamPoller poller1 = std::move(status_or_poller1.value()); auto status_or_poller2 = graph.AddOutputStreamPoller("stream2"); ASSERT_TRUE(status_or_poller2.ok()); OutputStreamPoller poller2 = std::move(status_or_poller2.value()); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); Packet packet1; Packet packet2; int num_packets1 = 0; int num_packets2 = 0; int running_pollers = 2; while (running_pollers > 0) { if (poller1.Next(&packet1)) { EXPECT_EQ(num_packets1++, packet1.Get()); } else { --running_pollers; } if (poller2.Next(&packet2)) { EXPECT_EQ(num_packets2++, packet2.Get()); } else { --running_pollers; } } MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); EXPECT_FALSE(poller1.Next(&packet1)); EXPECT_FALSE(poller2.Next(&packet2)); EXPECT_EQ(kDefaultMaxCount, num_packets1); EXPECT_EQ(kDefaultMaxCount, num_packets2); } class TimestampBoundTestCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { if (count_ % 50 == 1) { // Outputs packets at t10 and t60. cc->Outputs().Index(0).AddPacket( MakePacket(count_).At(Timestamp(count_))); } else if (count_ % 15 == 7) { cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_)); } absl::SleepFor(absl::Milliseconds(3)); ++count_; if (count_ == 110) { return tool::StatusStop(); } return absl::OkStatus(); } private: int count_ = 0; }; REGISTER_CALCULATOR(TimestampBoundTestCalculator); TEST(CalculatorGraph, TestPollPacketsWithTimestampNotification) { std::string config_str = R"( node { calculator: "TimestampBoundTestCalculator" output_stream: "foo" } )"; CalculatorGraphConfig graph_config = mediapipe::ParseTextProtoOrDie(config_str); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(graph_config)); auto status_or_poller = graph.AddOutputStreamPoller("foo", /*observe_timestamp_bounds=*/true); ASSERT_TRUE(status_or_poller.ok()); OutputStreamPoller poller = std::move(status_or_poller.value()); Packet packet; std::vector timestamps; std::vector values; MP_ASSERT_OK(graph.StartRun({})); while (poller.Next(&packet)) { if (packet.IsEmpty()) { timestamps.push_back(packet.Timestamp().Value()); } else { values.push_back(packet.Get()); } } MP_ASSERT_OK(graph.WaitUntilDone()); ASSERT_FALSE(poller.Next(&packet)); ASSERT_FALSE(timestamps.empty()); int prev_t = 0; for (auto t : timestamps) { EXPECT_TRUE(t > prev_t && t < 110); prev_t = t; } ASSERT_EQ(3, values.size()); EXPECT_EQ(1, values[0]); EXPECT_EQ(51, values[1]); EXPECT_EQ(101, values[2]); } // Ensure that when a custom input stream handler is used to handle packets from // input streams, an error message is outputted with the appropriate link to // resolve the issue when the calculator doesn't handle inputs in monotonically // increasing order of timestamps. TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) { CalculatorGraph graph; CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"pb( input_stream: 'input0' input_stream: 'input1' node { calculator: 'SimpleMuxCalculator' input_stream: 'input0' input_stream: 'input1' input_stream_handler { input_stream_handler: "ImmediateInputStreamHandler" } output_stream: 'output' } )pb"); std::vector packet_dump; tool::AddVectorSink("output", &config, &packet_dump); MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.StartRun({})); // Send packets to input stream "input0" at timestamps 0 and 1 consecutively. Timestamp input0_timestamp = Timestamp(0); MP_EXPECT_OK(graph.AddPacketToInputStream( "input0", MakePacket(1).At(input0_timestamp))); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(1, packet_dump.size()); EXPECT_EQ(1, packet_dump[0].Get()); ++input0_timestamp; MP_EXPECT_OK(graph.AddPacketToInputStream( "input0", MakePacket(3).At(input0_timestamp))); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(2, packet_dump.size()); EXPECT_EQ(3, packet_dump[1].Get()); // Send a packet to input stream "input1" at timestamp 0 after sending two // packets at timestamps 0 and 1 to input stream "input0". This will result // in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle // inputs from all streams in monotonically increasing order of timestamps. Timestamp input1_timestamp = Timestamp(0); MP_EXPECT_OK(graph.AddPacketToInputStream( "input1", MakePacket(2).At(input1_timestamp))); absl::Status run_status = graph.WaitUntilIdle(); EXPECT_THAT( run_status.ToString(), testing::AllOf( // The core problem. testing::HasSubstr("timestamp mismatch on a calculator"), testing::HasSubstr( "timestamps that are not strictly monotonically increasing"), // Link to the possible solution. testing::HasSubstr("ImmediateInputStreamHandler class comment"))); } void DoTestMultipleGraphRuns(absl::string_view input_stream_handler, bool select_packet) { std::string graph_proto = absl::StrFormat(R"( input_stream: 'input' input_stream: 'select' node { calculator: 'PassThroughCalculator' input_stream: 'input' input_stream: 'select' input_stream_handler { input_stream_handler: "%s" } output_stream: 'output' output_stream: 'select_out' } )", input_stream_handler.data()); CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(graph_proto); std::vector packet_dump; tool::AddVectorSink("output", &config, &packet_dump); CalculatorGraph graph; MP_ASSERT_OK(graph.Initialize(config)); struct Run { Timestamp timestamp; int value; }; std::vector runs = {{.timestamp = Timestamp(2000), .value = 2}, {.timestamp = Timestamp(1000), .value = 1}}; for (const Run& run : runs) { MP_ASSERT_OK(graph.StartRun({})); if (select_packet) { MP_EXPECT_OK(graph.AddPacketToInputStream( "select", MakePacket(0).At(run.timestamp))); } MP_EXPECT_OK(graph.AddPacketToInputStream( "input", MakePacket(run.value).At(run.timestamp))); MP_ASSERT_OK(graph.WaitUntilIdle()); ASSERT_EQ(1, packet_dump.size()); EXPECT_EQ(run.value, packet_dump[0].Get()); EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp()); MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); packet_dump.clear(); } } TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) { DoTestMultipleGraphRuns("BarrierInputStreamHandler", true); DoTestMultipleGraphRuns("DefaultInputStreamHandler", true); DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true); DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true); DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false); DoTestMultipleGraphRuns("MuxInputStreamHandler", true); DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true); DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true); } } // namespace } // namespace mediapipe