Updated the Begin/EndLoopCalculator to be able to handle mediapipe::Tensor
type. PiperOrigin-RevId: 508552066
This commit is contained in:
parent
8a49a5f822
commit
d61b7dbef8
|
@ -198,6 +198,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -220,6 +221,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -236,10 +238,12 @@ cc_test(
|
|||
":begin_loop_calculator",
|
||||
":end_loop_calculator",
|
||||
":gate_calculator",
|
||||
":pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
@ -444,5 +445,67 @@ TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
|||
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
|
||||
}
|
||||
|
||||
absl::Status InitBeginEndTensorLoopTestGraph(
|
||||
CalculatorGraph& graph, std::vector<Packet>& output_packets) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
num_threads: 4
|
||||
input_stream: "tensors"
|
||||
node {
|
||||
calculator: "BeginLoopTensorCalculator"
|
||||
input_stream: "ITERABLE:tensors"
|
||||
output_stream: "ITEM:tensor"
|
||||
output_stream: "BATCH_END:timestamp"
|
||||
}
|
||||
node {
|
||||
calculator: "PassThroughCalculator"
|
||||
input_stream: "tensor"
|
||||
output_stream: "passed_tensor"
|
||||
}
|
||||
node {
|
||||
calculator: "EndLoopTensorCalculator"
|
||||
input_stream: "ITEM:passed_tensor"
|
||||
input_stream: "BATCH_END:timestamp"
|
||||
output_stream: "ITERABLE:output_tensors"
|
||||
}
|
||||
)pb");
|
||||
tool::AddVectorSink("output_tensors", &graph_config, &output_packets);
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
|
||||
return graph.StartRun({});
|
||||
}
|
||||
|
||||
TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
||||
// Initialize the graph.
|
||||
CalculatorGraph graph;
|
||||
std::vector<Packet> output_packets;
|
||||
MP_ASSERT_OK(InitBeginEndTensorLoopTestGraph(graph, output_packets));
|
||||
|
||||
// Prepare the inputs and run.
|
||||
Timestamp input_timestamp = Timestamp(0);
|
||||
std::vector<mediapipe::Tensor> tensors;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tensors.emplace_back(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{4, 3, 2, 1});
|
||||
}
|
||||
Packet vector_packet =
|
||||
MakePacket<std::vector<mediapipe::Tensor>>(std::move(tensors));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensors", std::move(vector_packet).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Verify the output packet.
|
||||
EXPECT_EQ(output_packets.size(), 1);
|
||||
const std::vector<Tensor>& output_tensors =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
EXPECT_EQ(output_tensors.size(), 4);
|
||||
for (int i = 0; i < output_tensors.size(); i++) {
|
||||
EXPECT_THAT(output_tensors[i].shape().dims,
|
||||
testing::ElementsAre(4, 3, 2, 1));
|
||||
}
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -55,4 +56,8 @@ REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
|
|||
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
|
||||
|
||||
// A calculator to process std::vector<mediapipe::Tensor>.
|
||||
typedef BeginLoopCalculator<std::vector<Tensor>> BeginLoopTensorCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopTensorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
@ -24,6 +23,7 @@
|
|||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -112,13 +112,38 @@ class BeginLoopCalculator : public CalculatorBase {
|
|||
absl::Status Process(CalculatorContext* cc) final {
|
||||
Timestamp last_timestamp = loop_internal_timestamp_;
|
||||
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
||||
const IterableT& collection =
|
||||
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
||||
for (const auto& item : collection) {
|
||||
cc->Outputs().Tag("ITEM").AddPacket(
|
||||
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||
++loop_internal_timestamp_;
|
||||
// Try to consume the ITERABLE packet if possible to obtain the ownership
|
||||
// and emit the item packets by moving them.
|
||||
// If the ITERABLE packet is not consumable, then try to copy each item
|
||||
// instead. If the ITEM type is not copy constructible, an error will be
|
||||
// returned.
|
||||
auto iterable_ptr_or =
|
||||
cc->Inputs().Tag("ITERABLE").Value().Consume<IterableT>();
|
||||
if (iterable_ptr_or.ok()) {
|
||||
for (auto& item : *iterable_ptr_or.value()) {
|
||||
Packet item_packet = MakePacket<ItemT>(std::move(item));
|
||||
cc->Outputs().Tag("ITEM").AddPacket(
|
||||
item_packet.At(loop_internal_timestamp_));
|
||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||
++loop_internal_timestamp_;
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_copy_constructible<ItemT>()) {
|
||||
const IterableT& collection =
|
||||
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
||||
for (const auto& item : collection) {
|
||||
cc->Outputs().Tag("ITEM").AddPacket(
|
||||
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||
++loop_internal_timestamp_;
|
||||
}
|
||||
} else {
|
||||
return absl::InternalError(
|
||||
"The element type is not copiable. Consider making the "
|
||||
"BeginLoopCalculator the sole owner of the input packet so that "
|
||||
"the "
|
||||
"items can be consumed and moved.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +163,6 @@ class BeginLoopCalculator : public CalculatorBase {
|
|||
.Tag("BATCH_END")
|
||||
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
|
||||
.At(Timestamp(loop_internal_timestamp_ - 1)));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
|
@ -52,8 +53,9 @@ typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
|
|||
EndLoopClassificationListCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
||||
typedef EndLoopCalculator<std::vector<TfLiteTensor>>
|
||||
EndLoopTfLiteTensorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopTfLiteTensorCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
|
||||
EndLoopDetectionCalculator;
|
||||
|
@ -62,4 +64,7 @@ REGISTER_CALCULATOR(EndLoopDetectionCalculator);
|
|||
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<Tensor>> EndLoopTensorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
@ -75,8 +77,23 @@ class EndLoopCalculator : public CalculatorBase {
|
|||
if (!input_stream_collection_) {
|
||||
input_stream_collection_.reset(new IterableT);
|
||||
}
|
||||
input_stream_collection_->push_back(
|
||||
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
||||
// Try to consume the item and move it into the collection. If the items
|
||||
// are not consumable, then try to copy them instead. If the items are
|
||||
// not copiable, then an error will be returned.
|
||||
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
|
||||
if (item_ptr_or.ok()) {
|
||||
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
|
||||
} else {
|
||||
if constexpr (std::is_copy_constructible_v<ItemT>) {
|
||||
input_stream_collection_->push_back(
|
||||
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
||||
} else {
|
||||
return absl::InternalError(
|
||||
"The item type is not copiable. Consider making the "
|
||||
"EndLoopCalculator the sole owner of the input packets so that "
|
||||
"it can be moved instead of copying.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
|
||||
|
|
Loading…
Reference in New Issue
Block a user