mediapipe/mediapipe/framework/calculator_graph_test.cc
MediaPipe Team 6c68d4c358 Internal change
PiperOrigin-RevId: 514421618
2023-03-06 08:42:19 -08:00

4717 lines
168 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_graph.h"
#include <pthread.h>
#include <atomic>
#include <ctime>
#include <deque>
#include <map>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/input_stream_handler.h"
#include "mediapipe/framework/lifetime_tracker.h"
#include "mediapipe/framework/mediapipe_options.pb.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/status_handler.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/framework/thread_pool_executor.h"
#include "mediapipe/framework/thread_pool_executor.pb.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/type_map.h"
#include "mediapipe/gpu/gpu_service.h"
namespace mediapipe {
namespace {
constexpr char kCounter2Tag[] = "COUNTER2";
constexpr char kCounter1Tag[] = "COUNTER1";
constexpr char kExtraTag[] = "EXTRA";
constexpr char kWaitSemTag[] = "WAIT_SEM";
constexpr char kPostSemTag[] = "POST_SEM";
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
constexpr char kOutputTag[] = "OUTPUT";
constexpr char kInputTag[] = "INPUT";
constexpr char kSelectTag[] = "SELECT";
using testing::ElementsAre;
using testing::HasSubstr;
// Pass packets through. Note that it calls SetOffset() in Process()
// instead of Open().
class SetOffsetInProcessCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
// Input: arbitrary Packets.
// Output: copy of the input.
cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
cc->GetCounter("PassThrough")->Increment();
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(SetOffsetInProcessCalculator);
// A Calculator that outputs the square of its input packet (an int).
class SquareIntCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
int value = cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).Add(new int(value * value), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(SquareIntCalculator);
// A Calculator that selects an output stream from "OUTPUT:0", "OUTPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
// stream, and passes the packet on the "INPUT" input stream to the selected
// output stream.
//
// This calculator is called "Timed" because it sets the next timestamp bound on
// the unselected outputs.
class DemuxTimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag(kSelectTag).Set<int>();
PacketType* data_input = &cc->Inputs().Tag(kInputTag);
data_input->SetAny();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) {
cc->Outputs().Get(id).SetSameAs(data_input);
}
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
select_input_ = cc->Inputs().GetId("SELECT", 0);
data_input_ = cc->Inputs().GetId("INPUT", 0);
output_base_ = cc->Outputs().GetId("OUTPUT", 0);
num_outputs_ = cc->Outputs().NumEntries("OUTPUT");
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
int select = cc->Inputs().Get(select_input_).Get<int>();
RET_CHECK(0 <= select && select < num_outputs_);
const Timestamp next_timestamp_bound =
cc->InputTimestamp().NextAllowedInStream();
for (int i = 0; i < num_outputs_; ++i) {
if (i == select) {
cc->Outputs()
.Get(output_base_ + i)
.AddPacket(cc->Inputs().Get(data_input_).Value());
} else {
cc->Outputs()
.Get(output_base_ + i)
.SetNextTimestampBound(next_timestamp_bound);
}
}
return absl::OkStatus();
}
private:
CollectionItemId select_input_;
CollectionItemId data_input_;
CollectionItemId output_base_;
int num_outputs_ = 0;
};
REGISTER_CALCULATOR(DemuxTimedCalculator);
// A Calculator that selects an input stream from "INPUT:0", "INPUT:1", ...,
// using the integer value (0, 1, ...) in the packet on the "SELECT" input
// stream, and passes the packet on the selected input stream to the "OUTPUT"
// output stream.
//
// This calculator is called "Timed" because it requires next timestamp bound
// propagation on the unselected inputs.
class MuxTimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag(kSelectTag).Set<int>();
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny();
++data_input_id;
for (; data_input_id < cc->Inputs().EndId("INPUT"); ++data_input_id) {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0);
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
select_input_ = cc->Inputs().GetId("SELECT", 0);
data_input_base_ = cc->Inputs().GetId("INPUT", 0);
num_data_inputs_ = cc->Inputs().NumEntries("INPUT");
output_ = cc->Outputs().GetId("OUTPUT", 0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
int select = cc->Inputs().Get(select_input_).Get<int>();
RET_CHECK(0 <= select && select < num_data_inputs_);
cc->Outputs().Get(output_).AddPacket(
cc->Inputs().Get(data_input_base_ + select).Value());
return absl::OkStatus();
}
private:
CollectionItemId select_input_;
CollectionItemId data_input_base_;
int num_data_inputs_ = 0;
CollectionItemId output_;
};
REGISTER_CALCULATOR(MuxTimedCalculator);
// A Calculator that adds the integer values in the packets in all the input
// streams and outputs the sum to the output stream.
class IntAdderCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<int>();
}
cc->Outputs().Index(0).Set<int>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
int sum = 0;
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
sum += cc->Inputs().Index(i).Get<int>();
}
cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(IntAdderCalculator);
// A Calculator that adds the float values in the packets in all the input
// streams and outputs the sum to the output stream.
class FloatAdderCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<float>();
}
cc->Outputs().Index(0).Set<float>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
float sum = 0.0;
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
sum += cc->Inputs().Index(i).Get<float>();
}
cc->Outputs().Index(0).Add(new float(sum), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(FloatAdderCalculator);
// A Calculator that multiplies the integer values in the packets in all the
// input streams and outputs the product to the output stream.
class IntMultiplierCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<int>();
}
cc->Outputs().Index(0).Set<int>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
int product = 1;
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
product *= cc->Inputs().Index(i).Get<int>();
}
cc->Outputs().Index(0).Add(new int(product), cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(IntMultiplierCalculator);
// A Calculator that multiplies the float value in an input packet by a float
// constant scalar (specified in a side packet) and outputs the product to the
// output stream.
class FloatScalarMultiplierCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<float>();
cc->Outputs().Index(0).Set<float>();
cc->InputSidePackets().Index(0).Set<float>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
scalar_ = cc->InputSidePackets().Index(0).Get<float>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
float value = cc->Inputs().Index(0).Value().Get<float>();
cc->Outputs().Index(0).Add(new float(scalar_ * value),
cc->InputTimestamp());
return absl::OkStatus();
}
private:
float scalar_;
};
REGISTER_CALCULATOR(FloatScalarMultiplierCalculator);
// A Calculator that converts an integer to a float.
class IntToFloatCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<float>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
int value = cc->Inputs().Index(0).Value().Get<int>();
cc->Outputs().Index(0).Add(new float(static_cast<float>(value)),
cc->InputTimestamp());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(IntToFloatCalculator);
template <typename OutputType>
class TypedEmptySourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).SetAny();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(new OutputType(), Timestamp::PostStream());
return tool::StatusStop();
}
};
typedef TypedEmptySourceCalculator<std::string> StringEmptySourceCalculator;
typedef TypedEmptySourceCalculator<int> IntEmptySourceCalculator;
REGISTER_CALCULATOR(StringEmptySourceCalculator);
REGISTER_CALCULATOR(IntEmptySourceCalculator);
template <typename InputType>
class TypedSinkCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<InputType>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
return absl::OkStatus();
}
};
typedef TypedSinkCalculator<std::string> StringSinkCalculator;
typedef TypedSinkCalculator<int> IntSinkCalculator;
REGISTER_CALCULATOR(StringSinkCalculator);
REGISTER_CALCULATOR(IntSinkCalculator);
// Output kNumOutputPackets packets, the value of each being the next
// value in the counter provided as an input side packet. An optional
// second input side packet will, if true, cause this calculator to
// output the first of the kNumOutputPackets packets during Open.
class GlobalCountSourceCalculator : public CalculatorBase {
public:
static const int kNumOutputPackets;
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).Set<std::atomic<int>*>();
if (cc->InputSidePackets().NumEntries() >= 2) {
cc->InputSidePackets().Index(1).Set<bool>();
}
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().NumEntries() >= 2 &&
cc->InputSidePackets().Index(1).Get<bool>()) {
OutputOne(cc);
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
OutputOne(cc);
if (local_count_ >= kNumOutputPackets) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
private:
void OutputOne(CalculatorContext* cc) {
std::atomic<int>* counter =
cc->InputSidePackets().Index(0).Get<std::atomic<int>*>();
int count = counter->fetch_add(1, std::memory_order_relaxed);
cc->Outputs().Index(0).Add(new int(count), Timestamp(local_count_));
++local_count_;
}
int64 local_count_ = 0;
};
const int GlobalCountSourceCalculator::kNumOutputPackets = 5;
REGISTER_CALCULATOR(GlobalCountSourceCalculator);
static const int kTestSequenceLength = 15;
// Outputs the integers 0, 1, 2, 3, ..., 14, all with timestamp 0.
class TestSequence1SourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(new int(count_), Timestamp(0));
++count_;
++num_outputs_;
if (num_outputs_ >= kTestSequenceLength) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
private:
int count_ = 0;
int num_outputs_ = 0;
};
REGISTER_CALCULATOR(TestSequence1SourceCalculator);
// Outputs the integers 1, 2, 3, 4, ..., 15, with decreasing timestamps
// 100, 99, 98, 97, ....
class TestSequence2SourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(new int(count_), Timestamp(timestamp_));
++count_;
++num_outputs_;
--timestamp_;
if (num_outputs_ >= kTestSequenceLength) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
private:
int count_ = 1;
int num_outputs_ = 0;
int timestamp_ = 100;
};
REGISTER_CALCULATOR(TestSequence2SourceCalculator);
// Outputs the integers 0, 1, 2 repeatedly for a total of 15 outputs.
class Modulo3SourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(new int(count_ % 3), Timestamp(count_ % 3));
++count_;
++num_outputs_;
if (num_outputs_ >= kTestSequenceLength) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
private:
int count_ = 0;
int num_outputs_ = 0;
};
REGISTER_CALCULATOR(Modulo3SourceCalculator);
// A source calculator that outputs 100 packets all at once and stops. The
// number of output packets (100) is deliberately chosen to be equal to
// max_queue_size, which fills the input streams connected to this source
// calculator and causes the MediaPipe scheduler to throttle this source
// calculator.
class OutputAllSourceCalculator : public CalculatorBase {
public:
static constexpr int kNumOutputPackets = 100;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
for (int i = 0; i < kNumOutputPackets; ++i) {
cc->Outputs().Index(0).Add(new int(0), Timestamp(i));
}
return tool::StatusStop();
}
};
REGISTER_CALCULATOR(OutputAllSourceCalculator);
// A source calculator that outputs one packet at a time. The total number of
// output packets needs to be large enough to eventually fill an input stream
// connected to this source calculator and to force the MediaPipe scheduler to
// run this source calculator as a throttled source when the graph cannot make
// progress.
class OutputOneAtATimeSourceCalculator : public CalculatorBase {
public:
static constexpr int kNumOutputPackets = 1000;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (index_ < kNumOutputPackets) {
cc->Outputs().Index(0).Add(new int(0), Timestamp(index_));
++index_;
return absl::OkStatus();
}
return tool::StatusStop();
}
private:
int index_ = 0;
};
REGISTER_CALCULATOR(OutputOneAtATimeSourceCalculator);
// A calculator that passes through one out of every 101 input packets and
// discards the rest. The decimation ratio (101) is carefully chosen to be
// greater than max_queue_size (100) so that an input stream parallel to the
// input stream connected to this calculator can become full.
class DecimatorCalculator : public CalculatorBase {
public:
static constexpr int kDecimationRatio = 101;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (index_ % kDecimationRatio == 0) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
}
++index_;
return absl::OkStatus();
}
private:
int index_ = 0;
};
REGISTER_CALCULATOR(DecimatorCalculator);
// An error will be produced in Open() if ERROR_ON_OPEN is true. Otherwise,
// this calculator simply passes its input packets through, unchanged.
class ErrorOnOpenCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
return absl::NotFoundError("expected error");
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(ErrorOnOpenCalculator);
// A calculator that outputs an initial packet of value 0 at time 0 in the
// Open() method, and then delays each input packet by one time unit in the
// Process() method. The input stream and output stream have the integer type.
class UnitDelayCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->Outputs().Index(0).Add(new int(0), Timestamp(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
const Packet& packet = cc->Inputs().Index(0).Value();
cc->Outputs().Index(0).AddPacket(
packet.At(packet.Timestamp().NextAllowedInStream()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(UnitDelayCalculator);
class UnitDelayUntimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->Outputs().Index(0).Add(new int(0), Timestamp::Min());
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(UnitDelayUntimedCalculator);
class FloatUnitDelayCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<float>();
cc->Outputs().Index(0).Set<float>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->Outputs().Index(0).Add(new float(0.0), Timestamp(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
const Packet& packet = cc->Inputs().Index(0).Value();
cc->Outputs().Index(0).AddPacket(
packet.At(packet.Timestamp().NextAllowedInStream()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(FloatUnitDelayCalculator);
// A sink calculator that asserts its input stream is empty in Open() and
// discards input packets in Process().
class AssertEmptyInputInOpenCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
RET_CHECK(cc->Inputs().Index(0).Value().IsEmpty());
RET_CHECK_EQ(cc->Inputs().Index(0).Value().Timestamp(), Timestamp::Unset());
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
};
REGISTER_CALCULATOR(AssertEmptyInputInOpenCalculator);
// A slow sink calculator that expects 10 input integers with the values
// 0, 1, ..., 9.
class SlowCountingSinkCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
absl::SleepFor(absl::Milliseconds(10));
int value = cc->Inputs().Index(0).Get<int>();
CHECK_EQ(value, counter_);
++counter_;
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) override {
CHECK_EQ(10, counter_);
return absl::OkStatus();
}
private:
int counter_ = 0;
};
REGISTER_CALCULATOR(SlowCountingSinkCalculator);
template <typename InputType>
class TypedStatusHandler : public StatusHandler {
public:
~TypedStatusHandler() override = 0;
static absl::Status FillExpectations(
const MediaPipeOptions& extendable_options,
PacketTypeSet* input_side_packets) {
input_side_packets->Index(0).Set<InputType>();
return absl::OkStatus();
}
static absl::Status HandlePreRunStatus(
const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& pre_run_status) {
return absl::OkStatus();
}
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& run_status) {
return absl::OkStatus();
}
};
typedef TypedStatusHandler<std::string> StringStatusHandler;
typedef TypedStatusHandler<uint32> Uint32StatusHandler;
REGISTER_STATUS_HANDLER(StringStatusHandler);
REGISTER_STATUS_HANDLER(Uint32StatusHandler);
// A string generator that will succeed.
class StaticCounterStringGenerator : public PacketGenerator {
public:
static absl::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
for (int i = 0; i < input_side_packets->NumEntries(); ++i) {
input_side_packets->Index(i).SetAny();
}
output_side_packets->Index(0).Set<std::string>();
return absl::OkStatus();
}
static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets,
PacketSet* output_side_packets) {
output_side_packets->Index(0) = MakePacket<std::string>("fixed_string");
++num_packets_generated_;
return absl::OkStatus();
}
static int NumPacketsGenerated() { return num_packets_generated_; }
private:
static int num_packets_generated_;
};
int StaticCounterStringGenerator::num_packets_generated_ = 0;
REGISTER_PACKET_GENERATOR(StaticCounterStringGenerator);
// A failing PacketGenerator and Calculator to verify that status handlers get
// called. Both claim to output strings but instead always fail.
class FailingPacketGenerator : public PacketGenerator {
public:
static absl::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
for (int i = 0; i < input_side_packets->NumEntries(); ++i) {
input_side_packets->Index(i).SetAny();
}
output_side_packets->Index(0).Set<std::string>();
return absl::OkStatus();
}
static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets,
PacketSet* output_side_packets) {
return absl::UnknownError("this always fails.");
}
};
REGISTER_PACKET_GENERATOR(FailingPacketGenerator);
// Passes the integer through if it is positive.
class EnsurePositivePacketGenerator : public PacketGenerator {
public:
static absl::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
for (int i = 0; i < input_side_packets->NumEntries(); ++i) {
input_side_packets->Index(i).Set<int>();
output_side_packets->Index(i).Set<int>();
}
return absl::OkStatus();
}
static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets,
PacketSet* output_side_packets) {
for (int i = 0; i < input_side_packets.NumEntries(); ++i) {
if (input_side_packets.Index(i).Get<int>() > 0) {
output_side_packets->Index(i) = input_side_packets.Index(i);
} else {
return absl::UnknownError(
absl::StrCat("Integer ", i, " was not positive."));
}
}
return absl::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(EnsurePositivePacketGenerator);
// A Status handler which takes an int side packet and fails in pre run
// if that packet is FailableStatusHandler::kFailPreRun and fails post
// run if that value is FailableStatusHandler::kFailPostRun. If the
// int is any other value then no failures occur.
class FailableStatusHandler : public StatusHandler {
public:
enum {
kOk = 0,
kFailPreRun = 1,
kFailPostRun = 2,
};
static absl::Status FillExpectations(
const MediaPipeOptions& extendable_options,
PacketTypeSet* input_side_packets) {
input_side_packets->Index(0).Set<int>();
return absl::OkStatus();
}
static absl::Status HandlePreRunStatus(
const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, const absl::Status& pre_run_status) {
if (input_side_packets.Index(0).Get<int>() == kFailPreRun) {
return absl::UnknownError(
"FailableStatusHandler failing pre run as intended.");
} else {
return absl::OkStatus();
}
}
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets,
const absl::Status& run_status) {
if (input_side_packets.Index(0).Get<int>() == kFailPostRun) {
return absl::UnknownError(
"FailableStatusHandler failing post run as intended.");
} else {
return absl::OkStatus();
}
}
};
REGISTER_STATUS_HANDLER(FailableStatusHandler);
class FailingSourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<std::string>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
return absl::UnknownError("this always fails.");
}
};
REGISTER_CALCULATOR(FailingSourceCalculator);
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_ += amount; }
private:
std::atomic<int64_t> supply_;
};
// This calculator posts to a semaphore when it starts its Process method,
// and waits on a different semaphore before returning from Process.
// This allows a test to run code when the calculator is running Process
// without having to depend on any specific timing.
class SemaphoreCalculator : public CalculatorBase {
public:
using Semaphore = AtomicSemaphore;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag(kPostSemTag).Set<Semaphore*>();
cc->InputSidePackets().Tag(kWaitSemTag).Set<Semaphore*>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) override {
cc->InputSidePackets().Tag(kPostSemTag).Get<Semaphore*>()->Release(1);
cc->InputSidePackets().Tag(kWaitSemTag).Get<Semaphore*>()->Acquire(1);
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(SemaphoreCalculator);
// A calculator that has no input streams and output streams, runs only once,
// and takes 20 milliseconds to run.
class OneShot20MsCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
absl::SleepFor(absl::Milliseconds(20));
return tool::StatusStop();
}
};
REGISTER_CALCULATOR(OneShot20MsCalculator);
// A source calculator that outputs a packet containing the return value of
// pthread_self() (the pthread id of the current thread).
class PthreadSelfSourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<pthread_t>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).AddPacket(
MakePacket<pthread_t>(pthread_self()).At(Timestamp(0)));
return tool::StatusStop();
}
};
REGISTER_CALCULATOR(PthreadSelfSourceCalculator);
// A source calculator for testing the Calculator::InputTimestamp() method.
// It outputs five int packets with timestamps 0, 1, 2, 3, 4.
class CheckInputTimestampSourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Unstarted() in Open() for both source
// and non-source nodes.
absl::Status Open(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted());
return absl::OkStatus();
}
// InputTimestamp() always returns Timestamp(0) in Process() for source
// nodes.
absl::Status Process(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0));
cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_));
++count_;
if (count_ >= 5) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
// InputTimestamp() returns Timestamp::Done() in Close() for both source
// and non-source nodes.
absl::Status Close(CalculatorContext* cc) final {
// Must use CHECK instead of RET_CHECK in Close(), because the framework
// may call the Close() method of a source node with .IgnoreError().
CHECK_EQ(cc->InputTimestamp(), Timestamp::Done());
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(CheckInputTimestampSourceCalculator);
// A sink calculator for testing the Calculator::InputTimestamp() method.
// It expects to consume the output of a CheckInputTimestampSourceCalculator.
class CheckInputTimestampSinkCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Unstarted() in Open() for both source
// and non-source nodes.
absl::Status Open(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted());
return absl::OkStatus();
}
// InputTimestamp() returns the timestamp of input packets in Process() for
// non-source nodes.
absl::Status Process(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(),
cc->Inputs().Index(0).Value().Timestamp());
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_));
++count_;
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Done() in Close() for both source
// and non-source nodes.
absl::Status Close(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done());
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(CheckInputTimestampSinkCalculator);
// A source calculator for testing the Calculator::InputTimestamp() method.
// It outputs int packets with timestamps 0, 1, 2, ... until being closed by
// the framework.
class CheckInputTimestamp2SourceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Unstarted() in Open() for both source
// and non-source nodes.
absl::Status Open(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted());
return absl::OkStatus();
}
// InputTimestamp() always returns Timestamp(0) in Process() for source
// nodes.
absl::Status Process(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0));
cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_));
++count_;
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Done() in Close() for both source
// and non-source nodes.
absl::Status Close(CalculatorContext* cc) final {
// Must use CHECK instead of RET_CHECK in Close(), because the framework
// may call the Close() method of a source node with .IgnoreError().
CHECK_EQ(cc->InputTimestamp(), Timestamp::Done());
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(CheckInputTimestamp2SourceCalculator);
// A sink calculator for testing the Calculator::InputTimestamp() method.
// It expects to consume the output of a CheckInputTimestamp2SourceCalculator.
// It returns tool::StatusStop() after consuming five input packets.
class CheckInputTimestamp2SinkCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
return absl::OkStatus();
}
// InputTimestamp() returns Timestamp::Unstarted() in Open() for both source
// and non-source nodes.
absl::Status Open(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted());
return absl::OkStatus();
}
// InputTimestamp() returns the timestamp of input packets in Process() for
// non-source nodes.
absl::Status Process(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(),
cc->Inputs().Index(0).Value().Timestamp());
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_));
++count_;
if (count_ >= 5) {
return tool::StatusStop();
} else {
return absl::OkStatus();
}
}
// InputTimestamp() returns Timestamp::Done() in Close() for both source
// and non-source nodes.
absl::Status Close(CalculatorContext* cc) final {
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done());
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator);
// A calculator checks if either of two input streams contains a packet and
// sends the packet to the single output stream with the same timestamp.
class SimpleMuxCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Index(0).SetAny();
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
data_input_base_ = cc->Inputs().BeginId();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
int select_packet_index = -1;
if (!cc->Inputs().Index(0).IsEmpty()) {
select_packet_index = 0;
} else if (!cc->Inputs().Index(1).IsEmpty()) {
select_packet_index = 1;
}
if (select_packet_index != -1) {
cc->Outputs().Index(0).AddPacket(
cc->Inputs().Get(data_input_base_ + select_packet_index).Value());
}
return absl::OkStatus();
}
private:
CollectionItemId data_input_base_;
};
REGISTER_CALCULATOR(SimpleMuxCalculator);
// Mock status handler that reports the number of times HandleStatus was called
// by modifying the int in its input side packet.
class IncrementingStatusHandler : public StatusHandler {
public:
static absl::Status FillExpectations(
const MediaPipeOptions& extendable_options,
PacketTypeSet* input_side_packets) {
input_side_packets->Tag(kExtraTag).SetAny().Optional();
input_side_packets->Tag(kCounter1Tag).Set<std::unique_ptr<int>>();
input_side_packets->Tag(kCounter2Tag).Set<std::unique_ptr<int>>();
return absl::OkStatus();
}
static absl::Status HandlePreRunStatus(
const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& pre_run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter1Tag));
(*counter)++;
return pre_run_status_result_;
}
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter2Tag));
(*counter)++;
return post_run_status_result_;
}
static void SetPreRunStatusResult(const absl::Status& status) {
pre_run_status_result_ = status;
}
static void SetPostRunStatusResult(const absl::Status& status) {
post_run_status_result_ = status;
}
private:
// Return values of HandlePreRunStatus() and HandleSTatus(), respectively.
static absl::Status pre_run_status_result_;
static absl::Status post_run_status_result_;
};
absl::Status IncrementingStatusHandler::pre_run_status_result_ =
absl::OkStatus();
absl::Status IncrementingStatusHandler::post_run_status_result_ =
absl::OkStatus();
REGISTER_STATUS_HANDLER(IncrementingStatusHandler);
// A simple executor that runs tasks directly on the current thread.
// NOTE: If CurrentThreadExecutor is used, some CalculatorGraph methods may
// behave differently. For example, CalculatorGraph::StartRun will run the
// graph rather than returning immediately after starting the graph.
// Similarly, CalculatorGraph::AddPacketToInputStream will run the graph
// (until it's idle) rather than returning immediately after adding the packet
// to the graph input stream.
class CurrentThreadExecutor : public Executor {
public:
~CurrentThreadExecutor() override {
CHECK(!executing_);
CHECK(tasks_.empty());
}
void Schedule(std::function<void()> task) override {
if (executing_) {
// Queue the task for later execution (after the currently-running task
// returns) rather than running the task immediately. This is especially
// important for a source node (which can be rescheduled immediately after
// running) to avoid an indefinitely-deep call stack.
tasks_.emplace_back(std::move(task));
} else {
CHECK(tasks_.empty());
executing_ = true;
task();
while (!tasks_.empty()) {
task = tasks_.front();
tasks_.pop_front();
task();
}
executing_ = false;
}
}
private:
// True if the executor is executing tasks.
bool executing_ = false;
// The tasks to execute.
std::deque<std::function<void()>> tasks_;
};
// Returns a CalculatorGraphConfig used by tests.
CalculatorGraphConfig GetConfig() {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
# The graph configuration. We list the nodes in an arbitrary (not
# topologically-sorted) order to verify that CalculatorGraph can
# handle such configurations.
node {
calculator: "RangeCalculator"
output_stream: "range3"
output_stream: "range3_sum"
output_stream: "range3_mean"
input_side_packet: "node_3_converted"
}
node {
calculator: "RangeCalculator"
output_stream: "range5"
output_stream: "range5_sum"
output_stream: "range5_mean"
input_side_packet: "node_5_converted"
}
node {
calculator: "MergeCalculator"
input_stream: "range3"
input_stream: "range5_copy"
output_stream: "merge"
}
node {
calculator: "MergeCalculator"
input_stream: "range3_sum"
input_stream: "range5_sum"
output_stream: "merge_sum"
}
node {
calculator: "PassThroughCalculator"
input_stream: "range3_stddev"
input_stream: "range5_stddev"
output_stream: "range3_stddev_2"
output_stream: "range5_stddev_2"
}
node {
calculator: "PassThroughCalculator"
input_stream: "A:range3_stddev_2"
input_stream: "range5_stddev_2"
output_stream: "A:range3_stddev_3"
output_stream: "range5_stddev_3"
}
node {
calculator: "PassThroughCalculator"
input_stream: "B:range3_stddev_3"
input_stream: "B:1:range5_stddev_3"
output_stream: "B:range3_stddev_4"
output_stream: "B:1:range5_stddev_4"
}
node {
calculator: "MergeCalculator"
input_stream: "range3_stddev_4"
input_stream: "range5_stddev_4"
output_stream: "merge_stddev"
}
node {
calculator: "StdDevCalculator"
input_stream: "DATA:range3"
input_stream: "MEAN:range3_mean"
output_stream: "range3_stddev"
}
node {
calculator: "StdDevCalculator"
input_stream: "DATA:range5"
input_stream: "MEAN:range5_mean"
output_stream: "range5_stddev"
}
node {
name: "copy_range5"
calculator: "PassThroughCalculator"
input_stream: "range5"
output_stream: "range5_copy"
}
node {
calculator: "SaverCalculator"
input_stream: "merge"
output_stream: "final"
}
node {
calculator: "SaverCalculator"
input_stream: "merge_sum"
output_stream: "final_sum"
}
node {
calculator: "SaverCalculator"
input_stream: "merge_stddev"
output_stream: "final_stddev"
}
packet_generator {
packet_generator: "IntSplitterPacketGenerator"
input_side_packet: "node_3"
output_side_packet: "node_3_converted"
}
packet_generator {
packet_generator: "TaggedIntSplitterPacketGenerator"
input_side_packet: "node_5"
output_side_packet: "HIGH:unused_high"
output_side_packet: "LOW:unused_low"
output_side_packet: "PAIR:node_5_converted"
}
)pb");
return config;
}
// |graph| points to an empty CalculatorGraph object created by the default
// constructor, before CalculatorGraph::Initialize() is called.
void RunComprehensiveTest(CalculatorGraph* graph,
const CalculatorGraphConfig& the_config,
bool define_node_5) {
CalculatorGraphConfig proto(the_config);
Packet dumped_final_sum_packet;
Packet dumped_final_packet;
Packet dumped_final_stddev_packet;
tool::AddPostStreamPacketSink("final", &proto, &dumped_final_packet);
tool::AddPostStreamPacketSink("final_sum", &proto, &dumped_final_sum_packet);
tool::AddPostStreamPacketSink("final_stddev", &proto,
&dumped_final_stddev_packet);
MP_ASSERT_OK(graph->Initialize(proto));
std::map<std::string, Packet> extra_side_packets;
extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 3)));
if (define_node_5) {
extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5)));
}
// Call graph->Run() several times, to make sure that the appropriate
// cleanup happens between iterations.
for (int iteration = 0; iteration < 2; ++iteration) {
LOG(INFO) << "Loop iteration " << iteration;
dumped_final_sum_packet = Packet();
dumped_final_stddev_packet = Packet();
dumped_final_packet = Packet();
MP_ASSERT_OK(graph->Run(extra_side_packets));
// The merger will output the timestamp and all ints output from
// the range calculators. The saver will concatenate together the
// strings with a '/' deliminator.
EXPECT_EQ(
"Timestamp(0) 300 500/"
"Timestamp(3) 301 empty/"
"Timestamp(5) empty 501/"
"Timestamp(6) 302 empty/"
"Timestamp(9) 303 empty/"
"Timestamp(10) empty 502/"
"Timestamp(12) 304 empty/"
"Timestamp(15) 305 503",
dumped_final_packet.Get<std::string>());
// Verify that the headers got set correctly.
EXPECT_EQ(
"RangeCalculator3 RangeCalculator5",
graph->FindOutputStreamManager("merge")->Header().Get<std::string>());
// Verify that sum packets get correctly processed.
// (The first is a sum of all the 3's output and the second of all
// the 5's).
EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 1815 2006"),
dumped_final_sum_packet.Get<std::string>());
EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory()
->GetCounter("copy_range5-PassThrough")
->Get());
// Verify that stddev packets get correctly processed.
// The standard deviation computed as:
// sqrt(sum((x-mean(x))**2 / length(x)))
// for x = 300:305 is 1.707825 (multiplied by 100 and rounded it is 171)
// for x = 500:503 is 1.118034 (multiplied by 100 and rounded it is 112)
EXPECT_EQ(absl::StrCat(Timestamp::PostStream().DebugString(), " 171 112"),
dumped_final_stddev_packet.Get<std::string>());
EXPECT_EQ(4 * (iteration + 1), graph->GetCounterFactory()
->GetCounter("copy_range5-PassThrough")
->Get());
}
LOG(INFO) << "After Loop Runs.";
// Verify that the graph can still run (but not successfully) when
// one of the nodes is caused to fail.
extra_side_packets.clear();
extra_side_packets.emplace("node_3", Adopt(new uint64((15LL << 32) | 0)));
if (define_node_5) {
extra_side_packets.emplace("node_5", Adopt(new uint64((15LL << 32) | 5)));
}
dumped_final_sum_packet = Packet();
dumped_final_stddev_packet = Packet();
dumped_final_packet = Packet();
LOG(INFO) << "Expect an error to be logged here.";
ASSERT_FALSE(graph->Run(extra_side_packets).ok());
LOG(INFO) << "Error should have been logged.";
}
TEST(CalculatorGraph, BadInitialization) {
CalculatorGraphConfig proto = GetConfig();
CalculatorGraph graph;
// Force the config to have a missing Calculator.
proto.mutable_node(1)->clear_calculator();
ASSERT_FALSE(graph.Initialize(proto).ok());
}
TEST(CalculatorGraph, BadRun) {
CalculatorGraphConfig proto = GetConfig();
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(proto));
// Don't set the input side packets.
EXPECT_FALSE(graph.Run().ok());
}
TEST(CalculatorGraph, RunsCorrectly) {
CalculatorGraph graph;
CalculatorGraphConfig proto = GetConfig();
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
TEST(CalculatorGraph, RunsCorrectlyOnApplicationThread) {
CalculatorGraph graph;
CalculatorGraphConfig proto = GetConfig();
// Force application thread to be used.
proto.set_num_threads(0);
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
TEST(CalculatorGraph, RunsCorrectlyWithExternalExecutor) {
CalculatorGraph graph;
MP_ASSERT_OK(graph.SetExecutor("", std::make_shared<ThreadPoolExecutor>(1)));
CalculatorGraphConfig proto = GetConfig();
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
// This test verifies that the MediaPipe framework calls Executor::AddTask()
// without holding any mutex, because CurrentThreadExecutor::AddTask() may
// result in a recursive call to itself.
TEST(CalculatorGraph, RunsCorrectlyWithCurrentThreadExecutor) {
CalculatorGraph graph;
MP_ASSERT_OK(
graph.SetExecutor("", std::make_shared<CurrentThreadExecutor>()));
CalculatorGraphConfig proto = GetConfig();
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
TEST(CalculatorGraph, RunsCorrectlyWithNonDefaultExecutors) {
CalculatorGraph graph;
// Add executors "second" and "third".
MP_ASSERT_OK(
graph.SetExecutor("second", std::make_shared<ThreadPoolExecutor>(1)));
MP_ASSERT_OK(
graph.SetExecutor("third", std::make_shared<ThreadPoolExecutor>(1)));
CalculatorGraphConfig proto = GetConfig();
ExecutorConfig* executor = proto.add_executor();
executor->set_name("second");
executor = proto.add_executor();
executor->set_name("third");
for (int i = 0; i < proto.node_size(); ++i) {
switch (i % 3) {
case 0:
// Use the default executor.
break;
case 1:
proto.mutable_node(i)->set_executor("second");
break;
case 2:
proto.mutable_node(i)->set_executor("third");
break;
}
}
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
TEST(CalculatorGraph, RunsCorrectlyWithMultipleExecutors) {
CalculatorGraph graph;
// Add executors "second" and "third".
CalculatorGraphConfig proto = GetConfig();
ExecutorConfig* executor = proto.add_executor();
executor->set_name("second");
executor->set_type("ThreadPoolExecutor");
MediaPipeOptions* options = executor->mutable_options();
ThreadPoolExecutorOptions* extension =
options->MutableExtension(ThreadPoolExecutorOptions::ext);
extension->set_num_threads(1);
executor = proto.add_executor();
executor->set_name("third");
executor->set_type("ThreadPoolExecutor");
options = executor->mutable_options();
extension = options->MutableExtension(ThreadPoolExecutorOptions::ext);
extension->set_num_threads(1);
for (int i = 0; i < proto.node_size(); ++i) {
switch (i % 3) {
case 0:
// Use the default executor.
break;
case 1:
proto.mutable_node(i)->set_executor("second");
break;
case 2:
proto.mutable_node(i)->set_executor("third");
break;
}
}
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
// Packet generator for an arbitrary unit64 packet.
class Uint64PacketGenerator : public PacketGenerator {
public:
static absl::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
output_side_packets->Index(0).Set<uint64>();
return absl::OkStatus();
}
static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets,
PacketSet* output_side_packets) {
output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5));
return absl::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(Uint64PacketGenerator);
TEST(CalculatorGraph, GeneratePacket) {
CalculatorGraph graph;
CalculatorGraphConfig proto = GetConfig();
PacketGeneratorConfig* generator = proto.add_packet_generator();
generator->set_packet_generator("Uint64PacketGenerator");
generator->add_output_side_packet("node_5");
RunComprehensiveTest(&graph, proto, false);
}
TEST(CalculatorGraph, TypeMismatch) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->add_output_stream("stream_a");
node = config.add_node();
node->add_input_stream("stream_a");
std::unique_ptr<CalculatorGraph> graph;
// Type matches, expect success.
config.mutable_node(0)->set_calculator("StringEmptySourceCalculator");
config.mutable_node(1)->set_calculator("StringSinkCalculator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
MP_EXPECT_OK(graph->Run());
graph.reset(nullptr);
// Type matches, expect success.
config.mutable_node(0)->set_calculator("IntEmptySourceCalculator");
config.mutable_node(1)->set_calculator("IntSinkCalculator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
MP_EXPECT_OK(graph->Run());
graph.reset(nullptr);
// Type mismatch, expect non-crashing failure.
config.mutable_node(0)->set_calculator("StringEmptySourceCalculator");
config.mutable_node(1)->set_calculator("IntSinkCalculator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
EXPECT_FALSE(graph->Run().ok());
graph.reset(nullptr);
// Type mismatch, expect non-crashing failure.
config.mutable_node(0)->set_calculator("IntEmptySourceCalculator");
config.mutable_node(1)->set_calculator("StringSinkCalculator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
EXPECT_FALSE(graph->Run().ok());
graph.reset(nullptr);
}
TEST(CalculatorGraph, LayerOrdering) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node;
node = config.add_node();
node->set_calculator("GlobalCountSourceCalculator");
node->add_input_side_packet("global_counter");
node->add_output_stream("count_layer_0_node_0");
node->set_source_layer(0);
node = config.add_node();
node->set_calculator("GlobalCountSourceCalculator");
node->add_input_side_packet("global_counter");
node->add_output_stream("count_layer_1_node_0");
node->set_source_layer(1);
node = config.add_node();
node->set_calculator("GlobalCountSourceCalculator");
node->add_input_side_packet("global_counter");
node->add_output_stream("count_layer_1_node_1");
node->set_source_layer(1);
node = config.add_node();
node->set_calculator("GlobalCountSourceCalculator");
node->add_input_side_packet("global_counter");
node->add_output_stream("count_layer_2_node_0");
node->set_source_layer(2);
// Set num threads to 1 because we rely on sequential execution for this test.
config.set_num_threads(1);
std::vector<Packet> dump_layer_0_node_0;
std::vector<Packet> dump_layer_1_node_0;
std::vector<Packet> dump_layer_1_node_1;
std::vector<Packet> dump_layer_2_node_0;
tool::AddVectorSink("count_layer_0_node_0", &config, &dump_layer_0_node_0);
tool::AddVectorSink("count_layer_1_node_0", &config, &dump_layer_1_node_0);
tool::AddVectorSink("count_layer_1_node_1", &config, &dump_layer_1_node_1);
tool::AddVectorSink("count_layer_2_node_0", &config, &dump_layer_2_node_0);
auto graph = absl::make_unique<CalculatorGraph>();
std::atomic<int> global_counter(0);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
MP_ASSERT_OK(graph->Initialize(config));
MP_ASSERT_OK(graph->Run(input_side_packets));
graph.reset(nullptr);
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
dump_layer_0_node_0.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
dump_layer_1_node_0.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
dump_layer_1_node_1.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
dump_layer_2_node_0.size());
// Check layer 0.
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
EXPECT_EQ(i, dump_layer_0_node_0[i].Get<int>());
EXPECT_EQ(Timestamp(i), dump_layer_0_node_0[i].Timestamp());
}
// Check layer 1 is interleaved (arbitrarily).
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 ==
dump_layer_1_node_0[i].Get<int>() ||
GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 ==
dump_layer_1_node_0[i].Get<int>());
EXPECT_TRUE(GlobalCountSourceCalculator::kNumOutputPackets + i * 2 ==
dump_layer_1_node_1[i].Get<int>() ||
GlobalCountSourceCalculator::kNumOutputPackets + i * 2 + 1 ==
dump_layer_1_node_1[i].Get<int>());
EXPECT_EQ(Timestamp(i), dump_layer_1_node_0[i].Timestamp());
EXPECT_EQ(Timestamp(i), dump_layer_1_node_1[i].Timestamp());
}
// Check layer 2.
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
EXPECT_EQ(3 * GlobalCountSourceCalculator::kNumOutputPackets + i,
dump_layer_2_node_0[i].Get<int>());
EXPECT_EQ(Timestamp(i), dump_layer_2_node_0[i].Timestamp());
}
EXPECT_EQ(
20,
input_side_packets["global_counter"].Get<std::atomic<int>*>()->load());
}
// Tests for status handler input verification.
TEST(CalculatorGraph, StatusHandlerInputVerification) {
// Status handlers with all inputs present should be OK.
auto graph = absl::make_unique<CalculatorGraph>();
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
packet_generator {
packet_generator: "StaticCounterStringGenerator"
output_side_packet: "created_by_factory"
}
packet_generator {
packet_generator: "TaggedIntSplitterPacketGenerator"
input_side_packet: "a_uint64"
output_side_packet: "HIGH:generated_by_generator"
output_side_packet: "LOW:unused_low"
output_side_packet: "PAIR:unused_pair"
}
status_handler {
status_handler: "Uint32StatusHandler"
input_side_packet: "generated_by_generator"
}
status_handler {
status_handler: "StringStatusHandler"
input_side_packet: "created_by_factory"
}
status_handler {
status_handler: "StringStatusHandler"
input_side_packet: "extra_string"
}
)pb");
MP_ASSERT_OK(graph->Initialize(config));
Packet extra_string = Adopt(new std::string("foo"));
Packet a_uint64 = Adopt(new uint64(0));
MP_EXPECT_OK(
graph->Run({{"extra_string", extra_string}, {"a_uint64", a_uint64}}));
// Should fail verification when missing a required input. The generator is
// OK, but the StringStatusHandler is missing its input.
EXPECT_FALSE(graph->Run({{"a_uint64", a_uint64}}).ok());
// Should fail verification when the type of an already created packet is
// wrong. Here we give the uint64 packet instead of the string packet to the
// StringStatusHandler.
EXPECT_FALSE(
graph->Run({{"extra_string", a_uint64}, {"a_uint64", a_uint64}}).ok());
// Should fail verification when the type of a packet generated by a base
// packet factory is wrong. Everything is correct except we add a status
// handler expecting a uint32 but give it the string from the packet factory.
auto* invalid_handler = config.add_status_handler();
invalid_handler->set_status_handler("Uint32StatusHandler");
invalid_handler->add_input_side_packet("created_by_factory");
graph.reset(new CalculatorGraph());
absl::Status status = graph->Initialize(config);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("Uint32StatusHandler"),
// The problematic input side packet.
testing::HasSubstr("created_by_factory"),
// Actual type.
testing::HasSubstr("string"),
// Expected type.
testing::HasSubstr(
MediaPipeTypeStringOrDemangled<uint32>())));
// Should fail verification when the type of a to-be-generated packet is
// wrong. The added handler now expects a string but will receive the uint32
// generated by the existing generator.
invalid_handler->set_status_handler("StringStatusHandler");
invalid_handler->set_input_side_packet(0, "generated_by_generator");
graph.reset(new CalculatorGraph());
// This is caught earlier, when the type of the PacketGenerator output
// is compared to the type of the StatusHandler input.
status = graph->Initialize(config);
EXPECT_THAT(status.message(),
testing::AllOf(
testing::HasSubstr("StringStatusHandler"),
// The problematic input side packet.
testing::HasSubstr("generated_by_generator"),
// Actual type.
testing::HasSubstr(MediaPipeTypeStringOrDemangled<uint32>()),
// Expected type.
testing::HasSubstr("string")));
}
TEST(CalculatorGraph, GenerateInInitialize) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
output_side_packet: "foo1"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "foo1"
output_side_packet: "foo2"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "input_in_run"
output_side_packet: "foo3"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "input_in_run"
input_side_packet: "foo3"
output_side_packet: "foo4"
}
)pb");
int initial_count = StaticCounterStringGenerator::NumPacketsGenerated();
MP_ASSERT_OK(graph.Initialize(
config,
{{"created_by_factory", MakePacket<std::string>("default string")},
{"input_in_initialize", MakePacket<int>(10)}}));
EXPECT_EQ(initial_count + 2,
StaticCounterStringGenerator::NumPacketsGenerated());
MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket<int>(11)}}));
EXPECT_EQ(initial_count + 4,
StaticCounterStringGenerator::NumPacketsGenerated());
MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket<int>(12)}}));
EXPECT_EQ(initial_count + 6,
StaticCounterStringGenerator::NumPacketsGenerated());
}
// Resets the counters in the input side packets used in the HandlersRun test.
// The value of all these counters will be set to the integer zero, as required
// at the beginning of the test.
void ResetCounters(std::map<std::string, Packet>* input_side_packets) {
(*input_side_packets)["no_input_counter1"] = AdoptAsUniquePtr(new int(0));
(*input_side_packets)["no_input_counter2"] = AdoptAsUniquePtr(new int(0));
(*input_side_packets)["available_input_counter1"] =
AdoptAsUniquePtr(new int(0));
(*input_side_packets)["available_input_counter2"] =
AdoptAsUniquePtr(new int(0));
(*input_side_packets)["unavailable_input_counter1"] =
AdoptAsUniquePtr(new int(0));
(*input_side_packets)["unavailable_input_counter2"] =
AdoptAsUniquePtr(new int(0));
}
// Tests that status handlers run.
// - We specify three status handlers: one taking no input side packets, one
// taking
// an input side packet that is always provided in the call to Run(), and one
// that takes the input side packet that will not be produced by the
// FailingPacketGenerator. The first two should proccess their PRE-RUN status
// but not their POST-RUN status, the third one should not process either of
// them since the graph execution fails before the PRE-RUN step.
// - We then replace the FailingPacketGenerator with a non-failing generator,
// and should have all three handlers running both PRE and POST-RUN (after the
// FailingSourceCalculator fails).
// - We test that all three status handlers run (with both status) at the end of
// a successful graph run.
// - Finally, we verify that when the status handler fails (either on PRE or
// POST run), but the calculators don't, we still receive errors from the
// calculator run.
TEST(CalculatorGraph, HandlersRun) {
std::unique_ptr<CalculatorGraph> graph(new CalculatorGraph());
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
packet_generator {
packet_generator: "FailingPacketGenerator"
output_side_packet: "unavailable"
}
node { calculator: "FailingSourceCalculator" output_stream: "output" }
status_handler {
status_handler: "IncrementingStatusHandler"
input_side_packet: "COUNTER1:no_input_counter1"
input_side_packet: "COUNTER2:no_input_counter2"
}
status_handler {
status_handler: "IncrementingStatusHandler"
input_side_packet: "COUNTER1:available_input_counter1"
input_side_packet: "COUNTER2:available_input_counter2"
input_side_packet: "EXTRA:available_string"
}
status_handler {
status_handler: "IncrementingStatusHandler"
input_side_packet: "COUNTER1:unavailable_input_counter1"
input_side_packet: "COUNTER2:unavailable_input_counter2"
input_side_packet: "EXTRA:unavailable"
}
)pb");
std::map<std::string, Packet> input_side_packets(
{{"unused_input", AdoptAsUniquePtr(new int(0))},
{"no_input_counter1", AdoptAsUniquePtr(new int(0))},
{"no_input_counter2", AdoptAsUniquePtr(new int(0))},
{"available_input_counter1", AdoptAsUniquePtr(new int(0))},
{"available_input_counter2", AdoptAsUniquePtr(new int(0))},
{"unavailable_input_counter1", AdoptAsUniquePtr(new int(0))},
{"unavailable_input_counter2", AdoptAsUniquePtr(new int(0))},
{"available_string", Adopt(new std::string("foo"))}});
// When the graph fails in initialize (even because of a PacketGenerator
// returning an error), status handlers should not be run.
ASSERT_THAT(graph->Initialize(config).ToString(),
testing::HasSubstr("FailingPacketGenerator"));
EXPECT_EQ(0,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(0,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
// Add an input side packet to the packet generator so that it doesn't
// run at initialize time.
config.mutable_packet_generator(0)->add_input_side_packet("unused_input");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
ResetCounters(&input_side_packets);
EXPECT_THAT(graph->Run(input_side_packets).ToString(),
testing::HasSubstr("FailingPacketGenerator"));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(0,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
// Replace the failing packet generator with something that works. All three
// status handlers should now process both the PRE and POST-RUN status.
config.mutable_packet_generator(0)->set_packet_generator(
"StaticCounterStringGenerator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
ResetCounters(&input_side_packets);
// The entire graph should still fail because of the FailingSourceCalculator.
EXPECT_THAT(graph->Run(input_side_packets).ToString(),
testing::HasSubstr("FailingSourceCalculator"));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
// Replace the failing calculator with something that works. All three
// status handlers should still process both PRE and POST-RUN status as part
// of the successful graph run.
config.mutable_node(0)->set_calculator("StringEmptySourceCalculator");
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
ResetCounters(&input_side_packets);
MP_EXPECT_OK(graph->Run(input_side_packets));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
absl::Status run_status;
// Make status handlers fail. The graph should fail.
// First, when the PRE_run fails
IncrementingStatusHandler::SetPreRunStatusResult(
absl::InternalError("Fail at pre-run"));
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
ResetCounters(&input_side_packets);
run_status = graph->Run(input_side_packets);
EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal);
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at pre-run"));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(0,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(0, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
// Second, when the POST_run fails
IncrementingStatusHandler::SetPreRunStatusResult(absl::OkStatus());
IncrementingStatusHandler::SetPostRunStatusResult(
absl::InternalError("Fail at post-run"));
graph.reset(new CalculatorGraph());
MP_ASSERT_OK(graph->Initialize(config));
ResetCounters(&input_side_packets);
run_status = graph->Run(input_side_packets);
EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal);
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at post-run"));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter1")));
EXPECT_EQ(1,
*GetFromUniquePtr<int>(input_side_packets.at("no_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("available_input_counter2")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter1")));
EXPECT_EQ(1, *GetFromUniquePtr<int>(
input_side_packets.at("unavailable_input_counter2")));
}
TEST(CalculatorGraph, CalculatorGraphConfigCopyElision) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
)pb");
// config is consumed and never copied, which avoid copying data.
MP_ASSERT_OK(graph.Initialize(std::move(config)));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(1).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseInputStream("in"));
MP_EXPECT_OK(graph.WaitUntilDone());
}
// Test that calling SetOffset() in Calculator::Process() results in the
// absl::StatusCode::kFailedPrecondition error.
TEST(CalculatorGraph, SetOffsetInProcess) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'SetOffsetInProcessCalculator'
input_stream: 'in'
output_stream: 'out'
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(0).At(Timestamp(0))));
absl::Status status = graph.WaitUntilIdle();
EXPECT_FALSE(status.ok());
EXPECT_EQ(absl::StatusCode::kFailedPrecondition, status.code());
}
// Test that MediaPipe releases input packets when it is done with them.
TEST(CalculatorGraph, InputPacketLifetime) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
LifetimeTracker tracker;
Timestamp timestamp = Timestamp(0);
auto new_packet = [&timestamp, &tracker] {
return Adopt(tracker.MakeObject().release()).At(++timestamp);
};
MP_ASSERT_OK(graph.Initialize(config));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet()));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_EQ(0, tracker.live_count());
MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet()));
MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet()));
MP_EXPECT_OK(graph.AddPacketToInputStream("in", new_packet()));
MP_EXPECT_OK(graph.WaitUntilIdle());
EXPECT_EQ(0, tracker.live_count());
MP_EXPECT_OK(graph.CloseInputStream("in"));
MP_EXPECT_OK(graph.WaitUntilDone());
}
// Demonstrate an if-then-else graph.
TEST(CalculatorGraph, IfThenElse) {
// This graph has an if-then-else structure. The left branch, selected by the
// select value 0, applies a double (multiply by 2) operation. The right
// branch, selected by the select value 1, applies a square operation. The
// left branch also has some no-op PassThroughCalculators to make the lengths
// of the two branches different.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
input_stream: 'select'
node {
calculator: 'DemuxTimedCalculator'
input_stream: 'INPUT:in'
input_stream: 'SELECT:select'
output_stream: 'OUTPUT:0:left'
output_stream: 'OUTPUT:1:right'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'left'
output_stream: 'left1'
}
node {
calculator: 'DoubleIntCalculator'
input_stream: 'left1'
output_stream: 'left2'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'left2'
output_stream: 'left3'
}
node {
calculator: 'SquareIntCalculator'
input_stream: 'right'
output_stream: 'right1'
}
node {
calculator: 'MuxTimedCalculator'
input_stream: 'INPUT:0:left3'
input_stream: 'INPUT:1:right1'
input_stream: 'SELECT:select'
output_stream: 'OUTPUT:out'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("out", &config, &packet_dump);
Timestamp timestamp = Timestamp(0);
auto send_inputs = [&graph, &timestamp](int input, int select) {
++timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(input).At(timestamp)));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(select).At(timestamp)));
};
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// If the "select" input is 0, we apply a double operation. If "select" is 1,
// we apply a square operation. To make the code easier to understand, define
// symbolic names for the select values.
const int kApplyDouble = 0;
const int kApplySquare = 1;
send_inputs(1, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(2, packet_dump[0].Get<int>());
send_inputs(2, kApplySquare);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(4, packet_dump[1].Get<int>());
send_inputs(3, kApplyDouble);
send_inputs(4, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(4, packet_dump.size());
EXPECT_EQ(6, packet_dump[2].Get<int>());
EXPECT_EQ(8, packet_dump[3].Get<int>());
send_inputs(5, kApplySquare);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(5, packet_dump.size());
EXPECT_EQ(25, packet_dump[4].Get<int>());
send_inputs(6, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(6, packet_dump.size());
EXPECT_EQ(12, packet_dump[5].Get<int>());
send_inputs(7, kApplySquare);
send_inputs(8, kApplySquare);
send_inputs(9, kApplySquare);
send_inputs(10, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(10, packet_dump.size());
EXPECT_EQ(49, packet_dump[6].Get<int>());
EXPECT_EQ(64, packet_dump[7].Get<int>());
EXPECT_EQ(81, packet_dump[8].Get<int>());
EXPECT_EQ(20, packet_dump[9].Get<int>());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_EQ(10, packet_dump.size());
}
// A simple output selecting test calculator, which omits timestamp bounds
// for the unselected outputs.
class DemuxUntimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag(kInputTag).SetAny();
cc->Inputs().Tag(kSelectTag).Set<int>();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) {
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag));
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
int index = cc->Inputs().Tag(kSelectTag).Get<int>();
if (!cc->Inputs().Tag(kInputTag).IsEmpty()) {
cc->Outputs()
.Get("OUTPUT", index)
.AddPacket(cc->Inputs().Tag(kInputTag).Value());
} else {
cc->Outputs()
.Get("OUTPUT", index)
.SetNextTimestampBound(cc->InputTimestamp() + 1);
}
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(DemuxUntimedCalculator);
// Demonstrate an if-then-else graph. This test differs from the IfThenElse test
// in that it uses optional input streams instead of next timestamp bound
// propagation.
TEST(CalculatorGraph, IfThenElse2) {
// This graph has an if-then-else structure. The left branch, selected by the
// select value 0, applies a double (multiply by 2) operation. The right
// branch, selected by the select value 1, applies a square operation. The
// left branch also has some no-op PassThroughCalculators to make the lengths
// of the two branches different.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
input_stream: 'select'
node {
calculator: 'DemuxUntimedCalculator'
input_stream: 'INPUT:in'
input_stream: 'SELECT:select'
output_stream: 'OUTPUT:0:left'
output_stream: 'OUTPUT:1:right'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'left'
output_stream: 'left1'
}
node {
calculator: 'DoubleIntCalculator'
input_stream: 'left1'
output_stream: 'left2'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'left2'
output_stream: 'left3'
}
node {
calculator: 'SquareIntCalculator'
input_stream: 'right'
output_stream: 'right1'
}
node {
calculator: 'MuxCalculator'
input_stream: 'INPUT:0:left3'
input_stream: 'INPUT:1:right1'
input_stream: 'SELECT:select'
output_stream: 'OUTPUT:out'
input_stream_handler { input_stream_handler: 'MuxInputStreamHandler' }
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("out", &config, &packet_dump);
Timestamp timestamp = Timestamp(0);
auto send_inputs = [&graph, &timestamp](int input, int select) {
++timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(input).At(timestamp)));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(select).At(timestamp)));
};
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// If the "select" input is 0, we apply a double operation. If "select" is 1,
// we apply a square operation. To make the code easier to understand, define
// symbolic names for the select values.
const int kApplyDouble = 0;
const int kApplySquare = 1;
send_inputs(1, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(2, packet_dump[0].Get<int>());
send_inputs(2, kApplySquare);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(4, packet_dump[1].Get<int>());
send_inputs(3, kApplyDouble);
send_inputs(4, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(4, packet_dump.size());
EXPECT_EQ(6, packet_dump[2].Get<int>());
EXPECT_EQ(8, packet_dump[3].Get<int>());
send_inputs(5, kApplySquare);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(5, packet_dump.size());
EXPECT_EQ(25, packet_dump[4].Get<int>());
send_inputs(6, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(6, packet_dump.size());
EXPECT_EQ(12, packet_dump[5].Get<int>());
send_inputs(7, kApplySquare);
send_inputs(8, kApplySquare);
send_inputs(9, kApplySquare);
send_inputs(10, kApplyDouble);
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(10, packet_dump.size());
EXPECT_EQ(49, packet_dump[6].Get<int>());
EXPECT_EQ(64, packet_dump[7].Get<int>());
EXPECT_EQ(81, packet_dump[8].Get<int>());
EXPECT_EQ(20, packet_dump[9].Get<int>());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_EQ(10, packet_dump.size());
}
// A regression test for bug 28321551. The scheduler should be able to run
// the calculator graph to completion without hanging. The test merely checks
// that CalculatorGraph::Run() returns.
TEST(CalculatorGraph, ClosedSourceNodeShouldNotBeUnthrottled) {
// This calculator graph has two source nodes. The first source node,
// OutputAllSourceCalculator, outputs a lot of packets in one shot and stops.
// The second source node, OutputOneAtATimeSourceCalculator, outputs one
// packet at a time. But it is connected to a node, DecimatorCalculator,
// that discards most of its input packets and only rarely outputs a packet.
// The sink node, MergeCalculator, receives three input streams, two from
// the two source nodes and one from DecimatorCalculator. The two input
// streams connected to the two source nodes will become full, and the
// MediaPipe scheduler will throttle the source nodes.
//
// The MediaPipe scheduler should not schedule a closed source node, even if
// the source node filled an input stream and the input stream changes from
// being "full" to "not full".
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
num_threads: 1
max_queue_size: 100
node {
calculator: 'OutputAllSourceCalculator'
output_stream: 'first_stream'
}
node {
calculator: 'OutputOneAtATimeSourceCalculator'
output_stream: 'second_stream'
}
node {
calculator: 'DecimatorCalculator'
input_stream: 'second_stream'
output_stream: 'decimated_second_stream'
}
node {
calculator: 'MergeCalculator'
input_stream: 'first_stream'
input_stream: 'second_stream'
input_stream: 'decimated_second_stream'
output_stream: 'output'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
}
// Tests that a calculator can output a packet in the Open() method.
//
// The initial output packet generated by UnitDelayCalculator::Open() causes
// the following to happen before the scheduler starts to run:
// - The downstream PassThroughCalculator becomes ready and is added to the
// scheduler queue.
// - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is
// throttled.
// The scheduler should be able to run the graph from this initial state.
TEST(CalculatorGraph, OutputPacketInOpen) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
max_queue_size: 1
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'UnitDelayCalculator'
input_stream: 'integers'
output_stream: 'delayed_integers'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'delayed_integers'
output_stream: 'output'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1,
packet_dump.size());
EXPECT_EQ(0, packet_dump[0].Get<int>());
EXPECT_EQ(Timestamp(0), packet_dump[0].Timestamp());
for (int i = 1; i <= GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
EXPECT_EQ(i, packet_dump[i].Get<int>());
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
}
// Tests that a calculator can output a packet in the Open() method.
//
// The initial output packet generated by UnitDelayCalculator::Open() causes
// the following to happen before the scheduler starts to run:
// - The downstream MergeCalculator does not become ready because its second
// input stream has no packet.
// - Since max_queue_size is set to 1, the GlobalCountSourceCalculator is
// throttled.
// The scheduler must schedule a throttled source node from the beginning.
TEST(CalculatorGraph, OutputPacketInOpen2) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
max_queue_size: 1
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'UnitDelayCalculator'
input_stream: 'integers'
output_stream: 'delayed_integers'
}
node {
calculator: 'MergeCalculator'
input_stream: 'delayed_integers'
input_stream: 'integers'
output_stream: 'output'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets + 1,
packet_dump.size());
int i;
for (i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
std::string expected =
absl::Substitute("Timestamp($0) $1 $2",
packet_dump[i].Timestamp().DebugString(), i, i + 1);
EXPECT_EQ(expected, packet_dump[i].Get<std::string>());
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
std::string expected = absl::Substitute(
"Timestamp($0) $1 empty", packet_dump[i].Timestamp().DebugString(), i);
EXPECT_EQ(expected, packet_dump[i].Get<std::string>());
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
// Tests that no packets are available on input streams in Open(), even if the
// upstream calculator outputs a packet in Open().
TEST(CalculatorGraph, EmptyInputInOpen) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
max_queue_size: 1
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
# UnitDelayCalculator outputs a packet during Open().
node {
calculator: 'UnitDelayCalculator'
input_stream: 'integers'
output_stream: 'delayed_integers'
}
node {
calculator: 'AssertEmptyInputInOpenCalculator'
input_stream: 'delayed_integers'
}
node {
calculator: 'AssertEmptyInputInOpenCalculator'
input_stream: 'integers'
}
)pb");
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_EXPECT_OK(graph.Run(input_side_packets));
}
// Test for b/33568859.
TEST(CalculatorGraph, UnthrottleRespectsLayers) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
max_queue_size: 1
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers0'
source_layer: 0
}
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
input_side_packet: 'output_in_open'
output_stream: 'integers1'
source_layer: 1
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'integers1'
output_stream: 'integers1passthrough'
}
)pb");
std::vector<Packet> layer0_packets;
std::vector<Packet> layer1_packets;
tool::AddVectorSink("integers0", &config, &layer0_packets);
tool::AddVectorSink("integers1passthrough", &config, &layer1_packets);
std::atomic<int> global_counter(0);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
// TODO: Set this value to true. When the calculator outputs a
// packet in Open, it will trigget b/33568859, and the test will fail. Use
// this test to verify that b/33568859 is fixed.
constexpr bool kOutputInOpen = true;
input_side_packets["output_in_open"] = MakePacket<bool>(kOutputInOpen);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
layer0_packets.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets,
layer1_packets.size());
// Check that we ran things in the expected order.
int count = 0;
if (kOutputInOpen) {
EXPECT_EQ(count, layer1_packets[0].Get<int>());
EXPECT_EQ(Timestamp(0), layer1_packets[0].Timestamp());
++count;
}
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets;
++i, ++count) {
EXPECT_EQ(count, layer0_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), layer0_packets[i].Timestamp());
}
for (int i = kOutputInOpen ? 1 : 0;
i < GlobalCountSourceCalculator::kNumOutputPackets; ++i, ++count) {
EXPECT_EQ(count, layer1_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), layer1_packets[i].Timestamp());
}
}
// The graph calculates the sum of all the integers output by the source node
// so far. The graph has one cycle.
TEST(CalculatorGraph, Cycle) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'IntAdderCalculator'
input_stream: 'integers'
input_stream: 'old_sum'
input_stream_info: {
tag_index: ':1' # 'old_sum'
back_edge: true
}
output_stream: 'sum'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'UnitDelayCalculator'
input_stream: 'sum'
output_stream: 'old_sum'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("sum", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size());
int sum = 0;
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
sum += i + 1;
EXPECT_EQ(sum, packet_dump[i].Get<int>());
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
}
// The graph calculates the sum of all the integers output by the source node
// so far. The graph has one cycle.
//
// The difference from the "Cycle" test is that the graph is scheduled with
// packet timestamps ignored.
TEST(CalculatorGraph, CycleUntimed) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream_handler {
input_stream_handler: 'BarrierInputStreamHandler'
}
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'IntAdderCalculator'
input_stream: 'integers'
input_stream: 'old_sum'
input_stream_info: {
tag_index: ':1' # 'old_sum'
back_edge: true
}
output_stream: 'sum'
}
node {
calculator: 'UnitDelayUntimedCalculator'
input_stream: 'sum'
output_stream: 'old_sum'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("sum", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size());
int sum = 0;
for (int i = 0; i < GlobalCountSourceCalculator::kNumOutputPackets; ++i) {
sum += i + 1;
EXPECT_EQ(sum, packet_dump[i].Get<int>());
}
}
// This unit test is a direct form I implementation of Example 6.2 of
// Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.6. The system
// function of the linear time-invariant (LTI) system is
// H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2)
// The graph has two cycles.
TEST(CalculatorGraph, DirectFormI) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'IntToFloatCalculator'
input_stream: 'integers'
output_stream: 'x'
}
node {
calculator: 'FloatUnitDelayCalculator'
input_stream: 'x'
output_stream: 'a'
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'a'
output_stream: 'b'
input_side_packet: 'b1'
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'x'
input_stream: 'b'
output_stream: 'c'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'c'
input_stream: 'f'
input_stream_info: {
tag_index: ':1' # 'f'
back_edge: true
}
output_stream: 'y'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'FloatUnitDelayCalculator'
input_stream: 'y'
output_stream: 'd'
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'd'
output_stream: 'e'
input_side_packet: 'a1'
}
node {
calculator: 'FloatUnitDelayCalculator'
input_stream: 'd'
output_stream: 'g'
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'g'
output_stream: 'h'
input_side_packet: 'a2'
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'e'
input_stream: 'h'
output_stream: 'f'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("y", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
input_side_packets["a2"] = Adopt(new float(-0.9));
input_side_packets["a1"] = Adopt(new float(1.5));
input_side_packets["b1"] = Adopt(new float(2.0));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5);
EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get<float>());
EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get<float>());
EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get<float>());
EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get<float>());
EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get<float>());
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
}
// This unit test is a direct form II implementation of Example 6.2 of
// Discrete-Time Signal Processing, 3rd Ed., shown in Figure 6.7. The system
// function of the linear time-invariant (LTI) system is
// H(z) = (1 + 2 * z^-1) / (1 - 1.5 * z^-1 + 0.9 * z^-2)
// The graph has two cycles.
TEST(CalculatorGraph, DirectFormII) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'GlobalCountSourceCalculator'
input_side_packet: 'global_counter'
output_stream: 'integers'
}
node {
calculator: 'IntToFloatCalculator'
input_stream: 'integers'
output_stream: 'x'
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'x'
input_stream: 'f'
input_stream_info: {
tag_index: ':1' # 'f'
back_edge: true
}
output_stream: 'a'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'FloatUnitDelayCalculator'
input_stream: 'a'
output_stream: 'b'
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'b'
output_stream: 'd'
input_side_packet: 'a1'
}
node {
calculator: 'FloatUnitDelayCalculator'
input_stream: 'b'
output_stream: 'c'
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'c'
output_stream: 'e'
input_side_packet: 'a2'
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'd'
input_stream: 'e'
output_stream: 'f'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'FloatScalarMultiplierCalculator'
input_stream: 'b'
output_stream: 'g'
input_side_packet: 'b1'
}
node {
calculator: 'FloatAdderCalculator'
input_stream: 'a'
input_stream: 'g'
output_stream: 'y'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("y", &config, &packet_dump);
std::atomic<int> global_counter(1);
std::map<std::string, Packet> input_side_packets;
input_side_packets["global_counter"] = Adopt(new auto(&global_counter));
input_side_packets["a2"] = Adopt(new float(-0.9));
input_side_packets["a1"] = Adopt(new float(1.5));
input_side_packets["b1"] = Adopt(new float(2.0));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run(input_side_packets));
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, packet_dump.size());
ASSERT_EQ(GlobalCountSourceCalculator::kNumOutputPackets, 5);
EXPECT_FLOAT_EQ(1.0, packet_dump[0].Get<float>());
EXPECT_FLOAT_EQ(5.5, packet_dump[1].Get<float>());
EXPECT_FLOAT_EQ(14.35, packet_dump[2].Get<float>());
EXPECT_FLOAT_EQ(26.575, packet_dump[3].Get<float>());
EXPECT_FLOAT_EQ(39.9475, packet_dump[4].Get<float>());
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(Timestamp(i), packet_dump[i].Timestamp());
}
}
// Calculates the dot products of two streams of three-dimensional vectors.
TEST(CalculatorGraph, DotProduct) {
// The use of BarrierInputStreamHandler in this graph aligns the input
// packets to a calculator by arrival order rather than by timestamp.
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream_handler {
input_stream_handler: 'BarrierInputStreamHandler'
}
node {
calculator: 'TestSequence1SourceCalculator'
output_stream: 'test_sequence_1'
}
node {
calculator: 'TestSequence2SourceCalculator'
output_stream: 'test_sequence_2'
}
node {
calculator: 'Modulo3SourceCalculator'
output_stream: 'select_0_1_2'
}
node {
calculator: 'DemuxUntimedCalculator'
input_stream: 'INPUT:test_sequence_1'
input_stream: 'SELECT:select_0_1_2'
output_stream: 'OUTPUT:0:x_1'
output_stream: 'OUTPUT:1:y_1'
output_stream: 'OUTPUT:2:z_1'
}
node {
calculator: 'DemuxUntimedCalculator'
input_stream: 'INPUT:test_sequence_2'
input_stream: 'SELECT:select_0_1_2'
output_stream: 'OUTPUT:0:x_2'
output_stream: 'OUTPUT:1:y_2'
output_stream: 'OUTPUT:2:z_2'
}
node {
calculator: 'IntMultiplierCalculator'
input_stream: 'x_1'
input_stream: 'x_2'
output_stream: 'x_product'
}
node {
calculator: 'IntMultiplierCalculator'
input_stream: 'y_1'
input_stream: 'y_2'
output_stream: 'y_product'
}
node {
calculator: 'IntMultiplierCalculator'
input_stream: 'z_1'
input_stream: 'z_2'
output_stream: 'z_product'
}
node {
calculator: 'IntAdderCalculator'
input_stream: 'x_product'
input_stream: 'y_product'
input_stream: 'z_product'
output_stream: 'dot_product'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("dot_product", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
// The calculator graph performs the following computation:
// test_sequence_1 is split into x_1, y_1, z_1.
// test_sequence_2 is split into x_2, y_2, z_2.
// x_product = x_1 * x_2
// y_product = y_1 * y_2
// z_product = z_1 * z_2
// dot_product = x_product + y_product + z_product
//
// The values in these streams are:
// test_sequence_1: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
// test_sequence_2: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
// x_1: 0, 3, 6, 9, 12
// x_2: 1, 4, 7, 10, 13
// x_product: 0, 12, 42, 90, 156
// y_1: 1, 4, 7, 10, 13
// y_2: 2, 5, 8, 11, 14
// y_product: 2, 20, 56, 110, 182
// z_1: 2, 5, 8, 11, 14
// z_2: 3, 6, 9, 12, 15
// z_product: 6, 30, 72, 132, 210
// dot_product: 8, 62, 170, 332, 548
ASSERT_EQ(kTestSequenceLength / 3, packet_dump.size());
const int expected[] = {8, 62, 170, 332, 548};
for (int i = 0; i < packet_dump.size(); ++i) {
EXPECT_EQ(expected[i], packet_dump[i].Get<int>());
}
}
TEST(CalculatorGraph, TerminatesOnCancelWithOpenGraphInputStreams) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'PassThroughCalculator'
input_stream: 'in_a'
input_stream: 'in_b'
output_stream: 'out_a'
output_stream: 'out_b'
}
input_stream: 'in_a'
input_stream: 'in_b'
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_a", MakePacket<int>(1).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseInputStream("in_a"));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_b", MakePacket<int>(2).At(Timestamp(2))));
MP_EXPECT_OK(graph.WaitUntilIdle());
graph.Cancel();
// This tests that the graph doesn't deadlock on WaitUntilDone (because
// the scheduler thread is sleeping).
absl::Status status = graph.WaitUntilDone();
EXPECT_EQ(status.code(), absl::StatusCode::kCancelled);
}
TEST(CalculatorGraph, TerminatesOnCancelAfterPause) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
input_stream: 'in'
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
graph.Pause();
// Make the PassThroughCalculator runnable while the scheduler is paused.
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(1).At(Timestamp(1))));
// Now cancel the graph run. A non-empty scheduler queue should not prevent
// the scheduler from terminating.
graph.Cancel();
// Any attempt to pause the scheduler after the graph run is cancelled should
// be ignored.
graph.Pause();
// This tests that the graph doesn't deadlock on WaitUntilDone (because
// the scheduler thread is sleeping).
absl::Status status = graph.WaitUntilDone();
EXPECT_EQ(status.code(), absl::StatusCode::kCancelled);
}
// A PacketGenerator that simply passes its input Packets through
// unchanged. The inputs may be specified by tag or index. The outputs
// must match the inputs exactly. Any options may be specified and will
// also be ignored.
class PassThroughGenerator : public PacketGenerator {
public:
static absl::Status FillExpectations(
const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs,
PacketTypeSet* outputs) {
if (!inputs->TagMap()->SameAs(*outputs->TagMap())) {
return absl::InvalidArgumentError(
"Input and outputs to PassThroughGenerator must use the same tags "
"and indexes.");
}
for (CollectionItemId id = inputs->BeginId(); id < inputs->EndId(); ++id) {
inputs->Get(id).SetAny();
outputs->Get(id).SetSameAs(&inputs->Get(id));
}
return absl::OkStatus();
}
static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets,
PacketSet* output_side_packets) {
for (CollectionItemId id = input_side_packets.BeginId();
id < input_side_packets.EndId(); ++id) {
output_side_packets->Get(id) = input_side_packets.Get(id);
}
return absl::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(PassThroughGenerator);
TEST(CalculatorGraph, RecoverAfterRunError) {
PacketGeneratorGraph generator_graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
name: 'calculator1'
calculator: 'CountingSourceCalculator'
output_stream: 'count1'
input_side_packet: 'MAX_COUNT:max_count2'
input_side_packet: 'ERROR_COUNT:max_error2'
}
packet_generator {
packet_generator: 'EnsurePositivePacketGenerator'
input_side_packet: 'max_count1'
output_side_packet: 'max_count2'
input_side_packet: 'max_error1'
output_side_packet: 'max_error2'
}
status_handler {
status_handler: 'FailableStatusHandler'
input_side_packet: 'status_handler_command'
}
)pb");
int packet_count = 0;
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config, {}));
MP_ASSERT_OK(graph.ObserveOutputStream("count1",
[&packet_count](const Packet& packet) {
++packet_count;
return absl::OkStatus();
}));
// Set ERROR_COUNT higher than MAX_COUNT and hence the calculator will
// finish successfully.
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
// Fail in PacketGenerator::Generate().
// Negative max_count1 will cause EnsurePositivePacketGenerator to fail.
ASSERT_FALSE(graph
.Run({{"max_count1", MakePacket<int>(-1)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
// Fail in PacketGenerator::Generate() also fail in StatusHandler.
ASSERT_FALSE(graph
.Run({{"max_count1", MakePacket<int>(-1)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kFailPreRun)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
ASSERT_FALSE(
graph
.Run({{"max_count1", MakePacket<int>(-1)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kFailPostRun)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
// Fail in Calculator::Process().
ASSERT_FALSE(graph
.Run({{"max_count1", MakePacket<int>(1000)},
{"max_error1", MakePacket<int>(10)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
// Fail in Calculator::Process() also fail in StatusHandler.
ASSERT_FALSE(graph
.Run({{"max_count1", MakePacket<int>(1000)},
{"max_error1", MakePacket<int>(10)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kFailPreRun)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
ASSERT_FALSE(
graph
.Run({{"max_count1", MakePacket<int>(1000)},
{"max_error1", MakePacket<int>(10)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kFailPostRun)}})
.ok());
packet_count = 0;
MP_ASSERT_OK(graph.Run({{"max_count1", MakePacket<int>(10)},
{"max_error1", MakePacket<int>(20)},
{"status_handler_command",
MakePacket<int>(FailableStatusHandler::kOk)}}));
EXPECT_EQ(packet_count, 10);
}
TEST(CalculatorGraph, SetInputStreamMaxQueueSizeWorksSlowCalculator) {
using Semaphore = SemaphoreCalculator::Semaphore;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'SemaphoreCalculator'
input_stream: 'in'
output_stream: 'out'
input_side_packet: 'POST_SEM:post_sem'
input_side_packet: 'WAIT_SEM:wait_sem'
}
node {
calculator: 'SemaphoreCalculator'
input_stream: 'in_2'
output_stream: 'out_2'
input_side_packet: 'POST_SEM:post_sem_busy'
input_side_packet: 'WAIT_SEM:wait_sem_busy'
}
input_stream: 'in'
input_stream: 'in_2'
max_queue_size: 100
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
graph.SetGraphInputStreamAddMode(
CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL);
MP_ASSERT_OK(graph.SetInputStreamMaxQueueSize("in", 1));
Semaphore calc_entered_process(0);
Semaphore calc_can_exit_process(0);
Semaphore calc_entered_process_busy(0);
Semaphore calc_can_exit_process_busy(0);
MP_ASSERT_OK(graph.StartRun({
{"post_sem", MakePacket<Semaphore*>(&calc_entered_process)},
{"wait_sem", MakePacket<Semaphore*>(&calc_can_exit_process)},
{"post_sem_busy", MakePacket<Semaphore*>(&calc_entered_process_busy)},
{"wait_sem_busy", MakePacket<Semaphore*>(&calc_can_exit_process_busy)},
}));
Timestamp timestamp(0);
// Prevent deadlock resolution by running the "busy" SemaphoreCalculator
// for the duration of the test.
MP_EXPECT_OK(
graph.AddPacketToInputStream("in_2", MakePacket<int>(0).At(timestamp)));
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(0).At(timestamp++)));
for (int i = 1; i < 20; ++i, ++timestamp) {
// Wait for the calculator to begin its Process call.
calc_entered_process.Acquire(1);
// Now the calculator is stuck processing a packet. We can queue up
// another one.
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(i).At(timestamp)));
// We should be prevented from adding another, since the queue is now full.
absl::Status status = graph.AddPacketToInputStream(
"in", MakePacket<int>(i).At(timestamp + 1));
EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable);
// Allow calculator to complete its Process call.
calc_can_exit_process.Release(1);
}
// Allow the final Process call to complete.
calc_can_exit_process.Release(1);
calc_can_exit_process_busy.Release(1);
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Verify the scheduler unthrottles the graph input stream to avoid a deadlock,
// and won't enter a busy loop.
TEST(CalculatorGraph, AddPacketNoBusyLoop) {
// The DecimatorCalculator ouputs 1 out of every 101 input packets and drops
// the rest, without setting the next timestamp bound on its output. As a
// result, the MergeCalculator is not runnable in between and packets on its
// "in" input stream will be queued and exceed the max queue size.
//
// in
// |
// / \
// / \
// / \
// | \
// v |
// +---------+ |
// 101:1 |Decimator| | <== Packet buildup
// +---------+ |
// | |
// v v
// +----------+
// | Merge |
// +----------+
// |
// v
// out
//
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
max_queue_size: 1
node {
calculator: 'DecimatorCalculator'
input_stream: 'in'
output_stream: 'decimated_in'
}
node {
calculator: 'MergeCalculator'
input_stream: 'decimated_in'
input_stream: 'in'
output_stream: 'out'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
graph.SetGraphInputStreamAddMode(
CalculatorGraph::GraphInputStreamAddMode::WAIT_TILL_NOT_FULL);
std::vector<Packet> out_packets; // Packets from the output stream "out".
MP_ASSERT_OK(
graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
const int kDecimationRatio = DecimatorCalculator::kDecimationRatio;
// To leave the graph input stream "in" in the throttled state, kNumPackets
// can be any value other than a multiple of kDecimationRatio plus one.
const int kNumPackets = 2 * kDecimationRatio;
for (int i = 0; i < kNumPackets; ++i) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(i).At(Timestamp(i))));
}
// The graph input stream "in" is throttled. Wait until the graph is idle.
MP_ASSERT_OK(graph.WaitUntilIdle());
// Check that Pause() does not block forever trying to acquire a mutex.
// This is a regression test for an old bug.
graph.Pause();
graph.Resume();
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
// The expected output packets are:
// "Timestamp(0) 0 0"
// "Timestamp(1) empty 1"
// ...
// "Timestamp(100) empty 100"
// "Timestamp(101) 101 101"
// "Timestamp(102) empty 102"
// ...
// "Timestamp(201) empty 201"
ASSERT_EQ(kNumPackets, out_packets.size());
for (int i = 0; i < out_packets.size(); ++i) {
std::string format = (i % kDecimationRatio == 0) ? "Timestamp($0) $0 $0"
: "Timestamp($0) empty $0";
std::string expected = absl::Substitute(format, i);
EXPECT_EQ(expected, out_packets[i].Get<std::string>());
EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp());
}
}
namespace nested_ns {
typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)>
ProcessFunction;
// A Calculator that delegates its Process function to a callback function.
class ProcessCallbackCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0));
}
cc->InputSidePackets().Index(0).Set<std::unique_ptr<ProcessFunction>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
callback_ =
*GetFromUniquePtr<ProcessFunction>(cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
return callback_(cc->Inputs(), &(cc->Outputs()));
}
private:
ProcessFunction callback_;
};
REGISTER_CALCULATOR(::mediapipe::nested_ns::ProcessCallbackCalculator);
} // namespace nested_ns
TEST(CalculatorGraph, CalculatorInNamepsace) {
CalculatorGraphConfig config;
CHECK(proto_ns::TextFormat::ParseFromString(R"(
input_stream: 'in_a'
node {
calculator: 'mediapipe.nested_ns.ProcessCallbackCalculator'
input_stream: 'in_a'
output_stream: 'out_a'
input_side_packet: 'callback_1'
}
)",
&config));
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
nested_ns::ProcessFunction callback_1;
MP_ASSERT_OK(
graph.StartRun({{"callback_1", AdoptAsUniquePtr(new auto(callback_1))}}));
MP_EXPECT_OK(graph.WaitUntilIdle());
}
// A ProcessFunction that passes through all packets.
absl::Status DoProcess(const InputStreamShardSet& inputs,
OutputStreamShardSet* outputs) {
for (int i = 0; i < inputs.NumEntries(); ++i) {
if (!inputs.Index(i).Value().IsEmpty()) {
outputs->Index(i).AddPacket(inputs.Index(i).Value());
}
}
return absl::OkStatus();
}
TEST(CalculatorGraph, ObserveOutputStream) {
const int max_count = 10;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count'
input_side_packet: 'MAX_COUNT:max_count'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'count'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
// Observe the internal output stream "count" and the unconnected output
// stream "out".
std::vector<Packet> count_packets; // Packets from the output stream "count".
std::vector<Packet> out_packets; // Packets from the output stream "out".
MP_ASSERT_OK(graph.ObserveOutputStream(
"count", [&count_packets](const Packet& packet) {
count_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(
graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.Run());
ASSERT_EQ(max_count, count_packets.size());
for (int i = 0; i < count_packets.size(); ++i) {
EXPECT_EQ(i, count_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp());
}
ASSERT_EQ(max_count, out_packets.size());
for (int i = 0; i < out_packets.size(); ++i) {
EXPECT_EQ(i, out_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp());
}
}
class PassThroughSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'INPUT:input'
output_stream: 'OUTPUT:output'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
output_stream: 'output'
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(PassThroughSubgraph);
TEST(CalculatorGraph, ObserveOutputStreamSubgraph) {
const int max_count = 10;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count'
input_side_packet: 'MAX_COUNT:max_count'
}
node {
calculator: 'PassThroughSubgraph'
input_stream: 'INPUT:count'
output_stream: 'OUTPUT:out'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
// Observe the unconnected output stream "out".
std::vector<Packet> out_packets; // Packets from the output stream "out".
MP_ASSERT_OK(
graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.Run());
ASSERT_EQ(max_count, out_packets.size());
for (int i = 0; i < out_packets.size(); ++i) {
EXPECT_EQ(i, out_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), out_packets[i].Timestamp());
}
}
TEST(CalculatorGraph, ObserveOutputStreamError) {
const int max_count = 10;
const int fail_count = 6;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count'
input_side_packet: 'MAX_COUNT:max_count'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'count'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
// Observe the internal output stream "count" and the unconnected output
// stream "out".
std::vector<Packet> count_packets; // Packets from the output stream "count".
std::vector<Packet> out_packets; // Packets from the output stream "out".
MP_ASSERT_OK(graph.ObserveOutputStream(
"count", [&count_packets](const Packet& packet) {
count_packets.push_back(packet);
if (count_packets.size() >= fail_count) {
return absl::UnknownError("Expected. MagicString-eatnhuea");
} else {
return absl::OkStatus();
}
}));
MP_ASSERT_OK(
graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) {
out_packets.push_back(packet);
return absl::OkStatus();
}));
absl::Status status = graph.Run();
ASSERT_THAT(status.message(), testing::HasSubstr("MagicString-eatnhuea"));
ASSERT_EQ(fail_count, count_packets.size());
for (int i = 0; i < count_packets.size(); ++i) {
EXPECT_EQ(i, count_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), count_packets[i].Timestamp());
}
}
TEST(CalculatorGraph, ObserveOutputStreamNonexistent) {
const int max_count = 10;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count'
input_side_packet: 'MAX_COUNT:max_count'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'count'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
// Observe the internal output stream "count".
std::vector<Packet> count_packets; // Packets from the output stream "count".
absl::Status status = graph.ObserveOutputStream(
"not_found", [&count_packets](const Packet& packet) {
count_packets.push_back(packet);
return absl::OkStatus();
});
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(), testing::HasSubstr("not_found"));
}
// Verify that after a fast source node is closed, a slow sink node can
// consume all the accumulated input packets. In other words, closing an
// output stream still allows its mirrors to process all the received packets.
TEST(CalculatorGraph, FastSourceSlowSink) {
const int max_count = 10;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
num_threads: 2
max_queue_size: 100
node {
calculator: 'CountingSourceCalculator'
output_stream: 'out'
input_side_packet: 'MAX_COUNT:max_count'
}
node { calculator: 'SlowCountingSinkCalculator' input_stream: 'out' }
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
MP_EXPECT_OK(graph.Run());
}
TEST(CalculatorGraph, GraphFinishesWhilePaused) {
// The graph contains only one node, and the node runs only once. This test
// sets up the following sequence of events (all times in milliseconds):
//
// Application thread Worker thread
//
// T=0 graph.StartRun OneShot20MsCalculator::Process starts
// T=10 graph.Pause
// T=20 OneShot20MsCalculator::Process ends.
// So graph finishes running while paused.
// T=30 graph.Resume
//
// graph.WaitUntilDone must not block forever.
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { calculator: 'OneShot20MsCalculator' }
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_EXPECT_OK(graph.StartRun({}));
absl::SleepFor(absl::Milliseconds(10));
graph.Pause();
absl::SleepFor(absl::Milliseconds(20));
graph.Resume();
MP_EXPECT_OK(graph.WaitUntilDone());
}
// There should be no memory leaks, no error messages (requires manual
// inspection of the test log), etc.
TEST(CalculatorGraph, ConstructAndDestruct) { CalculatorGraph graph; }
// A regression test for b/36364314. UnitDelayCalculator outputs a packet in
// Open(). ErrorOnOpenCalculator fails in Open() if ERROR_ON_OPEN is true.
TEST(CalculatorGraph, RecoverAfterPreviousFailInOpen) {
const int max_count = 10;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'a'
input_side_packet: 'MAX_COUNT:max_count'
}
node {
calculator: 'UnitDelayCalculator'
input_stream: 'a'
output_stream: 'b'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'b'
output_stream: 'c'
}
node {
calculator: 'ErrorOnOpenCalculator'
input_stream: 'c'
output_stream: 'd'
input_side_packet: 'ERROR_ON_OPEN:fail'
}
node { calculator: 'IntSinkCalculator' input_stream: 'd' }
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(
graph.Initialize(config, {{"max_count", MakePacket<int>(max_count)}}));
for (int i = 0; i < 2; ++i) {
EXPECT_FALSE(graph.Run({{"fail", MakePacket<bool>(true)}}).ok());
MP_EXPECT_OK(graph.Run({{"fail", MakePacket<bool>(false)}}));
}
}
TEST(CalculatorGraph, ReuseValidatedGraphConfig) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
output_side_packet: "foo1"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "foo1"
output_side_packet: "foo2"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "input_in_run"
output_side_packet: "foo3"
}
packet_generator {
packet_generator: "StaticCounterStringGenerator"
input_side_packet: "created_by_factory"
input_side_packet: "input_in_initialize"
input_side_packet: "input_in_run"
input_side_packet: "foo3"
output_side_packet: "foo4"
}
node {
calculator: "GlobalCountSourceCalculator"
input_side_packet: "global_counter"
output_stream: "unused"
}
)pb");
ValidatedGraphConfig validated_graph;
MP_ASSERT_OK(validated_graph.Initialize(config));
std::atomic<int> global_counter(0);
Packet global_counter_packet = Adopt(new auto(&global_counter));
absl::FixedArray<CalculatorGraph> graphs(30);
for (int i = 0; i < graphs.size(); ++i) {
CalculatorGraph& graph = graphs[i];
int initial_generator_count =
StaticCounterStringGenerator::NumPacketsGenerated();
int initial_calculator_count = global_counter.load();
MP_ASSERT_OK(graph.Initialize(
config,
{{"created_by_factory", MakePacket<std::string>("default string")},
{"input_in_initialize", MakePacket<int>(10)},
{"global_counter", global_counter_packet}}));
EXPECT_EQ(initial_generator_count + 2,
StaticCounterStringGenerator::NumPacketsGenerated());
EXPECT_EQ(initial_calculator_count, global_counter.load());
}
for (int k = 0; k < 10; ++k) {
for (int i = 0; i < graphs.size(); ++i) {
CalculatorGraph& graph = graphs[i];
int initial_generator_count =
StaticCounterStringGenerator::NumPacketsGenerated();
int initial_calculator_count = global_counter.load();
MP_ASSERT_OK(graph.Run({{"input_in_run", MakePacket<int>(11)}}));
EXPECT_EQ(initial_generator_count + 2,
StaticCounterStringGenerator::NumPacketsGenerated());
EXPECT_EQ(initial_calculator_count +
GlobalCountSourceCalculator::kNumOutputPackets,
global_counter.load());
}
}
}
class TestRangeStdDevSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_side_packet: 'node_converted'
output_stream: 'DATA:range'
output_stream: 'SUM:range_sum'
output_stream: 'MEAN:range_mean'
output_stream: 'STDDEV:range_stddev'
node {
calculator: 'RangeCalculator'
output_stream: 'range'
output_stream: 'range_sum'
output_stream: 'range_mean'
input_side_packet: 'node_converted'
}
node {
calculator: 'StdDevCalculator'
input_stream: 'DATA:range'
input_stream: 'MEAN:range_mean'
output_stream: 'range_stddev'
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(TestRangeStdDevSubgraph);
class TestMergeSaverSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'DATA1:range1'
input_stream: 'DATA2:range2'
output_stream: 'MERGE:merge'
output_stream: 'FINAL:final'
node {
name: 'merger'
calculator: 'MergeCalculator'
input_stream: 'range1'
input_stream: 'range2'
output_stream: 'merge'
}
node {
calculator: 'SaverCalculator'
input_stream: 'merge'
output_stream: 'final'
}
)pb");
return config;
}
};
REGISTER_MEDIAPIPE_GRAPH(TestMergeSaverSubgraph);
CalculatorGraphConfig GetConfigWithSubgraphs() {
CalculatorGraphConfig proto =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
# Ensure stream name for FindOutputStreamManager
output_stream: 'MERGE:merge'
packet_generator {
packet_generator: 'IntSplitterPacketGenerator'
input_side_packet: 'node_3'
output_side_packet: 'node_3_converted'
}
packet_generator {
packet_generator: 'TaggedIntSplitterPacketGenerator'
input_side_packet: 'node_5'
output_side_packet: 'HIGH:unused_high'
output_side_packet: 'LOW:unused_low'
output_side_packet: 'PAIR:node_5_converted'
}
node {
calculator: 'TestRangeStdDevSubgraph'
input_side_packet: 'node_3_converted'
output_stream: 'DATA:range3'
output_stream: 'SUM:range3_sum'
output_stream: 'MEAN:range3_mean'
output_stream: 'STDDEV:range3_stddev'
}
node {
calculator: 'TestRangeStdDevSubgraph'
input_side_packet: 'node_5_converted'
output_stream: 'DATA:range5'
output_stream: 'SUM:range5_sum'
output_stream: 'MEAN:range5_mean'
output_stream: 'STDDEV:range5_stddev'
}
node {
name: 'copy_range5'
calculator: 'PassThroughCalculator'
input_stream: 'range5'
output_stream: 'range5_copy'
}
node {
calculator: 'TestMergeSaverSubgraph'
input_stream: 'DATA1:range3'
input_stream: 'DATA2:range5_copy'
output_stream: 'MERGE:merge'
output_stream: 'FINAL:final'
}
node {
calculator: 'TestMergeSaverSubgraph'
input_stream: 'DATA1:range3_sum'
input_stream: 'DATA2:range5_sum'
output_stream: 'FINAL:final_sum'
}
node {
calculator: 'TestMergeSaverSubgraph'
input_stream: 'DATA1:range3_stddev'
input_stream: 'DATA2:range5_stddev'
output_stream: 'FINAL:final_stddev'
}
)pb");
return proto;
}
TEST(CalculatorGraph, RunsCorrectlyWithSubgraphs) {
CalculatorGraph graph;
CalculatorGraphConfig proto = GetConfigWithSubgraphs();
RunComprehensiveTest(&graph, proto, /*define_node_5=*/true);
}
TEST(CalculatorGraph, SetExecutorTwice) {
// SetExecutor must not be called more than once for the same executor name.
CalculatorGraph graph;
MP_EXPECT_OK(
graph.SetExecutor("xyz", std::make_shared<ThreadPoolExecutor>(1)));
MP_EXPECT_OK(
graph.SetExecutor("abc", std::make_shared<ThreadPoolExecutor>(1)));
absl::Status status =
graph.SetExecutor("xyz", std::make_shared<ThreadPoolExecutor>(1));
EXPECT_EQ(status.code(), absl::StatusCode::kAlreadyExists);
EXPECT_THAT(status.message(), testing::HasSubstr("xyz"));
}
TEST(CalculatorGraph, ReservedNameSetExecutor) {
// A reserved executor name such as "__gpu" must not be used.
CalculatorGraph graph;
absl::Status status =
graph.SetExecutor("__gpu", std::make_shared<ThreadPoolExecutor>(1));
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"),
testing::HasSubstr("reserved")));
}
TEST(CalculatorGraph, ReservedNameExecutorConfig) {
// A reserved executor name such as "__gpu" must not be used.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
executor {
name: '__gpu'
type: 'ThreadPoolExecutor'
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"),
testing::HasSubstr("reserved")));
}
TEST(CalculatorGraph, ReservedNameNodeExecutor) {
// A reserved executor name such as "__gpu" must not be used.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
executor: '__gpu'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"),
testing::HasSubstr("reserved")));
}
TEST(CalculatorGraph, NonExistentExecutor) {
// Any executor used by a calculator node must either be created by the
// graph (which requires an ExecutorConfig with a "type" field) or be
// provided to the graph with a CalculatorGraph::SetExecutor() call.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("xyz"),
testing::HasSubstr("not declared")));
}
TEST(CalculatorGraph, UndeclaredExecutor) {
// Any executor used by a calculator node must be declared in an
// ExecutorConfig, even if the executor is provided to the graph with a
// CalculatorGraph::SetExecutor() call.
CalculatorGraph graph;
MP_ASSERT_OK(
graph.SetExecutor("xyz", std::make_shared<ThreadPoolExecutor>(1)));
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("xyz"),
testing::HasSubstr("not declared")));
}
TEST(CalculatorGraph, UntypedExecutorDeclaredButNotSet) {
// If an executor is declared without a "type" field, it must be provided to
// the graph with a CalculatorGraph::SetExecutor() call.
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
executor { name: 'xyz' }
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("xyz"),
testing::HasSubstr("SetExecutor")));
}
TEST(CalculatorGraph, DuplicateExecutorConfig) {
// More than one ExecutorConfig cannot have the same name.
CalculatorGraph graph;
MP_ASSERT_OK(
graph.SetExecutor("xyz", std::make_shared<ThreadPoolExecutor>(1)));
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
executor { name: 'xyz' }
executor { name: 'xyz' }
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("xyz"),
testing::HasSubstr("duplicate")));
}
TEST(CalculatorGraph, TypedExecutorDeclaredAndSet) {
// If an executor is declared with a "type" field, it must not be provided
// to the graph with a CalculatorGraph::SetExecutor() call.
CalculatorGraph graph;
MP_ASSERT_OK(
graph.SetExecutor("xyz", std::make_shared<ThreadPoolExecutor>(1)));
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
executor {
name: 'xyz'
type: 'ThreadPoolExecutor'
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'in'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("xyz"),
testing::HasSubstr("SetExecutor")));
}
// The graph-level num_threads field and the ExecutorConfig for the default
// executor must not both be specified.
TEST(CalculatorGraph, NumThreadsAndDefaultExecutorConfig) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
num_threads: 1
executor {
type: 'ThreadPoolExecutor'
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
absl::Status status = graph.Initialize(config);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
testing::AllOf(testing::HasSubstr("num_threads"),
testing::HasSubstr("default executor")));
}
// The graph-level num_threads field and the ExecutorConfig for a non-default
// executor may coexist.
TEST(CalculatorGraph, NumThreadsAndNonDefaultExecutorConfig) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
num_threads: 1
executor {
name: 'xyz'
type: 'ThreadPoolExecutor'
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 1 }
}
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'mid'
}
node {
calculator: 'PassThroughCalculator'
executor: 'xyz'
input_stream: 'mid'
output_stream: 'out'
}
)pb");
MP_EXPECT_OK(graph.Initialize(config));
}
// Verifies that the application thread is used only when
// "ApplicationThreadExecutor" is specified. In this test
// "ApplicationThreadExecutor" is specified in the ExecutorConfig for the
// default executor.
TEST(CalculatorGraph, RunWithNumThreadsInExecutorConfig) {
const struct {
std::string executor_type;
int num_threads;
bool use_app_thread_is_expected;
} cases[] = {{"ApplicationThreadExecutor", 0, true},
{"<None>", 0, false},
{"ThreadPoolExecutor", 1, false}};
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
executor {
options {
[mediapipe.ThreadPoolExecutorOptions.ext] { num_threads: 0 }
}
}
node { calculator: 'PthreadSelfSourceCalculator' output_stream: 'out' }
)pb");
ThreadPoolExecutorOptions* default_executor_options =
config.mutable_executor(0)->mutable_options()->MutableExtension(
ThreadPoolExecutorOptions::ext);
for (int i = 0; i < ABSL_ARRAYSIZE(cases); ++i) {
default_executor_options->set_num_threads(cases[i].num_threads);
config.mutable_executor(0)->clear_type();
if (cases[i].executor_type != "<None>") {
config.mutable_executor(0)->set_type(cases[i].executor_type);
}
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
Packet out_packet;
MP_ASSERT_OK(
graph.ObserveOutputStream("out", [&out_packet](const Packet& packet) {
out_packet = packet;
return absl::OkStatus();
}));
MP_ASSERT_OK(graph.Run());
EXPECT_EQ(cases[i].use_app_thread_is_expected,
out_packet.Get<pthread_t>() == pthread_self())
<< "for case " << i;
}
}
TEST(CalculatorGraph, CalculatorGraphNotInitialized) {
CalculatorGraph graph;
EXPECT_FALSE(graph.Run().ok());
}
TEST(CalculatorGraph, SimulateAssertFailure) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
num_threads: 2
node {
calculator: 'PassThroughCalculator'
input_stream: 'in_a'
input_stream: 'in_b'
output_stream: 'out_a'
output_stream: 'out_b'
}
input_stream: 'in_a'
input_stream: 'in_b'
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.WaitUntilIdle());
// End the test here to simulate an ASSERT_ failure, which will skip the
// rest of the test and exit the test function immediately. The test should
// not hang in the CalculatorGraph destructor.
}
// Verifies Calculator::InputTimestamp() returns the expected value in Open(),
// Process(), and Close() for both source and non-source nodes. In this test
// the source node stops the graph.
TEST(CalculatorGraph, CheckInputTimestamp) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CheckInputTimestampSourceCalculator'
output_stream: 'integer'
}
node {
calculator: 'CheckInputTimestampSinkCalculator'
input_stream: 'integer'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
}
// Verifies Calculator::InputTimestamp() returns the expected value in Open(),
// Process(), and Close() for both source and non-source nodes. In this test
// the sink node stops the graph, which causes the framework to close the
// source node.
TEST(CalculatorGraph, CheckInputTimestamp2) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: 'CheckInputTimestamp2SourceCalculator'
output_stream: 'integer'
}
node {
calculator: 'CheckInputTimestamp2SinkCalculator'
input_stream: 'integer'
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
}
TEST(CalculatorGraph, GraphInputStreamWithTag) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "VIDEO_METADATA:video_metadata"
input_stream: "max_count"
node {
calculator: "PassThroughCalculator"
input_stream: "FIRST_INPUT:video_metadata"
input_stream: "max_count"
output_stream: "FIRST_INPUT:output_0"
output_stream: "output_1"
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output_0", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
for (int i = 0; i < 5; ++i) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_metadata", MakePacket<int>(i).At(Timestamp(i))));
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(5, packet_dump.size());
}
TEST(CalculatorGraph, GraphInputStreamBeforeStartRun) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "VIDEO_METADATA:video_metadata"
input_stream: "max_count"
node {
calculator: "PassThroughCalculator"
input_stream: "FIRST_INPUT:video_metadata"
input_stream: "max_count"
output_stream: "FIRST_INPUT:output_0"
output_stream: "output_1"
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output_0", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
ASSERT_EQ(graph
.AddPacketToInputStream("video_metadata",
MakePacket<int>(0).At(Timestamp(0)))
.code(),
absl::StatusCode::kFailedPrecondition);
}
// Returns the first packet of the input stream.
class FirstPacketFilterCalculator : public CalculatorBase {
public:
FirstPacketFilterCalculator() {}
~FirstPacketFilterCalculator() override {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!seen_first_packet_) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
cc->Outputs().Index(0).Close();
seen_first_packet_ = true;
}
return absl::OkStatus();
}
private:
bool seen_first_packet_ = false;
};
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
constexpr int kDefaultMaxCount = 1000;
TEST(CalculatorGraph, TestPollPacket) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
for (int queue_size = 1; queue_size < 10; ++queue_size) {
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
poller.SetMaxQueueSize(queue_size);
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
}
TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node1 = config.add_node();
node1->set_calculator("CountingSourceCalculator");
node1->add_output_stream("stream1");
node1->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraphConfig::Node* node2 = config.add_node();
node2->set_calculator("PassThroughCalculator");
node2->add_input_stream("stream1");
node2->add_output_stream("stream2");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller1 = graph.AddOutputStreamPoller("stream1");
ASSERT_TRUE(status_or_poller1.ok());
OutputStreamPoller poller1 = std::move(status_or_poller1.value());
auto status_or_poller2 = graph.AddOutputStreamPoller("stream2");
ASSERT_TRUE(status_or_poller2.ok());
OutputStreamPoller poller2 = std::move(status_or_poller2.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet1;
Packet packet2;
int num_packets1 = 0;
int num_packets2 = 0;
int running_pollers = 2;
while (running_pollers > 0) {
if (poller1.Next(&packet1)) {
EXPECT_EQ(num_packets1++, packet1.Get<int>());
} else {
--running_pollers;
}
if (poller2.Next(&packet2)) {
EXPECT_EQ(num_packets2++, packet2.Get<int>());
} else {
--running_pollers;
}
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller1.Next(&packet1));
EXPECT_FALSE(poller2.Next(&packet2));
EXPECT_EQ(kDefaultMaxCount, num_packets1);
EXPECT_EQ(kDefaultMaxCount, num_packets2);
}
class TimestampBoundTestCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
if (count_ % 50 == 1) {
// Outputs packets at t10 and t60.
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_)));
} else if (count_ % 15 == 7) {
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_));
}
absl::SleepFor(absl::Milliseconds(3));
++count_;
if (count_ == 110) {
return tool::StatusStop();
}
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
TEST(CalculatorGraph, TestPollPacketsWithTimestampNotification) {
std::string config_str = R"(
node {
calculator: "TimestampBoundTestCalculator"
output_stream: "foo"
}
)";
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
auto status_or_poller =
graph.AddOutputStreamPoller("foo", /*observe_timestamp_bounds=*/true);
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
Packet packet;
std::vector<int> timestamps;
std::vector<int> values;
MP_ASSERT_OK(graph.StartRun({}));
while (poller.Next(&packet)) {
if (packet.IsEmpty()) {
timestamps.push_back(packet.Timestamp().Value());
} else {
values.push_back(packet.Get<int>());
}
}
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_FALSE(poller.Next(&packet));
ASSERT_FALSE(timestamps.empty());
int prev_t = 0;
for (auto t : timestamps) {
EXPECT_TRUE(t > prev_t && t < 110);
prev_t = t;
}
ASSERT_EQ(3, values.size());
EXPECT_EQ(1, values[0]);
EXPECT_EQ(51, values[1]);
EXPECT_EQ(101, values[2]);
}
// Ensure that when a custom input stream handler is used to handle packets from
// input streams, an error message is outputted with the appropriate link to
// resolve the issue when the calculator doesn't handle inputs in monotonically
// increasing order of timestamps.
TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input0'
input_stream: 'input1'
node {
calculator: 'SimpleMuxCalculator'
input_stream: 'input0'
input_stream: 'input1'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
output_stream: 'output'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Send packets to input stream "input0" at timestamps 0 and 1 consecutively.
Timestamp input0_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(1, packet_dump[0].Get<int>());
++input0_timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(3).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(3, packet_dump[1].Get<int>());
// Send a packet to input stream "input1" at timestamp 0 after sending two
// packets at timestamps 0 and 1 to input stream "input0". This will result
// in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle
// inputs from all streams in monotonically increasing order of timestamps.
Timestamp input1_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(2).At(input1_timestamp)));
absl::Status run_status = graph.WaitUntilIdle();
EXPECT_THAT(
run_status.ToString(),
testing::AllOf(
// The core problem.
testing::HasSubstr("timestamp mismatch on a calculator"),
testing::HasSubstr(
"timestamps that are not strictly monotonically increasing"),
// Link to the possible solution.
testing::HasSubstr("ImmediateInputStreamHandler class comment")));
}
void DoTestMultipleGraphRuns(absl::string_view input_stream_handler,
bool select_packet) {
std::string graph_proto = absl::StrFormat(R"(
input_stream: 'input'
input_stream: 'select'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'select'
input_stream_handler {
input_stream_handler: "%s"
}
output_stream: 'output'
output_stream: 'select_out'
}
)",
input_stream_handler.data());
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
struct Run {
Timestamp timestamp;
int value;
};
std::vector<Run> runs = {{.timestamp = Timestamp(2000), .value = 2},
{.timestamp = Timestamp(1000), .value = 1}};
for (const Run& run : runs) {
MP_ASSERT_OK(graph.StartRun({}));
if (select_packet) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(run.timestamp)));
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(run.value).At(run.timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(run.value, packet_dump[0].Get<int>());
EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
packet_dump.clear();
}
}
TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) {
DoTestMultipleGraphRuns("BarrierInputStreamHandler", true);
DoTestMultipleGraphRuns("DefaultInputStreamHandler", true);
DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true);
DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true);
DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false);
DoTestMultipleGraphRuns("MuxInputStreamHandler", true);
DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true);
DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true);
}
} // namespace
} // namespace mediapipe