Updated the Begin/EndLoopCalculator to be able to handle mediapipe::Tensor

type.

PiperOrigin-RevId: 508552066
This commit is contained in:
MediaPipe Team 2023-02-09 20:31:14 -08:00 committed by Copybara-Service
parent 8a49a5f822
commit d61b7dbef8
6 changed files with 131 additions and 13 deletions

View File

@ -198,6 +198,7 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -220,6 +221,7 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -236,10 +238,12 @@ cc_test(
":begin_loop_calculator", ":begin_loop_calculator",
":end_loop_calculator", ":end_loop_calculator",
":gate_calculator", ":gate_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",

View File

@ -20,6 +20,7 @@
#include "mediapipe/calculators/core/end_loop_calculator.h" #include "mediapipe/calculators/core/end_loop_calculator.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -444,5 +445,67 @@ TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9}))); PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
} }
absl::Status InitBeginEndTensorLoopTestGraph(
CalculatorGraph& graph, std::vector<Packet>& output_packets) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
num_threads: 4
input_stream: "tensors"
node {
calculator: "BeginLoopTensorCalculator"
input_stream: "ITERABLE:tensors"
output_stream: "ITEM:tensor"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "PassThroughCalculator"
input_stream: "tensor"
output_stream: "passed_tensor"
}
node {
calculator: "EndLoopTensorCalculator"
input_stream: "ITEM:passed_tensor"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:output_tensors"
}
)pb");
tool::AddVectorSink("output_tensors", &graph_config, &output_packets);
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
return graph.StartRun({});
}
TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) {
// Initialize the graph.
CalculatorGraph graph;
std::vector<Packet> output_packets;
MP_ASSERT_OK(InitBeginEndTensorLoopTestGraph(graph, output_packets));
// Prepare the inputs and run.
Timestamp input_timestamp = Timestamp(0);
std::vector<mediapipe::Tensor> tensors;
for (int i = 0; i < 4; i++) {
tensors.emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{4, 3, 2, 1});
}
Packet vector_packet =
MakePacket<std::vector<mediapipe::Tensor>>(std::move(tensors));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensors", std::move(vector_packet).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Verify the output packet.
EXPECT_EQ(output_packets.size(), 1);
const std::vector<Tensor>& output_tensors =
output_packets[0].Get<std::vector<Tensor>>();
EXPECT_EQ(output_tensors.size(), 4);
for (int i = 0; i < output_tensors.size(); i++) {
EXPECT_THAT(output_tensors[i].shape().dims,
testing::ElementsAre(4, 3, 2, 1));
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -20,6 +20,7 @@
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe { namespace mediapipe {
@ -55,4 +56,8 @@ REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator; typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
REGISTER_CALCULATOR(BeginLoopUint64tCalculator); REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
// A calculator to process std::vector<mediapipe::Tensor>.
typedef BeginLoopCalculator<std::vector<Tensor>> BeginLoopTensorCalculator;
REGISTER_CALCULATOR(BeginLoopTensorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,7 +15,6 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ #ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -24,6 +23,7 @@
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
namespace mediapipe { namespace mediapipe {
@ -112,6 +112,23 @@ class BeginLoopCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
Timestamp last_timestamp = loop_internal_timestamp_; Timestamp last_timestamp = loop_internal_timestamp_;
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) { if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
// Try to consume the ITERABLE packet if possible to obtain the ownership
// and emit the item packets by moving them.
// If the ITERABLE packet is not consumable, then try to copy each item
// instead. If the ITEM type is not copy constructible, an error will be
// returned.
auto iterable_ptr_or =
cc->Inputs().Tag("ITERABLE").Value().Consume<IterableT>();
if (iterable_ptr_or.ok()) {
for (auto& item : *iterable_ptr_or.value()) {
Packet item_packet = MakePacket<ItemT>(std::move(item));
cc->Outputs().Tag("ITEM").AddPacket(
item_packet.At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
} else {
if constexpr (std::is_copy_constructible<ItemT>()) {
const IterableT& collection = const IterableT& collection =
cc->Inputs().Tag("ITERABLE").template Get<IterableT>(); cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
for (const auto& item : collection) { for (const auto& item : collection) {
@ -120,6 +137,14 @@ class BeginLoopCalculator : public CalculatorBase {
ForwardClonePackets(cc, loop_internal_timestamp_); ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_; ++loop_internal_timestamp_;
} }
} else {
return absl::InternalError(
"The element type is not copiable. Consider making the "
"BeginLoopCalculator the sole owner of the input packet so that "
"the "
"items can be consumed and moved.");
}
}
} }
// The collection was empty and nothing was processed. // The collection was empty and nothing was processed.
@ -138,7 +163,6 @@ class BeginLoopCalculator : public CalculatorBase {
.Tag("BATCH_END") .Tag("BATCH_END")
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp()) .AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
.At(Timestamp(loop_internal_timestamp_ - 1))); .At(Timestamp(loop_internal_timestamp_ - 1)));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -21,6 +21,7 @@
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/util/render_data.pb.h" #include "mediapipe/util/render_data.pb.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
@ -52,8 +53,9 @@ typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
EndLoopClassificationListCalculator; EndLoopClassificationListCalculator;
REGISTER_CALCULATOR(EndLoopClassificationListCalculator); REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator; typedef EndLoopCalculator<std::vector<TfLiteTensor>>
REGISTER_CALCULATOR(EndLoopTensorCalculator); EndLoopTfLiteTensorCalculator;
REGISTER_CALCULATOR(EndLoopTfLiteTensorCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>> typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
EndLoopDetectionCalculator; EndLoopDetectionCalculator;
@ -62,4 +64,7 @@ REGISTER_CALCULATOR(EndLoopDetectionCalculator);
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator; typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
REGISTER_CALCULATOR(EndLoopMatrixCalculator); REGISTER_CALCULATOR(EndLoopMatrixCalculator);
typedef EndLoopCalculator<std::vector<Tensor>> EndLoopTensorCalculator;
REGISTER_CALCULATOR(EndLoopTensorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,6 +15,8 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_ #ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
#include <type_traits>
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -75,8 +77,23 @@ class EndLoopCalculator : public CalculatorBase {
if (!input_stream_collection_) { if (!input_stream_collection_) {
input_stream_collection_.reset(new IterableT); input_stream_collection_.reset(new IterableT);
} }
// Try to consume the item and move it into the collection. If the items
// are not consumable, then try to copy them instead. If the items are
// not copiable, then an error will be returned.
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
if (item_ptr_or.ok()) {
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
} else {
if constexpr (std::is_copy_constructible_v<ItemT>) {
input_stream_collection_->push_back( input_stream_collection_->push_back(
cc->Inputs().Tag("ITEM").template Get<ItemT>()); cc->Inputs().Tag("ITEM").template Get<ItemT>());
} else {
return absl::InternalError(
"The item type is not copiable. Consider making the "
"EndLoopCalculator the sole owner of the input packets so that "
"it can be moved instead of copying.");
}
}
} }
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal