Updated the Begin/EndLoopCalculator to be able to handle mediapipe::Tensor
type. PiperOrigin-RevId: 508552066
This commit is contained in:
parent
8a49a5f822
commit
d61b7dbef8
|
@ -198,6 +198,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -220,6 +221,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -236,10 +238,12 @@ cc_test(
|
||||||
":begin_loop_calculator",
|
":begin_loop_calculator",
|
||||||
":end_loop_calculator",
|
":end_loop_calculator",
|
||||||
":gate_calculator",
|
":gate_calculator",
|
||||||
|
":pass_through_calculator",
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_contract",
|
"//mediapipe/framework:calculator_contract",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -444,5 +445,67 @@ TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
||||||
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
|
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status InitBeginEndTensorLoopTestGraph(
|
||||||
|
CalculatorGraph& graph, std::vector<Packet>& output_packets) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
R"pb(
|
||||||
|
num_threads: 4
|
||||||
|
input_stream: "tensors"
|
||||||
|
node {
|
||||||
|
calculator: "BeginLoopTensorCalculator"
|
||||||
|
input_stream: "ITERABLE:tensors"
|
||||||
|
output_stream: "ITEM:tensor"
|
||||||
|
output_stream: "BATCH_END:timestamp"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughCalculator"
|
||||||
|
input_stream: "tensor"
|
||||||
|
output_stream: "passed_tensor"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "EndLoopTensorCalculator"
|
||||||
|
input_stream: "ITEM:passed_tensor"
|
||||||
|
input_stream: "BATCH_END:timestamp"
|
||||||
|
output_stream: "ITERABLE:output_tensors"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output_tensors", &graph_config, &output_packets);
|
||||||
|
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
|
||||||
|
return graph.StartRun({});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
||||||
|
// Initialize the graph.
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
MP_ASSERT_OK(InitBeginEndTensorLoopTestGraph(graph, output_packets));
|
||||||
|
|
||||||
|
// Prepare the inputs and run.
|
||||||
|
Timestamp input_timestamp = Timestamp(0);
|
||||||
|
std::vector<mediapipe::Tensor> tensors;
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
tensors.emplace_back(Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{4, 3, 2, 1});
|
||||||
|
}
|
||||||
|
Packet vector_packet =
|
||||||
|
MakePacket<std::vector<mediapipe::Tensor>>(std::move(tensors));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"tensors", std::move(vector_packet).At(input_timestamp)));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Verify the output packet.
|
||||||
|
EXPECT_EQ(output_packets.size(), 1);
|
||||||
|
const std::vector<Tensor>& output_tensors =
|
||||||
|
output_packets[0].Get<std::vector<Tensor>>();
|
||||||
|
EXPECT_EQ(output_tensors.size(), 4);
|
||||||
|
for (int i = 0; i < output_tensors.size(); i++) {
|
||||||
|
EXPECT_THAT(output_tensors[i].shape().dims,
|
||||||
|
testing::ElementsAre(4, 3, 2, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -55,4 +56,8 @@ REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
|
||||||
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
|
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
|
||||||
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
|
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
|
||||||
|
|
||||||
|
// A calculator to process std::vector<mediapipe::Tensor>.
|
||||||
|
typedef BeginLoopCalculator<std::vector<Tensor>> BeginLoopTensorCalculator;
|
||||||
|
REGISTER_CALCULATOR(BeginLoopTensorCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||||
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "mediapipe/framework/calculator_context.h"
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
@ -24,6 +23,7 @@
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -112,13 +112,38 @@ class BeginLoopCalculator : public CalculatorBase {
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
Timestamp last_timestamp = loop_internal_timestamp_;
|
Timestamp last_timestamp = loop_internal_timestamp_;
|
||||||
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
|
||||||
const IterableT& collection =
|
// Try to consume the ITERABLE packet if possible to obtain the ownership
|
||||||
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
// and emit the item packets by moving them.
|
||||||
for (const auto& item : collection) {
|
// If the ITERABLE packet is not consumable, then try to copy each item
|
||||||
cc->Outputs().Tag("ITEM").AddPacket(
|
// instead. If the ITEM type is not copy constructible, an error will be
|
||||||
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
// returned.
|
||||||
ForwardClonePackets(cc, loop_internal_timestamp_);
|
auto iterable_ptr_or =
|
||||||
++loop_internal_timestamp_;
|
cc->Inputs().Tag("ITERABLE").Value().Consume<IterableT>();
|
||||||
|
if (iterable_ptr_or.ok()) {
|
||||||
|
for (auto& item : *iterable_ptr_or.value()) {
|
||||||
|
Packet item_packet = MakePacket<ItemT>(std::move(item));
|
||||||
|
cc->Outputs().Tag("ITEM").AddPacket(
|
||||||
|
item_packet.At(loop_internal_timestamp_));
|
||||||
|
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||||
|
++loop_internal_timestamp_;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if constexpr (std::is_copy_constructible<ItemT>()) {
|
||||||
|
const IterableT& collection =
|
||||||
|
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
|
||||||
|
for (const auto& item : collection) {
|
||||||
|
cc->Outputs().Tag("ITEM").AddPacket(
|
||||||
|
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
|
||||||
|
ForwardClonePackets(cc, loop_internal_timestamp_);
|
||||||
|
++loop_internal_timestamp_;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return absl::InternalError(
|
||||||
|
"The element type is not copiable. Consider making the "
|
||||||
|
"BeginLoopCalculator the sole owner of the input packet so that "
|
||||||
|
"the "
|
||||||
|
"items can be consumed and moved.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +163,6 @@ class BeginLoopCalculator : public CalculatorBase {
|
||||||
.Tag("BATCH_END")
|
.Tag("BATCH_END")
|
||||||
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
|
.AddPacket(MakePacket<Timestamp>(cc->InputTimestamp())
|
||||||
.At(Timestamp(loop_internal_timestamp_ - 1)));
|
.At(Timestamp(loop_internal_timestamp_ - 1)));
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/util/render_data.pb.h"
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
|
@ -52,8 +53,9 @@ typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
|
||||||
EndLoopClassificationListCalculator;
|
EndLoopClassificationListCalculator;
|
||||||
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
||||||
|
|
||||||
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
|
typedef EndLoopCalculator<std::vector<TfLiteTensor>>
|
||||||
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
EndLoopTfLiteTensorCalculator;
|
||||||
|
REGISTER_CALCULATOR(EndLoopTfLiteTensorCalculator);
|
||||||
|
|
||||||
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
|
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
|
||||||
EndLoopDetectionCalculator;
|
EndLoopDetectionCalculator;
|
||||||
|
@ -62,4 +64,7 @@ REGISTER_CALCULATOR(EndLoopDetectionCalculator);
|
||||||
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
|
typedef EndLoopCalculator<std::vector<Matrix>> EndLoopMatrixCalculator;
|
||||||
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
|
REGISTER_CALCULATOR(EndLoopMatrixCalculator);
|
||||||
|
|
||||||
|
typedef EndLoopCalculator<std::vector<Tensor>> EndLoopTensorCalculator;
|
||||||
|
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
#ifndef MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||||
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
#define MEDIAPIPE_CALCULATORS_CORE_END_LOOP_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_context.h"
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
@ -75,8 +77,23 @@ class EndLoopCalculator : public CalculatorBase {
|
||||||
if (!input_stream_collection_) {
|
if (!input_stream_collection_) {
|
||||||
input_stream_collection_.reset(new IterableT);
|
input_stream_collection_.reset(new IterableT);
|
||||||
}
|
}
|
||||||
input_stream_collection_->push_back(
|
// Try to consume the item and move it into the collection. If the items
|
||||||
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
// are not consumable, then try to copy them instead. If the items are
|
||||||
|
// not copiable, then an error will be returned.
|
||||||
|
auto item_ptr_or = cc->Inputs().Tag("ITEM").Value().Consume<ItemT>();
|
||||||
|
if (item_ptr_or.ok()) {
|
||||||
|
input_stream_collection_->push_back(std::move(*item_ptr_or.value()));
|
||||||
|
} else {
|
||||||
|
if constexpr (std::is_copy_constructible_v<ItemT>) {
|
||||||
|
input_stream_collection_->push_back(
|
||||||
|
cc->Inputs().Tag("ITEM").template Get<ItemT>());
|
||||||
|
} else {
|
||||||
|
return absl::InternalError(
|
||||||
|
"The item type is not copiable. Consider making the "
|
||||||
|
"EndLoopCalculator the sole owner of the input packets so that "
|
||||||
|
"it can be moved instead of copying.");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
|
if (!cc->Inputs().Tag("BATCH_END").Value().IsEmpty()) { // flush signal
|
||||||
|
|
Loading…
Reference in New Issue
Block a user