No public description

PiperOrigin-RevId: 565446429
This commit is contained in:
MediaPipe Team 2023-09-14 12:38:11 -07:00 committed by Copybara-Service
parent 85b19383b9
commit f2b11bf250
2 changed files with 79 additions and 9 deletions

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
@ -163,6 +164,75 @@ TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
}
TEST(BeginEndLoopCalculatorPossibleDataRaceTest,
EndLoopForIntegersDoesNotRace) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "ints"
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
# BEGIN: Data race possibility
# EndLoop###Calculator and another calculator using the same input
# may introduce race due to EndLoop###Calculator possibly consuming
# packet.
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
node {
calculator: "IncrementCalculator"
input_stream: "int_plus_one"
output_stream: "int_plus_two"
}
# END: Data race possibility
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_two"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_two"
}
)pb");
std::vector<Packet> int_plus_one_packets;
tool::AddVectorSink("ints_plus_one", &graph_config, &int_plus_one_packets);
std::vector<Packet> int_original_packets;
tool::AddVectorSink("ints_plus_two", &graph_config, &int_original_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
for (int i = 0; i < 100; ++i) {
std::vector<int> ints = {i, i + 1, i + 2};
Timestamp ts = Timestamp(i);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(ts)));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(int_plus_one_packets,
testing::ElementsAre(
PacketOfIntsEq(ts, std::vector<int>{i + 1, i + 2, i + 3})));
EXPECT_THAT(int_original_packets,
testing::ElementsAre(
PacketOfIntsEq(ts, std::vector<int>{i + 2, i + 3, i + 4})));
int_plus_one_packets.clear();
int_original_packets.clear();
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Passes non empty vector through or outputs empty vector in case of timestamp
// bound update.
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {

View File

@ -55,16 +55,16 @@ class EndLoopCalculator : public CalculatorBase {
if (!input_stream_collection_) {
input_stream_collection_.reset(new IterableT);
}
// Try to consume the item and move it into the collection. If the items
// are not consumable, then try to copy them instead. If the items are
// not copyable, then an error will be returned.
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
if (item_ptr_or.ok()) {
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
if constexpr (std::is_copy_constructible_v<ItemT>) {
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").Get<ItemT>());
} else {
if constexpr (std::is_copy_constructible_v<ItemT>) {
input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").template Get<ItemT>());
// Try to consume the item and move it into the collection. Return an
// error if the items are not consumable.
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
if (item_ptr_or.ok()) {
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
} else {
return absl::InternalError(
"The item type is not copiable. Consider making the "