Stream/SidePacket == and != operators

PiperOrigin-RevId: 504114182
This commit is contained in:
MediaPipe Team 2023-01-23 16:41:32 -08:00 committed by Copybara-Service
parent 873d7181bf
commit 2465e47b01
2 changed files with 59 additions and 0 deletions

View File

@ -206,6 +206,16 @@ class SourceImpl {
return ConnectTo(dest); return ConnectTo(dest);
} }
template <typename U>
bool operator==(const SourceImpl<IsSide, U>& other) {
return base_ == other.base_;
}
template <typename U>
bool operator!=(const SourceImpl<IsSide, U>& other) {
return !(*this == other);
}
Src& SetName(std::string name) { Src& SetName(std::string name) {
base_->name_ = std::move(name); base_->name_ = std::move(name);
return *this; return *this;
@ -218,6 +228,9 @@ class SourceImpl {
} }
private: private:
template <bool, typename U>
friend class SourceImpl;
// Never null. // Never null.
SourceBase* base_; SourceBase* base_;
}; };

View File

@ -494,5 +494,51 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }
TEST(BuilderTest, TestStreamEqualsNotEqualsOperators) {
Graph graph;
Stream<AnyType> input0 = graph.In(0);
EXPECT_TRUE(input0 == input0);
EXPECT_FALSE(input0 != input0);
EXPECT_TRUE(input0 == input0.Cast<int>());
EXPECT_FALSE(input0.Cast<float>() != input0);
EXPECT_TRUE(input0.Cast<float>() == input0.Cast<int>());
EXPECT_FALSE(input0.Cast<float>() != input0.Cast<int>());
Stream<AnyType> input1 = graph.In(1);
EXPECT_FALSE(input0 == input1);
EXPECT_TRUE(input0 != input1);
input1 = input0;
EXPECT_TRUE(input0 == input1);
EXPECT_FALSE(input0 != input1);
EXPECT_TRUE(input0.Cast<int>() == input1.Cast<int>());
EXPECT_FALSE(input0.Cast<float>() != input1.Cast<float>());
}
TEST(BuilderTest, TestSidePacketEqualsNotEqualsOperators) {
Graph graph;
SidePacket<AnyType> side_input0 = graph.SideIn(0);
EXPECT_TRUE(side_input0 == side_input0);
EXPECT_FALSE(side_input0 != side_input0);
EXPECT_TRUE(side_input0 == side_input0.Cast<int>());
EXPECT_FALSE(side_input0.Cast<float>() != side_input0);
EXPECT_TRUE(side_input0.Cast<float>() == side_input0.Cast<int>());
EXPECT_FALSE(side_input0.Cast<float>() != side_input0.Cast<int>());
SidePacket<AnyType> side_input1 = graph.SideIn(1);
EXPECT_FALSE(side_input0 == side_input1);
EXPECT_TRUE(side_input0 != side_input1);
side_input1 = side_input0;
EXPECT_TRUE(side_input0 == side_input1);
EXPECT_FALSE(side_input0 != side_input1);
EXPECT_TRUE(side_input0.Cast<int>() == side_input1.Cast<int>());
EXPECT_FALSE(side_input0.Cast<float>() != side_input1.Cast<float>());
}
} // namespace } // namespace
} // namespace mediapipe::api2::builder } // namespace mediapipe::api2::builder