Internal change
PiperOrigin-RevId: 489135553
This commit is contained in:
parent
899c87466e
commit
ea4989b6f1
|
@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) {
|
||||||
: absl::StrCat("Timestamp(", t.DebugString(), ")");
|
: absl::StrCat("Timestamp(", t.DebugString(), ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::string SourceString(Packet packet) {
|
|
||||||
std::ostringstream oss;
|
|
||||||
if (packet.IsEmpty()) {
|
|
||||||
oss << "Packet()";
|
|
||||||
} else {
|
|
||||||
oss << "MakePacket<" << MediaPipeTypeStringOrDemangled<T>() << ">("
|
|
||||||
<< packet.Get<T>() << ")";
|
|
||||||
}
|
|
||||||
oss << ".At(" << SourceString(packet.Timestamp()) << ")";
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PacketContainer, typename PacketContent>
|
|
||||||
class PacketsEqMatcher
|
|
||||||
: public ::testing::MatcherInterface<const PacketContainer&> {
|
|
||||||
public:
|
|
||||||
PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
|
|
||||||
void DescribeTo(::std::ostream* os) const override {
|
|
||||||
*os << "The expected packet contents: \n";
|
|
||||||
Print(packets_, os);
|
|
||||||
}
|
|
||||||
bool MatchAndExplain(
|
|
||||||
const PacketContainer& value,
|
|
||||||
::testing::MatchResultListener* listener) const override {
|
|
||||||
if (!Equals(packets_, value)) {
|
|
||||||
if (listener->IsInterested()) {
|
|
||||||
*listener << "The actual packet contents: \n";
|
|
||||||
Print(value, listener->stream());
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
|
|
||||||
if (c1.size() != c2.size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
|
|
||||||
Packet p1 = *i1, p2 = *i2;
|
|
||||||
if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() ||
|
|
||||||
(!p1.IsEmpty() &&
|
|
||||||
p1.Get<PacketContent>() != p2.Get<PacketContent>())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
void Print(const PacketContainer& packets, ::std::ostream* os) const {
|
|
||||||
for (auto it = packets.begin(); it != packets.end(); ++it) {
|
|
||||||
const Packet& packet = *it;
|
|
||||||
*os << (it == packets.begin() ? "{" : "");
|
|
||||||
*os << SourceString<PacketContent>(packet);
|
|
||||||
*os << (std::next(it) == packets.end() ? "}" : ", ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const PacketContainer packets_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename PacketContainer, typename PacketContent>
|
|
||||||
::testing::Matcher<const PacketContainer&> PacketsEq(
|
|
||||||
const PacketContainer& packets) {
|
|
||||||
return MakeMatcher(
|
|
||||||
new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Calculator::Process callback function.
|
// A Calculator::Process callback function.
|
||||||
typedef std::function<absl::Status(const InputStreamShardSet&,
|
typedef std::function<absl::Status(const InputStreamShardSet&,
|
||||||
OutputStreamShardSet*)>
|
OutputStreamShardSet*)>
|
||||||
|
@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||||
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
||||||
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
||||||
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
auto BoolPacketsEq = PacketsEq<std::vector<Packet>, bool>;
|
|
||||||
auto IntPacketsEq = PacketsEq<std::vector<Packet>, int>;
|
|
||||||
|
|
||||||
// Configure the test.
|
// Configure the test.
|
||||||
SetUpInputData();
|
SetUpInputData();
|
||||||
SetUpSimulationClock();
|
SetUpSimulationClock();
|
||||||
|
@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[15],
|
input_packets_[0], input_packets_[2], input_packets_[15],
|
||||||
input_packets_[17], input_packets_[19],
|
input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output));
|
EXPECT_THAT(out_1_packets_,
|
||||||
|
ElementsAreArray(PacketMatchers<int>(expected_output)));
|
||||||
|
|
||||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
||||||
std::vector<Packet> expected_output_2 = {
|
std::vector<Packet> expected_output_2 = {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||||
input_packets_[15], input_packets_[17], input_packets_[19],
|
input_packets_[15], input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2));
|
EXPECT_THAT(out_2_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<int>(expected_output_2)));
|
||||||
|
|
||||||
// Validate the ALLOW stream output.
|
// Validate the ALLOW stream output.
|
||||||
std::vector<Packet> expected_allow = {
|
std::vector<Packet> expected_allow = {
|
||||||
|
@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
MakePacket<bool>(true).At(Timestamp(190000)),
|
MakePacket<bool>(true).At(Timestamp(190000)),
|
||||||
MakePacket<bool>(false).At(Timestamp(200000)),
|
MakePacket<bool>(false).At(Timestamp(200000)),
|
||||||
};
|
};
|
||||||
EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow));
|
EXPECT_THAT(allow_packets_,
|
||||||
|
ElementsAreArray(PacketMatchers<bool>(expected_allow)));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Packet> StripBoundsUpdates(const std::vector<Packet>& packets,
|
std::vector<Packet> StripBoundsUpdates(const std::vector<Packet>& packets,
|
||||||
|
@ -891,9 +823,6 @@ std::vector<Packet> StripBoundsUpdates(const std::vector<Packet>& packets,
|
||||||
// Shows how FlowLimiterCalculator releases auxiliary input packets.
|
// Shows how FlowLimiterCalculator releases auxiliary input packets.
|
||||||
// In this test, auxiliary input packets arrive at twice the primary rate.
|
// In this test, auxiliary input packets arrive at twice the primary rate.
|
||||||
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
auto BoolPacketsEq = PacketsEq<std::vector<Packet>, bool>;
|
|
||||||
auto IntPacketsEq = PacketsEq<std::vector<Packet>, int>;
|
|
||||||
|
|
||||||
// Configure the test.
|
// Configure the test.
|
||||||
SetUpInputData();
|
SetUpInputData();
|
||||||
SetUpSimulationClock();
|
SetUpSimulationClock();
|
||||||
|
@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
MakePacket<int>(6).At(Timestamp(60000)),
|
MakePacket<int>(6).At(Timestamp(60000)),
|
||||||
Packet().At(Timestamp(80000)),
|
Packet().At(Timestamp(80000)),
|
||||||
};
|
};
|
||||||
EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output));
|
EXPECT_THAT(out_1_packets_,
|
||||||
|
ElementsAreArray(PacketMatchers<int>(expected_output)));
|
||||||
|
|
||||||
// Packets following input packets 2 and 6, and not input packets 4 and 8.
|
// Packets following input packets 2 and 6, and not input packets 4 and 8.
|
||||||
std::vector<Packet> expected_auxiliary_output = {
|
std::vector<Packet> expected_auxiliary_output = {
|
||||||
|
@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
};
|
};
|
||||||
std::vector<Packet> actual_2 =
|
std::vector<Packet> actual_2 =
|
||||||
StripBoundsUpdates(out_2_packets, Timestamp(90000));
|
StripBoundsUpdates(out_2_packets, Timestamp(90000));
|
||||||
EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output));
|
EXPECT_THAT(actual_2,
|
||||||
|
ElementsAreArray(PacketMatchers<int>(expected_auxiliary_output)));
|
||||||
std::vector<Packet> expected_3 =
|
std::vector<Packet> expected_3 =
|
||||||
StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999));
|
StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999));
|
||||||
std::vector<Packet> actual_3 =
|
std::vector<Packet> actual_3 =
|
||||||
StripBoundsUpdates(out_3_packets, Timestamp(39999));
|
StripBoundsUpdates(out_3_packets, Timestamp(39999));
|
||||||
EXPECT_THAT(actual_3, IntPacketsEq(expected_3));
|
EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers<int>(expected_3)));
|
||||||
|
|
||||||
// Validate the ALLOW stream output.
|
// Validate the ALLOW stream output.
|
||||||
std::vector<Packet> expected_allow = {
|
std::vector<Packet> expected_allow = {
|
||||||
|
@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
MakePacket<bool>(true).At(Timestamp(60000)),
|
MakePacket<bool>(true).At(Timestamp(60000)),
|
||||||
MakePacket<bool>(false).At(Timestamp(80000)),
|
MakePacket<bool>(false).At(Timestamp(80000)),
|
||||||
};
|
};
|
||||||
EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow));
|
EXPECT_THAT(allow_packets_,
|
||||||
|
ElementsAreArray(PacketMatchers<bool>(expected_allow)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
|
@ -1469,6 +1469,7 @@ cc_test(
|
||||||
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
|
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
|
||||||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
||||||
"//mediapipe/framework/tool:sink",
|
"//mediapipe/framework/tool:sink",
|
||||||
|
"//mediapipe/util:packet_test_util",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) {
|
||||||
manager_->LockIntroData();
|
manager_->LockIntroData();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CalculatorGraph::GraphInputStream::SetNextTimestampBound(
|
||||||
|
Timestamp timestamp) {
|
||||||
|
shard_.SetNextTimestampBound(timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() {
|
void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() {
|
||||||
// Since GraphInputStream doesn't allow SetOffset() and
|
manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_);
|
||||||
// SetNextTimestampBound(), the timestamp bound to propagate is only
|
|
||||||
// determined by the timestamp of the output packets.
|
|
||||||
CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name()
|
|
||||||
<< "\" failed";
|
|
||||||
manager_->PropagateUpdatesToMirrors(
|
|
||||||
shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CalculatorGraph::GraphInputStream::Close() {
|
void CalculatorGraph::GraphInputStream::Close() {
|
||||||
|
@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream(
|
||||||
return AddPacketToInputStreamInternal(stream_name, std::move(packet));
|
return AddPacketToInputStreamInternal(stream_name, std::move(packet));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status CalculatorGraph::SetInputStreamTimestampBound(
|
||||||
|
const std::string& stream_name, Timestamp timestamp) {
|
||||||
|
std::unique_ptr<GraphInputStream>* stream =
|
||||||
|
mediapipe::FindOrNull(graph_input_streams_, stream_name);
|
||||||
|
RET_CHECK(stream).SetNoLogging() << absl::Substitute(
|
||||||
|
"SetInputStreamTimestampBound called on input stream \"$0\" which is not "
|
||||||
|
"a graph input stream.",
|
||||||
|
stream_name);
|
||||||
|
(*stream)->SetNextTimestampBound(timestamp);
|
||||||
|
(*stream)->PropagateUpdatesToMirrors();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
// We avoid having two copies of this code for AddPacketToInputStream(
|
// We avoid having two copies of this code for AddPacketToInputStream(
|
||||||
// const Packet&) and AddPacketToInputStream(Packet &&) by having this
|
// const Packet&) and AddPacketToInputStream(Packet &&) by having this
|
||||||
// internal-only templated version. T&& is a forwarding reference here, so
|
// internal-only templated version. T&& is a forwarding reference here, so
|
||||||
|
|
|
@ -257,6 +257,10 @@ class CalculatorGraph {
|
||||||
absl::Status AddPacketToInputStream(const std::string& stream_name,
|
absl::Status AddPacketToInputStream(const std::string& stream_name,
|
||||||
Packet&& packet);
|
Packet&& packet);
|
||||||
|
|
||||||
|
// Indicates that input will arrive no earlier than a certain timestamp.
|
||||||
|
absl::Status SetInputStreamTimestampBound(const std::string& stream_name,
|
||||||
|
Timestamp timestamp);
|
||||||
|
|
||||||
// Sets the queue size of a graph input stream, overriding the graph default.
|
// Sets the queue size of a graph input stream, overriding the graph default.
|
||||||
absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name,
|
absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name,
|
||||||
int max_queue_size);
|
int max_queue_size);
|
||||||
|
@ -425,6 +429,8 @@ class CalculatorGraph {
|
||||||
|
|
||||||
void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); }
|
void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); }
|
||||||
|
|
||||||
|
void SetNextTimestampBound(Timestamp timestamp);
|
||||||
|
|
||||||
void PropagateUpdatesToMirrors();
|
void PropagateUpdatesToMirrors();
|
||||||
|
|
||||||
void Close();
|
void Close();
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/str_replace.h"
|
#include "absl/strings/str_replace.h"
|
||||||
#include "mediapipe/framework/calculator_context.h"
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
@ -24,6 +26,7 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/thread_pool_executor.h"
|
#include "mediapipe/framework/thread_pool_executor.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/util/packet_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase {
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(EmptyPacketCalculator);
|
REGISTER_CALCULATOR(EmptyPacketCalculator);
|
||||||
|
|
||||||
// This test shows that an output timestamp bound can be specified by outputing
|
// This test shows that an output timestamp bound can be specified by outputting
|
||||||
// an empty packet with a settled timestamp.
|
// an empty packet with a settled timestamp.
|
||||||
TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
|
TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
|
||||||
// OffsetAndBoundCalculator runs on parallel threads and sends ts
|
// OffsetAndBoundCalculator runs on parallel threads and sends ts
|
||||||
|
@ -1585,5 +1588,194 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) {
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test shows that input timestamp bounds can be specified using
|
||||||
|
// CalculatorGraph::SetInputStreamTimestampBound.
|
||||||
|
TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) {
|
||||||
|
std::string config_str = R"(
|
||||||
|
input_stream: "input_0"
|
||||||
|
node {
|
||||||
|
calculator: "ProcessBoundToPacketCalculator"
|
||||||
|
input_stream: "input_0"
|
||||||
|
output_stream: "output_0"
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> output_0_packets;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
|
||||||
|
output_0_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Send in timestamp bounds.
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
const int ts = 10 + i * 10;
|
||||||
|
MP_ASSERT_OK(graph.SetInputStreamTimestampBound(
|
||||||
|
"input_0", Timestamp(ts).NextAllowedInStream()));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9 timestamp bounds are converted to packets.
|
||||||
|
EXPECT_EQ(output_0_packets.size(), 9);
|
||||||
|
for (int i = 0; i < 9; ++i) {
|
||||||
|
EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the graph.
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test shows how an input stream with infrequent packets, such as
|
||||||
|
// configuration protobufs, can be consumed while processing more frequent
|
||||||
|
// packets, such as video frames.
|
||||||
|
TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) {
|
||||||
|
// PassThroughCalculator consuming two input streams, with default ISH.
|
||||||
|
std::string config_str = R"pb(
|
||||||
|
input_stream: "INFREQUENT:config"
|
||||||
|
input_stream: "FREQUENT:frame"
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughCalculator"
|
||||||
|
input_stream: "CONFIG:config"
|
||||||
|
input_stream: "VIDEO:frame"
|
||||||
|
output_stream: "VIDEO:output_frame"
|
||||||
|
output_stream: "CONFIG:output_config"
|
||||||
|
}
|
||||||
|
)pb";
|
||||||
|
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> frame_packets;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||||
|
"output_frame",
|
||||||
|
[&](const Packet& p) {
|
||||||
|
frame_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
},
|
||||||
|
/*observe_bound_updates=*/true));
|
||||||
|
std::vector<Packet> config_packets;
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||||
|
"output_config",
|
||||||
|
[&](const Packet& p) {
|
||||||
|
config_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
},
|
||||||
|
/*observe_bound_updates=*/true));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Utility functions to send packets or timestamp bounds.
|
||||||
|
auto send_fn = [&](std::string stream, std::string value, int ts) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
stream,
|
||||||
|
MakePacket<std::string>(absl::StrCat(value)).At(Timestamp(ts))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
};
|
||||||
|
auto bound_fn = [&](std::string stream, int ts) {
|
||||||
|
MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts)));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send in a frame packet.
|
||||||
|
send_fn("frame", "frame_0", 0);
|
||||||
|
// The frame is not processed yet.
|
||||||
|
EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers<std::string>({})));
|
||||||
|
bound_fn("config", 10000);
|
||||||
|
// The frame is processed after a fresh config timestamp bound arrives.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
})));
|
||||||
|
|
||||||
|
// Send in a frame packet.
|
||||||
|
send_fn("frame", "frame_1", 20000);
|
||||||
|
// The frame is not processed yet.
|
||||||
|
// The PassThroughCalculator with TimestampOffset 0 now propagates
|
||||||
|
// Timestamp bound 10000 to both "output_frame" and "output_config",
|
||||||
|
// which appears here as Packet().At(Timestamp(9999). The timestamp
|
||||||
|
// bounds at 29999 and 50000 are propagated similarly.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
})));
|
||||||
|
bound_fn("config", 30000);
|
||||||
|
// The frame is processed after a fresh config timestamp bound arrives.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
MakePacket<std::string>("frame_1").At(Timestamp(20000)),
|
||||||
|
})));
|
||||||
|
|
||||||
|
// Send in a frame packet.
|
||||||
|
send_fn("frame", "frame_2", 40000);
|
||||||
|
// The frame is not processed yet.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
MakePacket<std::string>("frame_1").At(Timestamp(20000)),
|
||||||
|
Packet().At(Timestamp(29999)),
|
||||||
|
})));
|
||||||
|
send_fn("config", "config_1", 50000);
|
||||||
|
// The frame is processed after a fresh config arrives.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
MakePacket<std::string>("frame_1").At(Timestamp(20000)),
|
||||||
|
Packet().At(Timestamp(29999)),
|
||||||
|
MakePacket<std::string>("frame_2").At(Timestamp(40000)),
|
||||||
|
})));
|
||||||
|
|
||||||
|
// Send in a frame packet.
|
||||||
|
send_fn("frame", "frame_3", 60000);
|
||||||
|
// The frame is not processed yet.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
MakePacket<std::string>("frame_1").At(Timestamp(20000)),
|
||||||
|
Packet().At(Timestamp(29999)),
|
||||||
|
MakePacket<std::string>("frame_2").At(Timestamp(40000)),
|
||||||
|
Packet().At(Timestamp(50000)),
|
||||||
|
})));
|
||||||
|
bound_fn("config", 70000);
|
||||||
|
// The frame is processed after a fresh config timestamp bound arrives.
|
||||||
|
EXPECT_THAT(frame_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
MakePacket<std::string>("frame_0").At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
MakePacket<std::string>("frame_1").At(Timestamp(20000)),
|
||||||
|
Packet().At(Timestamp(29999)),
|
||||||
|
MakePacket<std::string>("frame_2").At(Timestamp(40000)),
|
||||||
|
Packet().At(Timestamp(50000)),
|
||||||
|
MakePacket<std::string>("frame_3").At(Timestamp(60000)),
|
||||||
|
})));
|
||||||
|
|
||||||
|
// One config packet is deleivered.
|
||||||
|
EXPECT_THAT(config_packets,
|
||||||
|
ElementsAreArray(PacketMatchers<std::string>({
|
||||||
|
Packet().At(Timestamp(0)),
|
||||||
|
Packet().At(Timestamp(9999)),
|
||||||
|
Packet().At(Timestamp(20000)),
|
||||||
|
Packet().At(Timestamp(29999)),
|
||||||
|
Packet().At(Timestamp(40000)),
|
||||||
|
MakePacket<std::string>("config_1").At(Timestamp(50000)),
|
||||||
|
Packet().At(Timestamp(60000)),
|
||||||
|
})));
|
||||||
|
|
||||||
|
// Shutdown the graph.
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -32,30 +32,29 @@ namespace mediapipe {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template <typename PayloadType>
|
template <typename PayloadType>
|
||||||
class PacketMatcher : public ::testing::MatcherInterface<const Packet&> {
|
class PacketMatcher : public testing::MatcherInterface<const Packet&> {
|
||||||
public:
|
public:
|
||||||
template <typename InnerMatcher>
|
template <typename InnerMatcher>
|
||||||
explicit PacketMatcher(InnerMatcher inner_matcher)
|
explicit PacketMatcher(InnerMatcher inner_matcher)
|
||||||
: inner_matcher_(
|
: inner_matcher_(
|
||||||
::testing::SafeMatcherCast<const PayloadType&>(inner_matcher)) {}
|
testing::SafeMatcherCast<const PayloadType&>(inner_matcher)) {}
|
||||||
|
|
||||||
// Returns true iff the packet contains value of PayloadType satisfying
|
// Returns true iff the packet contains value of PayloadType satisfying
|
||||||
// the inner matcher.
|
// the inner matcher.
|
||||||
bool MatchAndExplain(
|
bool MatchAndExplain(const Packet& packet,
|
||||||
const Packet& packet,
|
testing::MatchResultListener* listener) const override {
|
||||||
::testing::MatchResultListener* listener) const override {
|
|
||||||
if (!packet.ValidateAsType<PayloadType>().ok()) {
|
if (!packet.ValidateAsType<PayloadType>().ok()) {
|
||||||
*listener << packet.DebugString() << " does not contain expected type "
|
*listener << packet.DebugString() << " does not contain expected type "
|
||||||
<< ExpectedTypeName();
|
<< ExpectedTypeName();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
::testing::StringMatchResultListener match_listener;
|
testing::StringMatchResultListener match_listener;
|
||||||
const PayloadType& payload = packet.Get<PayloadType>();
|
const PayloadType& payload = packet.Get<PayloadType>();
|
||||||
const bool matches =
|
const bool matches =
|
||||||
inner_matcher_.MatchAndExplain(payload, &match_listener);
|
inner_matcher_.MatchAndExplain(payload, &match_listener);
|
||||||
const std::string explanation = match_listener.str();
|
const std::string explanation = match_listener.str();
|
||||||
*listener << packet.DebugString() << " containing value "
|
*listener << packet.DebugString() << " containing value "
|
||||||
<< ::testing::PrintToString(payload);
|
<< testing::PrintToString(payload);
|
||||||
if (!explanation.empty()) {
|
if (!explanation.empty()) {
|
||||||
*listener << ", which " << explanation;
|
*listener << ", which " << explanation;
|
||||||
}
|
}
|
||||||
|
@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface<const Packet&> {
|
||||||
return ::mediapipe::Demangle(typeid(PayloadType).name());
|
return ::mediapipe::Demangle(typeid(PayloadType).name());
|
||||||
}
|
}
|
||||||
|
|
||||||
const ::testing::Matcher<const PayloadType&> inner_matcher_;
|
const testing::Matcher<const PayloadType&> inner_matcher_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline std::string SourceString(Timestamp t) {
|
||||||
|
return (t.IsSpecialValue())
|
||||||
|
? t.DebugString()
|
||||||
|
: absl::StrCat("Timestamp(", t.DebugString(), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::string SourceString(Packet packet) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
if (packet.IsEmpty()) {
|
||||||
|
oss << "Packet()";
|
||||||
|
} else {
|
||||||
|
oss << "MakePacket<" << MediaPipeTypeStringOrDemangled<T>() << ">("
|
||||||
|
<< packet.Get<T>() << ")";
|
||||||
|
}
|
||||||
|
oss << ".At(" << SourceString(packet.Timestamp()) << ")";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
// Creates matcher validating that the packet contains value of expected type
|
// Creates matcher validating that the packet contains value of expected type
|
||||||
|
@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface<const Packet&> {
|
||||||
//
|
//
|
||||||
// EXPECT_THAT(MakePacket<int>(42), PacketContains<int>(Eq(42)))
|
// EXPECT_THAT(MakePacket<int>(42), PacketContains<int>(Eq(42)))
|
||||||
template <typename PayloadType, typename InnerMatcher>
|
template <typename PayloadType, typename InnerMatcher>
|
||||||
inline ::testing::Matcher<const Packet&> PacketContains(
|
inline testing::Matcher<const Packet&> PacketContains(
|
||||||
InnerMatcher inner_matcher) {
|
InnerMatcher inner_matcher) {
|
||||||
return ::testing::MakeMatcher(
|
return testing::MakeMatcher(
|
||||||
new internal::PacketMatcher<PayloadType>(inner_matcher));
|
new internal::PacketMatcher<PayloadType>(inner_matcher));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +128,7 @@ inline ::testing::Matcher<const Packet&> PacketContains(
|
||||||
// Eq(42)))
|
// Eq(42)))
|
||||||
template <typename PayloadType, typename TimestampMatcher,
|
template <typename PayloadType, typename TimestampMatcher,
|
||||||
typename ContentMatcher>
|
typename ContentMatcher>
|
||||||
inline ::testing::Matcher<const Packet&> PacketContainsTimestampAndPayload(
|
inline testing::Matcher<const Packet&> PacketContainsTimestampAndPayload(
|
||||||
TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) {
|
TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) {
|
||||||
return testing::AllOf(
|
return testing::AllOf(
|
||||||
testing::Property("Packet::Timestamp", &Packet::Timestamp,
|
testing::Property("Packet::Timestamp", &Packet::Timestamp,
|
||||||
|
@ -118,6 +136,46 @@ inline ::testing::Matcher<const Packet&> PacketContainsTimestampAndPayload(
|
||||||
PacketContains<PayloadType>(content_matcher));
|
PacketContains<PayloadType>(content_matcher));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class PacketEqMatcher : public testing::MatcherInterface<Packet> {
|
||||||
|
public:
|
||||||
|
PacketEqMatcher(Packet packet) : packet_(packet) {}
|
||||||
|
void DescribeTo(::std::ostream* os) const override {
|
||||||
|
*os << "The expected packet: " << internal::SourceString<T>(packet_);
|
||||||
|
}
|
||||||
|
bool MatchAndExplain(Packet value,
|
||||||
|
testing::MatchResultListener* listener) const override {
|
||||||
|
bool unequal = (value.Timestamp() != packet_.Timestamp() ||
|
||||||
|
value.IsEmpty() != packet_.IsEmpty() ||
|
||||||
|
(!value.IsEmpty() && value.Get<T>() != packet_.Get<T>()));
|
||||||
|
if (unequal && listener->IsInterested()) {
|
||||||
|
*listener << "The actual packet: " << internal::SourceString<T>(value);
|
||||||
|
}
|
||||||
|
return !unequal;
|
||||||
|
}
|
||||||
|
const Packet packet_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
testing::Matcher<Packet> PacketEq(Packet packet) {
|
||||||
|
return MakeMatcher(new PacketEqMatcher<T>(packet));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<testing::Matcher<Packet>> PacketMatchers(
|
||||||
|
std::vector<Packet> packets) {
|
||||||
|
std::vector<testing::Matcher<Packet>> result;
|
||||||
|
for (const auto& packet : packets) {
|
||||||
|
result.push_back(PacketEq<T>(packet));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
using mediapipe::PacketContains;
|
||||||
|
using mediapipe::PacketContainsTimestampAndPayload;
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_
|
#endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_
|
||||||
|
|
Loading…
Reference in New Issue
Block a user