diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index fd9c4049b..827727056 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -198,6 +198,7 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -220,6 +221,7 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -236,10 +238,12 @@ cc_test( ":begin_loop_calculator", ":end_loop_calculator", ":gate_calculator", + ":pass_through_calculator", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", diff --git a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc index b1ebdd086..281a6fa8c 100644 --- a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc +++ b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/calculators/core/end_loop_calculator.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -444,5 +445,67 @@ TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) { PacketOfIntsEq(input_timestamp2, std::vector{6, 9}))); } +absl::Status InitBeginEndTensorLoopTestGraph( + CalculatorGraph& graph, std::vector& output_packets) { + auto graph_config = ParseTextProtoOrDie( + R"pb( + num_threads: 4 + input_stream: "tensors" + node { + calculator: "BeginLoopTensorCalculator" + input_stream: "ITERABLE:tensors" + output_stream: "ITEM:tensor" + output_stream: "BATCH_END:timestamp" + } + node { + calculator: "PassThroughCalculator" + input_stream: "tensor" + output_stream: "passed_tensor" + } + node { + calculator: "EndLoopTensorCalculator" + input_stream: "ITEM:passed_tensor" + input_stream: "BATCH_END:timestamp" + output_stream: "ITERABLE:output_tensors" + } + )pb"); + tool::AddVectorSink("output_tensors", &graph_config, &output_packets); + MP_RETURN_IF_ERROR(graph.Initialize(graph_config)); + return graph.StartRun({}); +} + +TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) { + // Initialize the graph. + CalculatorGraph graph; + std::vector output_packets; + MP_ASSERT_OK(InitBeginEndTensorLoopTestGraph(graph, output_packets)); + + // Prepare the inputs and run. + Timestamp input_timestamp = Timestamp(0); + std::vector tensors; + for (int i = 0; i < 4; i++) { + tensors.emplace_back(Tensor::ElementType::kFloat32, + Tensor::Shape{4, 3, 2, 1}); + } + Packet vector_packet = + MakePacket>(std::move(tensors)); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tensors", std::move(vector_packet).At(input_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Verify the output packet. + EXPECT_EQ(output_packets.size(), 1); + const std::vector& output_tensors = + output_packets[0].Get>(); + EXPECT_EQ(output_tensors.size(), 4); + for (int i = 0; i < output_tensors.size(); i++) { + EXPECT_THAT(output_tensors[i].shape().dims, + testing::ElementsAre(4, 3, 2, 1)); + } + + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 5bf8e65fc..441c66937 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" namespace mediapipe { @@ -55,4 +56,8 @@ REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator); typedef BeginLoopCalculator> BeginLoopUint64tCalculator; REGISTER_CALCULATOR(BeginLoopUint64tCalculator); +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopTensorCalculator; +REGISTER_CALCULATOR(BeginLoopTensorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index 6d17f9953..81fff39da 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -15,7 +15,6 @@ #ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ -#include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +23,7 @@ #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" namespace mediapipe { @@ -112,13 +112,38 @@ class BeginLoopCalculator : public CalculatorBase { absl::Status Process(CalculatorContext* cc) final { Timestamp last_timestamp = loop_internal_timestamp_; if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) { - const IterableT& collection = - cc->Inputs().Tag("ITERABLE").template Get(); - for (const auto& item : collection) { - cc->Outputs().Tag("ITEM").AddPacket( - MakePacket(item).At(loop_internal_timestamp_)); - ForwardClonePackets(cc, loop_internal_timestamp_); - ++loop_internal_timestamp_; + // Try to consume the ITERABLE packet if possible to obtain the ownership + // and emit the item packets by moving them. + // If the ITERABLE packet is not consumable, then try to copy each item + // instead. If the ITEM type is not copy constructible, an error will be + // returned. + auto iterable_ptr_or = + cc->Inputs().Tag("ITERABLE").Value().Consume(); + if (iterable_ptr_or.ok()) { + for (auto& item : *iterable_ptr_or.value()) { + Packet item_packet = MakePacket(std::move(item)); + cc->Outputs().Tag("ITEM").AddPacket( + item_packet.At(loop_internal_timestamp_)); + ForwardClonePackets(cc, loop_internal_timestamp_); + ++loop_internal_timestamp_; + } + } else { + if constexpr (std::is_copy_constructible()) { + const IterableT& collection = + cc->Inputs().Tag("ITERABLE").template Get(); + for (const auto& item : collection) { + cc->Outputs().Tag("ITEM").AddPacket( + MakePacket(item).At(loop_internal_timestamp_)); + ForwardClonePackets(cc, loop_internal_timestamp_); + ++loop_internal_timestamp_; + } + } else { + return absl::InternalError( + "The element type is not copiable. Consider making the " + "BeginLoopCalculator the sole owner of the input packet so that " + "the " + "items can be consumed and moved."); + } } } @@ -138,7 +163,6 @@ class BeginLoopCalculator : public CalculatorBase { .Tag("BATCH_END") .AddPacket(MakePacket(cc->InputTimestamp()) .At(Timestamp(loop_internal_timestamp_ - 1))); - return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/end_loop_calculator.cc b/mediapipe/calculators/core/end_loop_calculator.cc index 45cd8a9fd..4109fba06 100644 --- a/mediapipe/calculators/core/end_loop_calculator.cc +++ b/mediapipe/calculators/core/end_loop_calculator.cc @@ -21,6 +21,7 @@ #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/util/render_data.pb.h" #include "tensorflow/lite/interpreter.h" @@ -52,8 +53,9 @@ typedef EndLoopCalculator> EndLoopClassificationListCalculator; REGISTER_CALCULATOR(EndLoopClassificationListCalculator); -typedef EndLoopCalculator> EndLoopTensorCalculator; -REGISTER_CALCULATOR(EndLoopTensorCalculator); +typedef EndLoopCalculator> + EndLoopTfLiteTensorCalculator; +REGISTER_CALCULATOR(EndLoopTfLiteTensorCalculator); typedef EndLoopCalculator> EndLoopDetectionCalculator; @@ -62,4 +64,7 @@ REGISTER_CALCULATOR(EndLoopDetectionCalculator); typedef EndLoopCalculator> EndLoopMatrixCalculator; REGISTER_CALCULATOR(EndLoopMatrixCalculator); +typedef EndLoopCalculator> EndLoopTensorCalculator; +REGISTER_CALCULATOR(EndLoopTensorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/end_loop_calculator.h b/mediapipe/calculators/core/end_loop_calculator.h index 9f56657d0..2598194e6 100644 --- a/mediapipe/calculators/core/end_loop_calculator.h +++ b/mediapipe/calculators/core/end_loop_calculator.h @@ -15,6 +15,8 @@ #ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_ +#include + #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" @@ -75,8 +77,23 @@ class EndLoopCalculator : public CalculatorBase { if (!input_stream_collection_) { input_stream_collection_.reset(new IterableT); } - input_stream_collection_->push_back( - cc->Inputs().Tag("ITEM").template Get()); + // Try to consume the item and move it into the collection. If the items + // are not consumable, then try to copy them instead. If the items are + // not copiable, then an error will be returned. + auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume(); + if (item_ptr_or.ok()) { + input_stream_collection_->push_back(std::move(*item_ptr_or.value())); + } else { + if constexpr (std::is_copy_constructible_v) { + input_stream_collection_->push_back( + cc->Inputs().Tag("ITEM").template Get()); + } else { + return absl::InternalError( + "The item type is not copiable. Consider making the " + "EndLoopCalculator the sole owner of the input packets so that " + "it can be moved instead of copying."); + } + } } if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal