Project import generated by Copybara.
GitOrigin-RevId: 43cd697ec87dcc5cab5051f27960bb77a057399d
This commit is contained in:
parent
3b6d3c4058
commit
1722d4b8a2
|
@ -129,7 +129,11 @@ http_archive(
|
||||||
],
|
],
|
||||||
# A compatibility patch
|
# A compatibility patch
|
||||||
patches = [
|
patches = [
|
||||||
"@//third_party:org_tensorflow_528e22eae8bf3206189a066032c66e9e5c9b4a61.diff"
|
"@//third_party:org_tensorflow_528e22eae8bf3206189a066032c66e9e5c9b4a61.diff",
|
||||||
|
# Updates for XNNPACK: https://github.com/tensorflow/tensorflow/commit/cfc31e324c8de6b52f752a39cb161d99d853ca99
|
||||||
|
"@//third_party:org_tensorflow_cfc31e324c8de6b52f752a39cb161d99d853ca99.diff",
|
||||||
|
# CpuInfo's build rule fixes.
|
||||||
|
"@//third_party:org_tensorflow_9696366bcadab23a25c773b3ed405bac8ded4d0d.diff",
|
||||||
],
|
],
|
||||||
patch_args = [
|
patch_args = [
|
||||||
"-p1",
|
"-p1",
|
||||||
|
|
|
@ -228,6 +228,7 @@ cc_library(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:collection_item_id",
|
"//mediapipe/framework:collection_item_id",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
@ -249,6 +250,7 @@ cc_library(
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:collection_item_id",
|
"//mediapipe/framework:collection_item_id",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
@ -265,10 +267,11 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":begin_loop_calculator",
|
":begin_loop_calculator",
|
||||||
":end_loop_calculator",
|
":end_loop_calculator",
|
||||||
"//mediapipe/calculators/core:packet_cloner_calculator",
|
":gate_calculator",
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_contract",
|
"//mediapipe/framework:calculator_contract",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
@ -334,6 +337,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":clip_vector_size_calculator_cc_proto",
|
":clip_vector_size_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -693,15 +697,17 @@ cc_test(
|
||||||
name = "previous_loopback_calculator_test",
|
name = "previous_loopback_calculator_test",
|
||||||
srcs = ["previous_loopback_calculator_test.cc"],
|
srcs = ["previous_loopback_calculator_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":gate_calculator",
|
||||||
|
":make_pair_calculator",
|
||||||
|
":pass_through_calculator",
|
||||||
":previous_loopback_calculator",
|
":previous_loopback_calculator",
|
||||||
"//mediapipe/calculators/core:make_pair_calculator",
|
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||||
"//mediapipe/framework/tool:sink",
|
"//mediapipe/framework/tool:sink",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
|
@ -769,6 +775,7 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":split_vector_calculator_cc_proto",
|
":split_vector_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
#include "mediapipe/calculators/core/end_loop_calculator.h"
|
||||||
#include "mediapipe/framework/calculator_contract.h"
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
@ -28,6 +30,13 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
MATCHER_P2(PacketOfIntsEq, timestamp, value, "") {
|
||||||
|
Timestamp actual_timestamp = arg.Timestamp();
|
||||||
|
const auto& actual_value = arg.template Get<std::vector<int>>();
|
||||||
|
return testing::Value(actual_timestamp, testing::Eq(timestamp)) &&
|
||||||
|
testing::Value(actual_value, testing::ElementsAreArray(value));
|
||||||
|
}
|
||||||
|
|
||||||
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
|
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
|
||||||
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
|
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
|
||||||
|
|
||||||
|
@ -59,8 +68,8 @@ REGISTER_CALCULATOR(EndLoopIntegersCalculator);
|
||||||
|
|
||||||
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
|
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
BeginEndLoopCalculatorGraphTest() {
|
void SetUp() override {
|
||||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
R"(
|
R"(
|
||||||
num_threads: 4
|
num_threads: 4
|
||||||
input_stream: "ints"
|
input_stream: "ints"
|
||||||
|
@ -82,94 +91,222 @@ class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
|
||||||
output_stream: "ITERABLE:ints_plus_one"
|
output_stream: "ITERABLE:ints_plus_one"
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
tool::AddVectorSink("ints_plus_one", &graph_config_, &output_packets_);
|
tool::AddVectorSink("ints_plus_one", &graph_config, &output_packets_);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
}
|
}
|
||||||
|
|
||||||
CalculatorGraphConfig graph_config_;
|
void SendPacketOfInts(Timestamp timestamp, std::vector<int> ints) {
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph graph_;
|
||||||
std::vector<Packet> output_packets_;
|
std::vector<Packet> output_packets_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_F(BeginEndLoopCalculatorGraphTest, InputStreamForIterableIsEmpty) {
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
// EndLoopCalc will forward the timestamp bound because there are no packets
|
||||||
|
// to process.
|
||||||
|
ASSERT_EQ(0, output_packets_.size());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
|
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
|
||||||
CalculatorGraph graph;
|
SendPacketOfInts(Timestamp(0), {});
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
|
||||||
Timestamp input_timestamp = Timestamp(0);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
|
||||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
||||||
|
|
||||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||||
// in collection to output.
|
// in collection to output.
|
||||||
ASSERT_EQ(0, output_packets_.size());
|
EXPECT_TRUE(output_packets_.empty());
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
|
||||||
CalculatorGraph graph;
|
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector->emplace_back(0);
|
|
||||||
input_vector->emplace_back(1);
|
|
||||||
input_vector->emplace_back(2);
|
|
||||||
Timestamp input_timestamp = Timestamp(0);
|
Timestamp input_timestamp = Timestamp(0);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPacketOfInts(input_timestamp, {0, 1, 2});
|
||||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
||||||
|
|
||||||
ASSERT_EQ(1, output_packets_.size());
|
EXPECT_THAT(output_packets_,
|
||||||
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
|
testing::ElementsAre(
|
||||||
std::vector<int> expected_output_vector = {1, 2, 3};
|
PacketOfIntsEq(input_timestamp, std::vector<int>{1, 2, 3})));
|
||||||
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
|
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
|
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
|
||||||
CalculatorGraph graph;
|
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
|
|
||||||
auto input_vector0 = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector0->emplace_back(0);
|
|
||||||
input_vector0->emplace_back(1);
|
|
||||||
Timestamp input_timestamp0 = Timestamp(0);
|
Timestamp input_timestamp0 = Timestamp(0);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPacketOfInts(input_timestamp0, {0, 1});
|
||||||
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
|
|
||||||
|
|
||||||
auto input_vector1 = absl::make_unique<std::vector<int>>();
|
|
||||||
Timestamp input_timestamp1 = Timestamp(1);
|
Timestamp input_timestamp1 = Timestamp(1);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPacketOfInts(input_timestamp1, {});
|
||||||
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
|
|
||||||
|
|
||||||
auto input_vector2 = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector2->emplace_back(2);
|
|
||||||
input_vector2->emplace_back(3);
|
|
||||||
Timestamp input_timestamp2 = Timestamp(2);
|
Timestamp input_timestamp2 = Timestamp(2);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPacketOfInts(input_timestamp2, {2, 3});
|
||||||
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
|
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
|
||||||
ASSERT_EQ(2, output_packets_.size());
|
|
||||||
|
|
||||||
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
|
|
||||||
std::vector<int> expected_output_vector0 = {1, 2};
|
|
||||||
EXPECT_EQ(expected_output_vector0,
|
|
||||||
output_packets_[0].Get<std::vector<int>>());
|
|
||||||
|
|
||||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||||
// no elements in vector to process.
|
// no elements in vector to process.
|
||||||
|
EXPECT_THAT(output_packets_,
|
||||||
|
testing::ElementsAre(
|
||||||
|
PacketOfIntsEq(input_timestamp0, std::vector<int>{1, 2}),
|
||||||
|
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
|
||||||
|
}
|
||||||
|
|
||||||
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
|
// Passes non empty vector through or outputs empty vector in case of timestamp
|
||||||
std::vector<int> expected_output_vector2 = {3, 4};
|
// bound update.
|
||||||
EXPECT_EQ(expected_output_vector2,
|
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {
|
||||||
output_packets_[1].Get<std::vector<int>>());
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
cc->Inputs().Index(0).Set<std::vector<int>>();
|
||||||
|
cc->Outputs().Index(0).Set<std::vector<int>>();
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||||
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||||
|
if (!cc->Inputs().Index(0).IsEmpty()) {
|
||||||
|
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||||
|
} else {
|
||||||
|
cc->Outputs().Index(0).AddPacket(
|
||||||
|
MakePacket<std::vector<int>>(std::vector<int>())
|
||||||
|
.At(cc->InputTimestamp()));
|
||||||
|
}
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_CALCULATOR(PassThroughOrEmptyVectorCalculator);
|
||||||
|
|
||||||
|
class BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest
|
||||||
|
: public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
R"(
|
||||||
|
num_threads: 4
|
||||||
|
input_stream: "ints"
|
||||||
|
input_stream: "force_ints_to_be_timestamp_bound_update"
|
||||||
|
node {
|
||||||
|
calculator: "GateCalculator"
|
||||||
|
input_stream: "ints"
|
||||||
|
input_stream: "DISALLOW:force_ints_to_be_timestamp_bound_update"
|
||||||
|
output_stream: "ints_passed_through"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "BeginLoopIntegerCalculator"
|
||||||
|
input_stream: "ITERABLE:ints_passed_through"
|
||||||
|
output_stream: "ITEM:int"
|
||||||
|
output_stream: "BATCH_END:timestamp"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "IncrementCalculator"
|
||||||
|
input_stream: "int"
|
||||||
|
output_stream: "int_plus_one"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "EndLoopIntegersCalculator"
|
||||||
|
input_stream: "ITEM:int_plus_one"
|
||||||
|
input_stream: "BATCH_END:timestamp"
|
||||||
|
output_stream: "ITERABLE:ints_plus_one"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "PassThroughOrEmptyVectorCalculator"
|
||||||
|
input_stream: "ints_plus_one"
|
||||||
|
output_stream: "ints_plus_one_passed_through"
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
tool::AddVectorSink("ints_plus_one_passed_through", &graph_config,
|
||||||
|
&output_packets_);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SendPacketOfIntsOrBound(Timestamp timestamp, std::vector<int> ints) {
|
||||||
|
// All "ints" packets which are empty are forced to be just timestamp
|
||||||
|
// bound updates for begin loop calculator.
|
||||||
|
bool force_ints_to_be_timestamp_bound_update = ints.empty();
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"force_ints_to_be_timestamp_bound_update",
|
||||||
|
MakePacket<bool>(force_ints_to_be_timestamp_bound_update)
|
||||||
|
.At(timestamp)));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph graph_;
|
||||||
|
std::vector<Packet> output_packets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
|
||||||
|
SingleEmptyVector) {
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(0), {});
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
|
||||||
|
Timestamp(0), std::vector<int>{})));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
|
||||||
|
SingleNonEmptyVector) {
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(0), {0, 1, 2});
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
|
||||||
|
Timestamp(0), std::vector<int>{1, 2, 3})));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) {
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(0), {});
|
||||||
|
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||||
|
// individually. (Timestamp bounds updates occur in the provide config only
|
||||||
|
// if input is an empty vector.)
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(1), {0, 1});
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(2), {});
|
||||||
|
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||||
|
// individually. (Timestamp bounds updates occur in the provide config only
|
||||||
|
// if input is an empty vector.)
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(3), {2, 3});
|
||||||
|
SendPacketOfIntsOrBound(Timestamp(4), {});
|
||||||
|
// Waiting until idle to guarantee all timestamp bound updates are processed
|
||||||
|
// individually. (Timestamp bounds updates occur in the provide config only
|
||||||
|
// if input is an empty vector.)
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
testing::ElementsAre(PacketOfIntsEq(Timestamp(0), std::vector<int>{}),
|
||||||
|
PacketOfIntsEq(Timestamp(1), std::vector<int>{1, 2}),
|
||||||
|
PacketOfIntsEq(Timestamp(2), std::vector<int>{}),
|
||||||
|
PacketOfIntsEq(Timestamp(3), std::vector<int>{3, 4}),
|
||||||
|
PacketOfIntsEq(Timestamp(4), std::vector<int>{})));
|
||||||
}
|
}
|
||||||
|
|
||||||
class MultiplierCalculator : public CalculatorBase {
|
class MultiplierCalculator : public CalculatorBase {
|
||||||
|
@ -199,8 +336,8 @@ REGISTER_CALCULATOR(MultiplierCalculator);
|
||||||
|
|
||||||
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
|
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
BeginEndLoopCalculatorGraphWithClonedInputsTest() {
|
void SetUp() override {
|
||||||
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
R"(
|
R"(
|
||||||
num_threads: 4
|
num_threads: 4
|
||||||
input_stream: "ints"
|
input_stream: "ints"
|
||||||
|
@ -226,109 +363,85 @@ class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
|
||||||
output_stream: "ITERABLE:multiplied_ints"
|
output_stream: "ITERABLE:multiplied_ints"
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
tool::AddVectorSink("multiplied_ints", &graph_config_, &output_packets_);
|
tool::AddVectorSink("multiplied_ints", &graph_config, &output_packets_);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
}
|
}
|
||||||
|
|
||||||
CalculatorGraphConfig graph_config_;
|
void SendPackets(Timestamp timestamp, int multiplier, std::vector<int> ints) {
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SendMultiplier(Timestamp timestamp, int multiplier) {
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph graph_;
|
||||||
std::vector<Packet> output_packets_;
|
std::vector<Packet> output_packets_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
|
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest,
|
||||||
CalculatorGraph graph;
|
InputStreamForIterableIsEmpty) {
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
|
||||||
Timestamp input_timestamp = Timestamp(42);
|
Timestamp input_timestamp = Timestamp(42);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendMultiplier(input_timestamp, /*multiplier=*/2);
|
||||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
auto multiplier = absl::make_unique<int>(2);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
// EndLoopCalc will forward the timestamp bound because there are no packets
|
||||||
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
|
// to process.
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
ASSERT_EQ(0, output_packets_.size());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
|
||||||
|
SendPackets(Timestamp(0), /*multiplier=*/2, /*ints=*/{});
|
||||||
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
|
|
||||||
// EndLoopCalc will forward the timestamp bound because there are no elements
|
// EndLoopCalc will forward the timestamp bound because there are no elements
|
||||||
// in collection to output.
|
// in collection to output.
|
||||||
ASSERT_EQ(0, output_packets_.size());
|
EXPECT_TRUE(output_packets_.empty());
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
|
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
|
||||||
CalculatorGraph graph;
|
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
auto input_vector = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector->emplace_back(0);
|
|
||||||
input_vector->emplace_back(1);
|
|
||||||
input_vector->emplace_back(2);
|
|
||||||
Timestamp input_timestamp = Timestamp(42);
|
Timestamp input_timestamp = Timestamp(42);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPackets(input_timestamp, /*multiplier=*/2, /*ints=*/{0, 1, 2});
|
||||||
"ints", Adopt(input_vector.release()).At(input_timestamp)));
|
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||||
auto multiplier = absl::make_unique<int>(2);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
|
||||||
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
|
|
||||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
|
||||||
|
|
||||||
ASSERT_EQ(1, output_packets_.size());
|
EXPECT_THAT(output_packets_,
|
||||||
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
|
testing::ElementsAre(
|
||||||
std::vector<int> expected_output_vector = {0, 2, 4};
|
PacketOfIntsEq(input_timestamp, std::vector<int>{0, 2, 4})));
|
||||||
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
|
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
|
||||||
CalculatorGraph graph;
|
|
||||||
MP_EXPECT_OK(graph.Initialize(graph_config_));
|
|
||||||
MP_EXPECT_OK(graph.StartRun({}));
|
|
||||||
|
|
||||||
auto input_vector0 = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector0->emplace_back(0);
|
|
||||||
input_vector0->emplace_back(1);
|
|
||||||
Timestamp input_timestamp0 = Timestamp(42);
|
Timestamp input_timestamp0 = Timestamp(42);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPackets(input_timestamp0, /*multiplier=*/2, /*ints=*/{0, 1});
|
||||||
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
|
|
||||||
auto multiplier0 = absl::make_unique<int>(2);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
|
||||||
"multiplier", Adopt(multiplier0.release()).At(input_timestamp0)));
|
|
||||||
|
|
||||||
auto input_vector1 = absl::make_unique<std::vector<int>>();
|
|
||||||
Timestamp input_timestamp1 = Timestamp(43);
|
Timestamp input_timestamp1 = Timestamp(43);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPackets(input_timestamp1, /*multiplier=*/2, /*ints=*/{});
|
||||||
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
|
|
||||||
auto multiplier1 = absl::make_unique<int>(2);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
|
||||||
"multiplier", Adopt(multiplier1.release()).At(input_timestamp1)));
|
|
||||||
|
|
||||||
auto input_vector2 = absl::make_unique<std::vector<int>>();
|
|
||||||
input_vector2->emplace_back(2);
|
|
||||||
input_vector2->emplace_back(3);
|
|
||||||
Timestamp input_timestamp2 = Timestamp(44);
|
Timestamp input_timestamp2 = Timestamp(44);
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
SendPackets(input_timestamp2, /*multiplier=*/3, /*ints=*/{2, 3});
|
||||||
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
|
|
||||||
auto multiplier2 = absl::make_unique<int>(3);
|
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
|
||||||
"multiplier", Adopt(multiplier2.release()).At(input_timestamp2)));
|
|
||||||
|
|
||||||
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
MP_ASSERT_OK(graph_.CloseAllPacketSources());
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph_.WaitUntilDone());
|
||||||
|
|
||||||
ASSERT_EQ(2, output_packets_.size());
|
|
||||||
|
|
||||||
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
|
|
||||||
std::vector<int> expected_output_vector0 = {0, 2};
|
|
||||||
EXPECT_EQ(expected_output_vector0,
|
|
||||||
output_packets_[0].Get<std::vector<int>>());
|
|
||||||
|
|
||||||
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
|
||||||
// no elements in vector to process.
|
// no elements in vector to process.
|
||||||
|
EXPECT_THAT(output_packets_,
|
||||||
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
|
testing::ElementsAre(
|
||||||
std::vector<int> expected_output_vector2 = {6, 9};
|
PacketOfIntsEq(input_timestamp0, std::vector<int>{0, 2}),
|
||||||
EXPECT_EQ(expected_output_vector2,
|
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
|
||||||
output_packets_[1].Get<std::vector<int>>());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
|
||||||
|
@ -31,4 +32,9 @@ typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||||
BeginLoopNormalizedRectCalculator;
|
BeginLoopNormalizedRectCalculator;
|
||||||
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
|
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
|
||||||
|
|
||||||
|
// A calculator to process std::vector<Detection>.
|
||||||
|
typedef BeginLoopCalculator<std::vector<::mediapipe::Detection>>
|
||||||
|
BeginLoopDetectionCalculator;
|
||||||
|
REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -52,20 +52,28 @@ namespace mediapipe {
|
||||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// BeginLoopCalculator accepts an optional input stream tagged with "TICK"
|
// Input streams tagged with "CLONE" are cloned to the corresponding output
|
||||||
// which if non-empty, wakes up the calculator and calls
|
// streams at loop timestamps. This ensures that a MediaPipe graph or sub-graph
|
||||||
// BeginLoopCalculator::Process(). Input streams tagged with "CLONE" are cloned
|
// can run multiple times, once per element in the "ITERABLE" for each pakcet
|
||||||
// to the corresponding output streams at loop timestamps. This ensures that a
|
// clone of the packets in the "CLONE" input streams.
|
||||||
// MediaPipe graph or sub-graph can run multiple times, once per element in the
|
|
||||||
// "ITERABLE" for each pakcet clone of the packets in the "CLONE" input streams.
|
|
||||||
template <typename IterableT>
|
template <typename IterableT>
|
||||||
class BeginLoopCalculator : public CalculatorBase {
|
class BeginLoopCalculator : public CalculatorBase {
|
||||||
using ItemT = typename IterableT::value_type;
|
using ItemT = typename IterableT::value_type;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
// The below enables processing of timestamp bound updates, and that enables
|
||||||
|
// correct timestamp propagation by the companion EndLoopCalculator.
|
||||||
|
//
|
||||||
|
// For instance, Process() function will be still invoked even if upstream
|
||||||
|
// calculator has updated timestamp bound for ITERABLE input instead of
|
||||||
|
// providing actual value.
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
|
||||||
// A non-empty packet in the optional "TICK" input stream wakes up the
|
// A non-empty packet in the optional "TICK" input stream wakes up the
|
||||||
// calculator.
|
// calculator.
|
||||||
|
// DEPRECATED as timestamp bound updates are processed by default in this
|
||||||
|
// calculator.
|
||||||
if (cc->Inputs().HasTag("TICK")) {
|
if (cc->Inputs().HasTag("TICK")) {
|
||||||
cc->Inputs().Tag("TICK").SetAny();
|
cc->Inputs().Tag("TICK").SetAny();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -25,4 +26,8 @@ typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
|
||||||
ClipNormalizedRectVectorSizeCalculator;
|
ClipNormalizedRectVectorSizeCalculator;
|
||||||
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
|
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
|
||||||
|
|
||||||
|
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
|
||||||
|
ClipDetectionVectorSizeCalculator;
|
||||||
|
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/util/render_data.pb.h"
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
@ -37,4 +38,8 @@ typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
|
||||||
EndLoopRenderDataCalculator;
|
EndLoopRenderDataCalculator;
|
||||||
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
|
||||||
|
|
||||||
|
typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
|
||||||
|
EndLoopClassificationListCalculator;
|
||||||
|
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -25,13 +25,17 @@ namespace mediapipe {
|
||||||
// together with some previous output.
|
// together with some previous output.
|
||||||
//
|
//
|
||||||
// For the first packet that arrives on the MAIN input, the timestamp bound is
|
// For the first packet that arrives on the MAIN input, the timestamp bound is
|
||||||
// advanced on the output. Downstream calculators will see this as an empty
|
// advanced on the PREV_LOOP. Downstream calculators will see this as an empty
|
||||||
// packet. This way they are not kept waiting for the previous output, which
|
// packet. This way they are not kept waiting for the previous output, which
|
||||||
// for the first iteration does not exist.
|
// for the first iteration does not exist.
|
||||||
//
|
//
|
||||||
// Thereafter, each packet received on MAIN is matched with a packet received
|
// Thereafter,
|
||||||
// on LOOP; the LOOP packet's timestamp is changed to that of the MAIN packet,
|
// - Each non-empty MAIN packet results in:
|
||||||
// and it is output on PREV_LOOP.
|
// a) a PREV_LOOP packet with contents of the LOOP packet received at the
|
||||||
|
// timestamp of the previous non-empty MAIN packet
|
||||||
|
// b) or in a PREV_LOOP timestamp bound update if the LOOP packet was empty.
|
||||||
|
// - Each empty MAIN packet indicating timestamp bound update results in a
|
||||||
|
// PREV_LOOP timestamp bound update.
|
||||||
//
|
//
|
||||||
// Example config:
|
// Example config:
|
||||||
// node {
|
// node {
|
||||||
|
@ -56,83 +60,115 @@ class PreviousLoopbackCalculator : public CalculatorBase {
|
||||||
// TODO: an optional PREV_TIMESTAMP output could be added to
|
// TODO: an optional PREV_TIMESTAMP output could be added to
|
||||||
// carry the original timestamp of the packet on PREV_LOOP.
|
// carry the original timestamp of the packet on PREV_LOOP.
|
||||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
// Process() function is invoked in response to MAIN/LOOP stream timestamp
|
||||||
|
// bound updates.
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status Open(CalculatorContext* cc) final {
|
::mediapipe::Status Open(CalculatorContext* cc) final {
|
||||||
main_id_ = cc->Inputs().GetId("MAIN", 0);
|
main_id_ = cc->Inputs().GetId("MAIN", 0);
|
||||||
loop_id_ = cc->Inputs().GetId("LOOP", 0);
|
loop_id_ = cc->Inputs().GetId("LOOP", 0);
|
||||||
loop_out_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
|
prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Get(loop_out_id_)
|
.Get(prev_loop_id_)
|
||||||
.SetHeader(cc->Inputs().Get(loop_id_).Header());
|
.SetHeader(cc->Inputs().Get(loop_id_).Header());
|
||||||
|
|
||||||
// Use an empty packet for the first round, since there is no previous
|
|
||||||
// output.
|
|
||||||
loopback_packets_.push_back({});
|
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status Process(CalculatorContext* cc) final {
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
Packet& main_packet = cc->Inputs().Get(main_id_).Value();
|
// Non-empty packets and empty packets indicating timestamp bound updates
|
||||||
|
// are guaranteed to have timestamps greater than timestamps of previous
|
||||||
|
// packets within the same stream. Calculator tracks and operates on such
|
||||||
|
// packets.
|
||||||
|
|
||||||
|
const Packet& main_packet = cc->Inputs().Get(main_id_).Value();
|
||||||
|
if (prev_main_ts_ < main_packet.Timestamp()) {
|
||||||
|
Timestamp loop_timestamp;
|
||||||
if (!main_packet.IsEmpty()) {
|
if (!main_packet.IsEmpty()) {
|
||||||
main_ts_.push_back(main_packet.Timestamp());
|
loop_timestamp = prev_non_empty_main_ts_;
|
||||||
}
|
prev_non_empty_main_ts_ = main_packet.Timestamp();
|
||||||
Packet& loopback_packet = cc->Inputs().Get(loop_id_).Value();
|
|
||||||
if (!loopback_packet.IsEmpty()) {
|
|
||||||
loopback_packets_.push_back(loopback_packet);
|
|
||||||
while (!main_ts_.empty() &&
|
|
||||||
main_ts_.front() <= loopback_packets_.front().Timestamp()) {
|
|
||||||
main_ts_.pop_front();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto& loop_out = cc->Outputs().Get(loop_out_id_);
|
|
||||||
|
|
||||||
while (!main_ts_.empty() && !loopback_packets_.empty()) {
|
|
||||||
Timestamp main_timestamp = main_ts_.front();
|
|
||||||
main_ts_.pop_front();
|
|
||||||
Packet previous_loopback = loopback_packets_.front().At(main_timestamp);
|
|
||||||
loopback_packets_.pop_front();
|
|
||||||
|
|
||||||
if (previous_loopback.IsEmpty()) {
|
|
||||||
// TODO: SetCompleteTimestampBound would be more useful.
|
|
||||||
loop_out.SetNextTimestampBound(main_timestamp + 1);
|
|
||||||
} else {
|
} else {
|
||||||
loop_out.AddPacket(std::move(previous_loopback));
|
// Calculator advances PREV_LOOP timestamp bound in response to empty
|
||||||
|
// MAIN packet, hence not caring about corresponding loop packet.
|
||||||
|
loop_timestamp = Timestamp::Unset();
|
||||||
|
}
|
||||||
|
main_packet_specs_.push_back({.timestamp = main_packet.Timestamp(),
|
||||||
|
.loop_timestamp = loop_timestamp});
|
||||||
|
prev_main_ts_ = main_packet.Timestamp();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value();
|
||||||
|
if (prev_loop_ts_ < loop_packet.Timestamp()) {
|
||||||
|
loop_packets_.push_back(loop_packet);
|
||||||
|
prev_loop_ts_ = loop_packet.Timestamp();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& prev_loop = cc->Outputs().Get(prev_loop_id_);
|
||||||
|
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
|
||||||
|
// The earliest MAIN packet.
|
||||||
|
const MainPacketSpec& main_spec = main_packet_specs_.front();
|
||||||
|
// The earliest LOOP packet.
|
||||||
|
const Packet& loop_candidate = loop_packets_.front();
|
||||||
|
// Match LOOP and MAIN packets.
|
||||||
|
if (main_spec.loop_timestamp < loop_candidate.Timestamp()) {
|
||||||
|
// No LOOP packet can match the MAIN packet under review.
|
||||||
|
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
|
||||||
|
main_packet_specs_.pop_front();
|
||||||
|
} else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) {
|
||||||
|
// No MAIN packet can match the LOOP packet under review.
|
||||||
|
loop_packets_.pop_front();
|
||||||
|
} else {
|
||||||
|
// Exact match found.
|
||||||
|
if (loop_candidate.IsEmpty()) {
|
||||||
|
// However, LOOP packet is empty.
|
||||||
|
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
|
||||||
|
} else {
|
||||||
|
prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp));
|
||||||
|
}
|
||||||
|
loop_packets_.pop_front();
|
||||||
|
main_packet_specs_.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case of an empty loopback input, the next timestamp bound for
|
if (main_packet_specs_.empty() && cc->Inputs().Get(main_id_).IsDone()) {
|
||||||
// loopback input is the loopback timestamp + 1. The next timestamp bound
|
prev_loop.Close();
|
||||||
// for output is set and the main_ts_ vector is truncated accordingly.
|
|
||||||
if (loopback_packet.IsEmpty() &&
|
|
||||||
loopback_packet.Timestamp() != Timestamp::Unstarted()) {
|
|
||||||
Timestamp loopback_bound =
|
|
||||||
loopback_packet.Timestamp().NextAllowedInStream();
|
|
||||||
while (!main_ts_.empty() && main_ts_.front() <= loopback_bound) {
|
|
||||||
main_ts_.pop_front();
|
|
||||||
}
|
|
||||||
if (main_ts_.empty()) {
|
|
||||||
loop_out.SetNextTimestampBound(loopback_bound.NextAllowedInStream());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!main_ts_.empty()) {
|
|
||||||
loop_out.SetNextTimestampBound(main_ts_.front());
|
|
||||||
}
|
|
||||||
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
|
|
||||||
loop_out.Close();
|
|
||||||
}
|
}
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
struct MainPacketSpec {
|
||||||
|
Timestamp timestamp;
|
||||||
|
// Expected timestamp of the packet from LOOP stream that corresponds to the
|
||||||
|
// packet from MAIN stream descirbed by this spec.
|
||||||
|
Timestamp loop_timestamp;
|
||||||
|
};
|
||||||
|
|
||||||
CollectionItemId main_id_;
|
CollectionItemId main_id_;
|
||||||
CollectionItemId loop_id_;
|
CollectionItemId loop_id_;
|
||||||
CollectionItemId loop_out_id_;
|
CollectionItemId prev_loop_id_;
|
||||||
|
|
||||||
std::deque<Timestamp> main_ts_;
|
// Contains specs for MAIN packets which only can be:
|
||||||
std::deque<Packet> loopback_packets_;
|
// - non-empty packets
|
||||||
|
// - empty packets indicating timestamp bound updates
|
||||||
|
//
|
||||||
|
// Sorted according to packet timestamps.
|
||||||
|
std::deque<MainPacketSpec> main_packet_specs_;
|
||||||
|
Timestamp prev_main_ts_ = Timestamp::Unstarted();
|
||||||
|
Timestamp prev_non_empty_main_ts_ = Timestamp::Unstarted();
|
||||||
|
|
||||||
|
// Contains LOOP packets which only can be:
|
||||||
|
// - the very first empty packet
|
||||||
|
// - non empty packets
|
||||||
|
// - empty packets indicating timestamp bound updates
|
||||||
|
//
|
||||||
|
// Sorted according to packet timestamps.
|
||||||
|
std::deque<Packet> loop_packets_;
|
||||||
|
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
|
||||||
|
// allow addition of the very first empty packet (which doesn't indicate
|
||||||
|
// timestamp bound change necessarily).
|
||||||
|
Timestamp prev_loop_ts_ = Timestamp::Unset();
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(PreviousLoopbackCalculator);
|
REGISTER_CALCULATOR(PreviousLoopbackCalculator);
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -25,12 +26,17 @@
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::Eq;
|
||||||
|
using ::testing::Pair;
|
||||||
|
using ::testing::Value;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Returns the timestamp values for a vector of Packets.
|
// Returns the timestamp values for a vector of Packets.
|
||||||
|
@ -43,6 +49,23 @@ std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MATCHER(EmptyPacket, negation ? "isn't empty" : "is empty") {
|
||||||
|
if (arg.IsEmpty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P(IntPacket, value, "") {
|
||||||
|
return Value(arg.template Get<int>(), Eq(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
MATCHER_P2(PairPacket, timestamp, pair, "") {
|
||||||
|
Timestamp actual_timestamp = arg.Timestamp();
|
||||||
|
const auto& actual_pair = arg.template Get<std::pair<Packet, Packet>>();
|
||||||
|
return Value(actual_timestamp, Eq(timestamp)) && Value(actual_pair, pair);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
||||||
std::vector<Packet> in_prev;
|
std::vector<Packet> in_prev;
|
||||||
CalculatorGraphConfig graph_config_ =
|
CalculatorGraphConfig graph_config_ =
|
||||||
|
@ -81,32 +104,30 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
|
||||||
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
MP_EXPECT_OK(graph_.AddPacketToInputStream(
|
||||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||||
};
|
};
|
||||||
auto pair_values = [](const Packet& packet) {
|
|
||||||
auto pair = packet.Get<std::pair<Packet, Packet>>();
|
|
||||||
int first = pair.first.IsEmpty() ? -1 : pair.first.Get<int>();
|
|
||||||
int second = pair.second.IsEmpty() ? -1 : pair.second.Get<int>();
|
|
||||||
return std::make_pair(first, second);
|
|
||||||
};
|
|
||||||
|
|
||||||
send_packet("in", 1);
|
send_packet("in", 1);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1}));
|
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1));
|
||||||
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(1, -1));
|
EXPECT_THAT(in_prev.back(),
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())));
|
||||||
|
|
||||||
send_packet("in", 2);
|
send_packet("in", 2);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2}));
|
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2));
|
||||||
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(2, 1));
|
EXPECT_THAT(in_prev.back(),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))));
|
||||||
|
|
||||||
send_packet("in", 5);
|
send_packet("in", 5);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2, 5}));
|
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5));
|
||||||
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(5, 2));
|
EXPECT_THAT(in_prev.back(),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(2))));
|
||||||
|
|
||||||
send_packet("in", 15);
|
send_packet("in", 15);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2, 5, 15}));
|
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5, 15));
|
||||||
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(15, 5));
|
EXPECT_THAT(in_prev.back(),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5))));
|
||||||
|
|
||||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
@ -185,24 +206,24 @@ TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
|
||||||
|
|
||||||
send_packet("in", 1);
|
send_packet("in", 1);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
|
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1));
|
||||||
|
|
||||||
send_packet("in", 2);
|
send_packet("in", 2);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2}));
|
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2));
|
||||||
|
|
||||||
send_packet("in", 5);
|
send_packet("in", 5);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2, 5}));
|
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5));
|
||||||
|
|
||||||
send_packet("in", 15);
|
send_packet("in", 15);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2, 5, 15}));
|
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5, 15));
|
||||||
|
|
||||||
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
EXPECT_EQ(TimestampValues(outputs),
|
EXPECT_THAT(TimestampValues(outputs),
|
||||||
(std::vector<int64>{1, 2, 5, 15, Timestamp::Max().Value()}));
|
ElementsAre(1, 2, 5, 15, Timestamp::Max().Value()));
|
||||||
|
|
||||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
@ -247,16 +268,12 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
|
||||||
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
input_name, MakePacket<int>(n).At(Timestamp(n))));
|
||||||
};
|
};
|
||||||
|
|
||||||
send_packet("in", 0);
|
for (int main_ts = 0; main_ts < 50; ++main_ts) {
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
|
||||||
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{0}));
|
|
||||||
|
|
||||||
for (int main_ts = 1; main_ts < 50; ++main_ts) {
|
|
||||||
send_packet("in", main_ts);
|
send_packet("in", main_ts);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
std::vector<int64> ts_values = TimestampValues(outputs);
|
std::vector<int64> ts_values = TimestampValues(outputs);
|
||||||
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
||||||
for (int j = 0; j < main_ts; ++j) {
|
for (int j = 0; j < main_ts + 1; ++j) {
|
||||||
EXPECT_EQ(ts_values[j], j);
|
EXPECT_EQ(ts_values[j], j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -266,5 +283,487 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
|
||||||
MP_EXPECT_OK(graph_.WaitUntilDone());
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class PreviousLoopbackCalculatorProcessingTimestampsTest
|
||||||
|
: public testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'force_main_empty'
|
||||||
|
input_stream: 'force_loop_empty'
|
||||||
|
# Used to indicate "main" timestamp bound updates.
|
||||||
|
node {
|
||||||
|
calculator: 'GateCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'DISALLOW:force_main_empty'
|
||||||
|
output_stream: 'main'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'PreviousLoopbackCalculator'
|
||||||
|
input_stream: 'MAIN:main'
|
||||||
|
input_stream: 'LOOP:loop'
|
||||||
|
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||||
|
output_stream: 'PREV_LOOP:prev_loop'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'PassThroughCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'prev_loop'
|
||||||
|
output_stream: 'passed_through_input'
|
||||||
|
output_stream: 'passed_through_prev_loop'
|
||||||
|
}
|
||||||
|
# Used to indicate "loop" timestamp bound updates.
|
||||||
|
node {
|
||||||
|
calculator: 'GateCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'DISALLOW:force_loop_empty'
|
||||||
|
output_stream: 'loop'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'MakePairCalculator'
|
||||||
|
input_stream: 'passed_through_input'
|
||||||
|
input_stream: 'passed_through_prev_loop'
|
||||||
|
output_stream: 'passed_through_input_and_prev_loop'
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
|
||||||
|
&output_packets_);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SendPackets(int timestamp, int input, bool force_main_empty,
|
||||||
|
bool force_loop_empty) {
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"force_main_empty",
|
||||||
|
MakePacket<bool>(force_main_empty).At(Timestamp(timestamp))));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"force_loop_empty",
|
||||||
|
MakePacket<bool>(force_loop_empty).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph graph_;
|
||||||
|
std::vector<Packet> output_packets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||||
|
MultiplePacketsEmptyMainNonEmptyLoop) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1, /*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3, /*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5, /*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||||
|
/*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||||
|
MultiplePacketsNonEmptyMainEmptyLoop) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||||
|
MultiplePacketsAlteringMainNonEmptyLoop) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||||
|
/*force_main_empty=*/true,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||||
|
MultiplePacketsNonEmptyMainAlteringLoop) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15,
|
||||||
|
/*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
|
||||||
|
MultiplePacketsCheckIfLastCorrectAlteringMainAlteringLoop) {
|
||||||
|
int num_packets = 1000;
|
||||||
|
for (int i = 0; i < num_packets; ++i) {
|
||||||
|
bool force_main_empty = i % 3 == 0 ? true : false;
|
||||||
|
bool force_loop_empty = i % 2 == 0 ? true : false;
|
||||||
|
SendPackets(/*timestamp=*/i + 1, /*input=*/i + 1, force_main_empty,
|
||||||
|
force_loop_empty);
|
||||||
|
}
|
||||||
|
SendPackets(/*timestamp=*/num_packets + 1,
|
||||||
|
/*input=*/num_packets + 1, /*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
SendPackets(/*timestamp=*/num_packets + 2,
|
||||||
|
/*input=*/num_packets + 2, /*force_main_empty=*/false,
|
||||||
|
/*force_loop_empty=*/false);
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
ASSERT_FALSE(output_packets_.empty());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_.back(),
|
||||||
|
PairPacket(Timestamp(num_packets + 2),
|
||||||
|
Pair(IntPacket(num_packets + 2), IntPacket(num_packets + 1))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to GateCalculator, but it doesn't propagate timestamp bound updates.
|
||||||
|
class DroppingGateCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Index(0).SetAny();
|
||||||
|
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||||
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (!cc->Inputs().Index(0).IsEmpty() &&
|
||||||
|
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
|
||||||
|
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||||
|
}
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(DroppingGateCalculator);
|
||||||
|
|
||||||
|
// Tests PreviousLoopbackCalculator in cases when there are no "LOOP" timestamp
|
||||||
|
// bound updates and non-empty packets for a while and the aforementioned start
|
||||||
|
// to arrive at some point. So, "PREV_LOOP" is delayed for a couple of inputs.
|
||||||
|
class PreviousLoopbackCalculatorDelayBehaviorTest : public testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||||
|
input_stream: 'input'
|
||||||
|
# Drops "loop" when set to "true", delaying output of prev_loop, hence
|
||||||
|
# delaying output of the graph.
|
||||||
|
input_stream: 'delay_next_output'
|
||||||
|
node {
|
||||||
|
calculator: 'PreviousLoopbackCalculator'
|
||||||
|
input_stream: 'MAIN:input'
|
||||||
|
input_stream: 'LOOP:loop'
|
||||||
|
input_stream_info: { tag_index: 'LOOP' back_edge: true }
|
||||||
|
output_stream: 'PREV_LOOP:prev_loop'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'PassThroughCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'prev_loop'
|
||||||
|
output_stream: 'passed_through_input'
|
||||||
|
output_stream: 'passed_through_prev_loop'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'DroppingGateCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
input_stream: 'DISALLOW:delay_next_output'
|
||||||
|
output_stream: 'loop'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'MakePairCalculator'
|
||||||
|
input_stream: 'passed_through_input'
|
||||||
|
input_stream: 'passed_through_prev_loop'
|
||||||
|
output_stream: 'passed_through_input_and_prev_loop'
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
|
||||||
|
&output_packets_);
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph_.StartRun({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SendPackets(int timestamp, int input, bool delay_next_output) {
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
|
||||||
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
|
"delay_next_output",
|
||||||
|
MakePacket<bool>(delay_next_output).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
|
||||||
|
CalculatorGraph graph_;
|
||||||
|
std::vector<Packet> output_packets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest, MultipleDelayedOutputs) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest,
|
||||||
|
NonDelayedOutputFollowedByMultipleDelayedOutputs) {
|
||||||
|
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
|
||||||
|
|
||||||
|
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(
|
||||||
|
output_packets_,
|
||||||
|
ElementsAre(
|
||||||
|
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
|
||||||
|
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
|
||||||
|
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
|
||||||
|
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllInputStreams());
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
@ -48,6 +49,10 @@ typedef SplitVectorCalculator<::mediapipe::NormalizedLandmark, false>
|
||||||
SplitLandmarkVectorCalculator;
|
SplitLandmarkVectorCalculator;
|
||||||
REGISTER_CALCULATOR(SplitLandmarkVectorCalculator);
|
REGISTER_CALCULATOR(SplitLandmarkVectorCalculator);
|
||||||
|
|
||||||
|
typedef SplitVectorCalculator<::mediapipe::NormalizedLandmarkList, false>
|
||||||
|
SplitNormalizedLandmarkListVectorCalculator;
|
||||||
|
REGISTER_CALCULATOR(SplitNormalizedLandmarkListVectorCalculator);
|
||||||
|
|
||||||
typedef SplitVectorCalculator<::mediapipe::NormalizedRect, false>
|
typedef SplitVectorCalculator<::mediapipe::NormalizedRect, false>
|
||||||
SplitNormalizedRectVectorCalculator;
|
SplitNormalizedRectVectorCalculator;
|
||||||
REGISTER_CALCULATOR(SplitNormalizedRectVectorCalculator);
|
REGISTER_CALCULATOR(SplitNormalizedRectVectorCalculator);
|
||||||
|
@ -57,4 +62,9 @@ typedef SplitVectorCalculator<::tflite::gpu::gl::GlBuffer, true>
|
||||||
MovableSplitGlBufferVectorCalculator;
|
MovableSplitGlBufferVectorCalculator;
|
||||||
REGISTER_CALCULATOR(MovableSplitGlBufferVectorCalculator);
|
REGISTER_CALCULATOR(MovableSplitGlBufferVectorCalculator);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
typedef SplitVectorCalculator<::mediapipe::Detection, false>
|
||||||
|
SplitDetectionVectorCalculator;
|
||||||
|
REGISTER_CALCULATOR(SplitDetectionVectorCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -422,9 +422,12 @@ cc_library(
|
||||||
":recolor_calculator_cc_proto",
|
":recolor_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/util:color_cc_proto",
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
#include "mediapipe/calculators/image/recolor_calculator.pb.h"
|
#include "mediapipe/calculators/image/recolor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
@ -39,8 +42,6 @@ namespace mediapipe {
|
||||||
// The luminance of the input image is used to adjust the blending weight,
|
// The luminance of the input image is used to adjust the blending weight,
|
||||||
// to help preserve image textures.
|
// to help preserve image textures.
|
||||||
//
|
//
|
||||||
// TODO implement cpu support.
|
|
||||||
//
|
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// One of the following IMAGE tags:
|
// One of the following IMAGE tags:
|
||||||
// IMAGE: An ImageFrame input image, RGB or RGBA.
|
// IMAGE: An ImageFrame input image, RGB or RGBA.
|
||||||
|
@ -71,6 +72,8 @@ namespace mediapipe {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
// Note: Cannot mix-match CPU & GPU inputs/outputs.
|
||||||
|
// CPU-in & CPU-out <or> GPU-in & GPU-out
|
||||||
class RecolorCalculator : public CalculatorBase {
|
class RecolorCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
RecolorCalculator() = default;
|
RecolorCalculator() = default;
|
||||||
|
@ -138,6 +141,11 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
||||||
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Confirm only one of the input streams is present.
|
||||||
|
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
|
||||||
|
// Confirm only one of the output streams is present.
|
||||||
|
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
|
||||||
|
|
||||||
if (use_gpu) {
|
if (use_gpu) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
#if !defined(MEDIAPIPE_DISABLE_GPU)
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||||
|
@ -193,7 +201,62 @@ REGISTER_CALCULATOR(RecolorCalculator);
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
return ::mediapipe::UnimplementedError("CPU support is not implemented yet.");
|
if (cc->Inputs().Tag("MASK").IsEmpty()) {
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
// Get inputs and setup output.
|
||||||
|
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
|
||||||
|
const auto& mask_img = cc->Inputs().Tag("MASK").Get<ImageFrame>();
|
||||||
|
|
||||||
|
cv::Mat input_mat = formats::MatView(&input_img);
|
||||||
|
cv::Mat mask_mat = formats::MatView(&mask_img);
|
||||||
|
|
||||||
|
RET_CHECK(input_mat.channels() == 3); // RGB only.
|
||||||
|
|
||||||
|
if (mask_mat.channels() > 1) {
|
||||||
|
std::vector<cv::Mat> channels;
|
||||||
|
cv::split(mask_mat, channels);
|
||||||
|
if (mask_channel_ == mediapipe::RecolorCalculatorOptions_MaskChannel_ALPHA)
|
||||||
|
mask_mat = channels[3];
|
||||||
|
else
|
||||||
|
mask_mat = channels[0];
|
||||||
|
}
|
||||||
|
cv::Mat mask_full;
|
||||||
|
cv::resize(mask_mat, mask_full, input_mat.size());
|
||||||
|
|
||||||
|
auto output_img = absl::make_unique<ImageFrame>(
|
||||||
|
input_img.Format(), input_mat.cols, input_mat.rows);
|
||||||
|
cv::Mat output_mat = mediapipe::formats::MatView(output_img.get());
|
||||||
|
|
||||||
|
// From GPU shader:
|
||||||
|
/*
|
||||||
|
vec4 weight = texture2D(mask, sample_coordinate);
|
||||||
|
vec4 color1 = texture2D(frame, sample_coordinate);
|
||||||
|
vec4 color2 = vec4(recolor, 1.0);
|
||||||
|
|
||||||
|
float luminance = dot(color1.rgb, vec3(0.299, 0.587, 0.114));
|
||||||
|
float mix_value = weight.MASK_COMPONENT * luminance;
|
||||||
|
|
||||||
|
fragColor = mix(color1, color2, mix_value);
|
||||||
|
*/
|
||||||
|
for (int i = 0; i < output_mat.rows; ++i) {
|
||||||
|
for (int j = 0; j < output_mat.cols; ++j) {
|
||||||
|
float weight = mask_full.at<uchar>(i, j) * (1.0 / 255.0);
|
||||||
|
cv::Vec3f color1 = input_mat.at<cv::Vec3b>(i, j);
|
||||||
|
cv::Vec3f color2 = {color_[0], color_[1], color_[2]};
|
||||||
|
|
||||||
|
float luminance =
|
||||||
|
(color1[0] * 0.299 + color1[1] * 0.587 + color1[2] * 0.114) / 255;
|
||||||
|
float mix_value = weight * luminance;
|
||||||
|
|
||||||
|
cv::Vec3b mix_color = color1 * (1.0 - mix_value) + color2 * mix_value;
|
||||||
|
output_mat.at<cv::Vec3b>(i, j) = mix_color;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cc->Outputs().Tag("IMAGE").Add(output_img.release(), cc->InputTimestamp());
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
|
||||||
|
@ -303,9 +366,9 @@ void RecolorCalculator::GlRender() {
|
||||||
|
|
||||||
if (!options.has_color()) RET_CHECK_FAIL() << "Missing color option.";
|
if (!options.has_color()) RET_CHECK_FAIL() << "Missing color option.";
|
||||||
|
|
||||||
color_.push_back(options.color().r() / 255.0);
|
color_.push_back(options.color().r());
|
||||||
color_.push_back(options.color().g() / 255.0);
|
color_.push_back(options.color().g());
|
||||||
color_.push_back(options.color().b() / 255.0);
|
color_.push_back(options.color().b());
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -378,8 +441,8 @@ void RecolorCalculator::GlRender() {
|
||||||
glUseProgram(program_);
|
glUseProgram(program_);
|
||||||
glUniform1i(glGetUniformLocation(program_, "frame"), 1);
|
glUniform1i(glGetUniformLocation(program_, "frame"), 1);
|
||||||
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
|
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
|
||||||
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
|
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0] / 255.0,
|
||||||
color_[2]);
|
color_[1] / 255.0, color_[2] / 255.0);
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
|
|
|
@ -1110,6 +1110,7 @@ cc_test(
|
||||||
],
|
],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
|
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
|
||||||
|
"@org_tensorflow//tensorflow/core:android_tensorflow_test_lib",
|
||||||
],
|
],
|
||||||
"//mediapipe:ios": [
|
"//mediapipe:ios": [
|
||||||
"@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib",
|
"@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib",
|
||||||
|
|
|
@ -222,9 +222,11 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":util",
|
":util",
|
||||||
":tflite_inference_calculator_cc_proto",
|
":tflite_inference_calculator_cc_proto",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
@ -254,6 +256,10 @@ cc_library(
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||||
],
|
],
|
||||||
|
}) + select({
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/util:cpu_util",
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -308,6 +314,20 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tflite_model_calculator",
|
||||||
|
srcs = ["tflite_model_calculator.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tflite_tensors_to_segmentation_calculator",
|
name = "tflite_tensors_to_segmentation_calculator",
|
||||||
srcs = ["tflite_tensors_to_segmentation_calculator.cc"],
|
srcs = ["tflite_tensors_to_segmentation_calculator.cc"],
|
||||||
|
@ -478,6 +498,9 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":tflite_inference_calculator",
|
":tflite_inference_calculator",
|
||||||
":tflite_inference_calculator_cc_proto",
|
":tflite_inference_calculator_cc_proto",
|
||||||
|
":tflite_model_calculator",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||||
|
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
@ -485,7 +508,9 @@ cc_test(
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/framework/tool:validate_type",
|
"//mediapipe/framework/tool:validate_type",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
|
@ -511,3 +536,19 @@ cc_test(
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tflite_model_calculator_test",
|
||||||
|
srcs = ["tflite_model_calculator_test.cc"],
|
||||||
|
data = ["testdata/add.bin"],
|
||||||
|
deps = [
|
||||||
|
":tflite_model_calculator",
|
||||||
|
"//mediapipe/calculators/core:constant_side_packet_calculator",
|
||||||
|
"//mediapipe/calculators/util:local_file_contents_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -17,10 +17,16 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
|
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tflite/util.h"
|
#include "mediapipe/calculators/tflite/util.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
|
||||||
|
#if !defined(__EMSCRIPTEN__)
|
||||||
|
#include "mediapipe/util/cpu_util.h"
|
||||||
|
#endif // !__EMSCRIPTEN__
|
||||||
|
|
||||||
#include "mediapipe/util/resource_util.h"
|
#include "mediapipe/util/resource_util.h"
|
||||||
#include "tensorflow/lite/error_reporter.h"
|
#include "tensorflow/lite/error_reporter.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
@ -50,7 +56,7 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
|
||||||
#endif // iOS
|
#endif // iOS
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#endif // ANDROID
|
#endif // ANDROID
|
||||||
|
@ -113,6 +119,23 @@ struct GPUData {
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Returns number of threads to configure XNNPACK delegate with.
|
||||||
|
// (Equal to user provided value if specified. Otherwise, it returns number of
|
||||||
|
// high cores (hard-coded to 1 for __EMSCRIPTEN__))
|
||||||
|
int GetXnnpackNumThreads(
|
||||||
|
const mediapipe::TfLiteInferenceCalculatorOptions& opts) {
|
||||||
|
static constexpr int kDefaultNumThreads = -1;
|
||||||
|
if (opts.has_delegate() && opts.delegate().has_xnnpack() &&
|
||||||
|
opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) {
|
||||||
|
return opts.delegate().xnnpack().num_threads();
|
||||||
|
}
|
||||||
|
#if !defined(__EMSCRIPTEN__)
|
||||||
|
return InferHigherCoreIds().size();
|
||||||
|
#else
|
||||||
|
return 1;
|
||||||
|
#endif // !__EMSCRIPTEN__
|
||||||
|
}
|
||||||
|
|
||||||
// Calculator Header Section
|
// Calculator Header Section
|
||||||
|
|
||||||
// Runs inference on the provided input TFLite tensors and TFLite model.
|
// Runs inference on the provided input TFLite tensors and TFLite model.
|
||||||
|
@ -139,6 +162,9 @@ struct GPUData {
|
||||||
// Input side packet:
|
// Input side packet:
|
||||||
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
|
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
|
||||||
// instead of the builtin one.
|
// instead of the builtin one.
|
||||||
|
// MODEL (optional) - Use to specify TfLite model
|
||||||
|
// (std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
// std::function<void(tflite::FlatBufferModel*)>>)
|
||||||
//
|
//
|
||||||
// Example use:
|
// Example use:
|
||||||
// node {
|
// node {
|
||||||
|
@ -153,6 +179,20 @@ struct GPUData {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
// or
|
||||||
|
//
|
||||||
|
// node {
|
||||||
|
// calculator: "TfLiteInferenceCalculator"
|
||||||
|
// input_stream: "TENSORS:tensor_image"
|
||||||
|
// input_side_packet: "MODEL:model"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// options: {
|
||||||
|
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
|
// delegate { gpu {} }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
// IMPORTANT Notes:
|
// IMPORTANT Notes:
|
||||||
// Tensors are assumed to be ordered correctly (sequentially added to model).
|
// Tensors are assumed to be ordered correctly (sequentially added to model).
|
||||||
// Input tensors are assumed to be of the correct size and already normalized.
|
// Input tensors are assumed to be of the correct size and already normalized.
|
||||||
|
@ -165,6 +205,9 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
using TfLiteDelegatePtr =
|
using TfLiteDelegatePtr =
|
||||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||||
|
using TfLiteModelPtr =
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
@ -173,12 +216,12 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
||||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
::mediapipe::Status LoadOptions(CalculatorContext* cc);
|
|
||||||
::mediapipe::Status LoadModel(CalculatorContext* cc);
|
::mediapipe::Status LoadModel(CalculatorContext* cc);
|
||||||
|
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
|
||||||
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
||||||
|
|
||||||
|
Packet model_packet_;
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
std::unique_ptr<tflite::FlatBufferModel> model_;
|
|
||||||
TfLiteDelegatePtr delegate_;
|
TfLiteDelegatePtr delegate_;
|
||||||
|
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||||
|
@ -198,7 +241,6 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
||||||
edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
|
edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::string model_path_ = "";
|
|
||||||
bool gpu_inference_ = false;
|
bool gpu_inference_ = false;
|
||||||
bool gpu_input_ = false;
|
bool gpu_input_ = false;
|
||||||
bool gpu_output_ = false;
|
bool gpu_output_ = false;
|
||||||
|
@ -217,6 +259,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
|
|
||||||
const auto& options =
|
const auto& options =
|
||||||
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||||
|
RET_CHECK(!options.model_path().empty() ^
|
||||||
|
cc->InputSidePackets().HasTag("MODEL"))
|
||||||
|
<< "Either model as side packet or model path in options is required.";
|
||||||
|
|
||||||
bool use_gpu =
|
bool use_gpu =
|
||||||
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
|
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
|
||||||
|
|
||||||
|
@ -249,6 +295,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
.Tag("CUSTOM_OP_RESOLVER")
|
.Tag("CUSTOM_OP_RESOLVER")
|
||||||
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag("MODEL")) {
|
||||||
|
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||||
|
}
|
||||||
|
|
||||||
if (use_gpu) {
|
if (use_gpu) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||||
|
@ -267,7 +316,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
|
::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(LoadOptions(cc));
|
const auto& options =
|
||||||
|
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||||
|
gpu_inference_ = options.use_gpu();
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
if (cc->Inputs().HasTag("TENSORS_GPU")) {
|
||||||
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
|
||||||
|
@ -492,34 +543,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
|
|
||||||
// Calculator Auxiliary Section
|
// Calculator Auxiliary Section
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::LoadOptions(
|
|
||||||
CalculatorContext* cc) {
|
|
||||||
// Get calculator options specified in the graph.
|
|
||||||
const auto& options =
|
|
||||||
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
|
|
||||||
|
|
||||||
// Get model name.
|
|
||||||
if (!options.model_path().empty()) {
|
|
||||||
std::string model_path = options.model_path();
|
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(model_path_, mediapipe::PathToResourceAsFile(model_path));
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Must specify path to TFLite model.";
|
|
||||||
return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound,
|
|
||||||
"Must specify path to TFLite model.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get execution modes.
|
|
||||||
gpu_inference_ =
|
|
||||||
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
|
|
||||||
|
|
||||||
return ::mediapipe::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
|
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
|
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
|
||||||
RET_CHECK(model_);
|
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
|
||||||
|
|
||||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||||
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
|
||||||
|
@ -529,9 +556,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
}
|
}
|
||||||
#if defined(MEDIAPIPE_EDGE_TPU)
|
#if defined(MEDIAPIPE_EDGE_TPU)
|
||||||
interpreter_ =
|
interpreter_ =
|
||||||
BuildEdgeTpuInterpreter(*model_, &op_resolver, edgetpu_context_.get());
|
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
|
||||||
#else
|
#else
|
||||||
tflite::InterpreterBuilder(*model_, op_resolver)(&interpreter_);
|
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||||
#endif // MEDIAPIPE_EDGE_TPU
|
#endif // MEDIAPIPE_EDGE_TPU
|
||||||
|
|
||||||
RET_CHECK(interpreter_);
|
RET_CHECK(interpreter_);
|
||||||
|
@ -557,6 +584,28 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
::mediapipe::StatusOr<Packet> TfLiteInferenceCalculator::GetModelAsPacket(
|
||||||
|
const CalculatorContext& cc) {
|
||||||
|
const auto& options =
|
||||||
|
cc.Options<mediapipe::TfLiteInferenceCalculatorOptions>();
|
||||||
|
if (!options.model_path().empty()) {
|
||||||
|
std::string model_path = options.model_path();
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path));
|
||||||
|
|
||||||
|
auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||||
|
RET_CHECK(model) << "Failed to load model from path.";
|
||||||
|
return MakePacket<TfLiteModelPtr>(TfLiteModelPtr(
|
||||||
|
model.release(), [](tflite::FlatBufferModel* model) { delete model; }));
|
||||||
|
}
|
||||||
|
if (cc.InputSidePackets().HasTag("MODEL")) {
|
||||||
|
return cc.InputSidePackets().Tag("MODEL");
|
||||||
|
}
|
||||||
|
return ::mediapipe::Status(
|
||||||
|
::mediapipe::StatusCode::kNotFound,
|
||||||
|
"Must specify TFLite model as path or loaded model.");
|
||||||
|
}
|
||||||
|
|
||||||
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
|
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
const auto& calculator_opts =
|
const auto& calculator_opts =
|
||||||
|
@ -587,6 +636,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
|
|
||||||
|
#if defined(__EMSCRIPTEN__)
|
||||||
|
const bool xnnpack_requested = true;
|
||||||
|
#else
|
||||||
|
const bool xnnpack_requested = calculator_opts.has_delegate() &&
|
||||||
|
calculator_opts.delegate().has_xnnpack();
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
|
|
||||||
|
if (xnnpack_requested) {
|
||||||
|
TfLiteXNNPackDelegateOptions xnnpack_opts{};
|
||||||
|
xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts);
|
||||||
|
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||||
|
&TfLiteXNNPackDelegateDelete);
|
||||||
|
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||||
|
kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
// Return, no need for GPU delegate below.
|
// Return, no need for GPU delegate below.
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,11 +45,17 @@ message TfLiteInferenceCalculatorOptions {
|
||||||
message Gpu {}
|
message Gpu {}
|
||||||
// Android only.
|
// Android only.
|
||||||
message Nnapi {}
|
message Nnapi {}
|
||||||
|
message Xnnpack {
|
||||||
|
// Number of threads for XNNPACK delegate. (By default, calculator tries
|
||||||
|
// to choose optimal number of threads depending on the device.)
|
||||||
|
optional int32 num_threads = 1 [default = -1];
|
||||||
|
}
|
||||||
|
|
||||||
oneof delegate {
|
oneof delegate {
|
||||||
TfLite tflite = 1;
|
TfLite tflite = 1;
|
||||||
Gpu gpu = 2;
|
Gpu gpu = 2;
|
||||||
Nnapi nnapi = 3;
|
Nnapi nnapi = 3;
|
||||||
|
Xnnpack xnnpack = 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ namespace mediapipe {
|
||||||
|
|
||||||
using ::tflite::Interpreter;
|
using ::tflite::Interpreter;
|
||||||
|
|
||||||
void DoSmokeTest(absl::string_view delegate) {
|
void DoSmokeTest(const std::string& graph_proto) {
|
||||||
const int width = 8;
|
const int width = 8;
|
||||||
const int height = 8;
|
const int height = 8;
|
||||||
const int channels = 3;
|
const int channels = 3;
|
||||||
|
@ -69,24 +69,9 @@ void DoSmokeTest(absl::string_view delegate) {
|
||||||
auto input_vec = absl::make_unique<std::vector<TfLiteTensor>>();
|
auto input_vec = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||||
input_vec->emplace_back(*tensor);
|
input_vec->emplace_back(*tensor);
|
||||||
|
|
||||||
std::string graph_proto = R"(
|
|
||||||
input_stream: "tensor_in"
|
|
||||||
node {
|
|
||||||
calculator: "TfLiteInferenceCalculator"
|
|
||||||
input_stream: "TENSORS:tensor_in"
|
|
||||||
output_stream: "TENSORS:tensor_out"
|
|
||||||
options {
|
|
||||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
|
||||||
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
|
|
||||||
$delegate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
ASSERT_EQ(absl::StrReplaceAll({{"$delegate", delegate}}, &graph_proto), 1);
|
|
||||||
// Prepare single calculator graph to and wait for packets.
|
// Prepare single calculator graph to and wait for packets.
|
||||||
CalculatorGraphConfig graph_config =
|
CalculatorGraphConfig graph_config =
|
||||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
|
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
|
||||||
CalculatorGraph graph(graph_config);
|
CalculatorGraph graph(graph_config);
|
||||||
|
@ -119,8 +104,70 @@ void DoSmokeTest(absl::string_view delegate) {
|
||||||
|
|
||||||
// Tests a simple add model that adds an input tensor to itself.
|
// Tests a simple add model that adds an input tensor to itself.
|
||||||
TEST(TfLiteInferenceCalculatorTest, SmokeTest) {
|
TEST(TfLiteInferenceCalculatorTest, SmokeTest) {
|
||||||
DoSmokeTest(/*delegate=*/"");
|
std::string graph_proto = R"(
|
||||||
DoSmokeTest(/*delegate=*/"delegate { tflite {} }");
|
input_stream: "tensor_in"
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteInferenceCalculator"
|
||||||
|
input_stream: "TENSORS:tensor_in"
|
||||||
|
output_stream: "TENSORS:tensor_out"
|
||||||
|
options {
|
||||||
|
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
|
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
|
||||||
|
$delegate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
DoSmokeTest(
|
||||||
|
/*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}}));
|
||||||
|
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||||
|
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
|
||||||
|
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||||
|
graph_proto, {{"$delegate", "delegate { xnnpack {} }"}}));
|
||||||
|
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||||
|
graph_proto,
|
||||||
|
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
|
||||||
|
std::string graph_proto = R"(
|
||||||
|
input_stream: "tensor_in"
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
output_side_packet: "PACKET:model_path"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet { string_value: "mediapipe/calculators/tflite/testdata/add.bin" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "LocalFileContentsCalculator"
|
||||||
|
input_side_packet: "FILE_PATH:model_path"
|
||||||
|
output_side_packet: "CONTENTS:model_blob"
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteModelCalculator"
|
||||||
|
input_side_packet: "MODEL_BLOB:model_blob"
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteInferenceCalculator"
|
||||||
|
input_stream: "TENSORS:tensor_in"
|
||||||
|
output_stream: "TENSORS:tensor_out"
|
||||||
|
input_side_packet: "MODEL:model"
|
||||||
|
options {
|
||||||
|
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||||
|
use_gpu: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
DoSmokeTest(graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
86
mediapipe/calculators/tflite/tflite_model_calculator.cc
Normal file
86
mediapipe/calculators/tflite/tflite_model_calculator.cc
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Loads TfLite model from model blob specified as input side packet and outputs
|
||||||
|
// corresponding side packet.
|
||||||
|
//
|
||||||
|
// Input side packets:
|
||||||
|
// MODEL_BLOB - TfLite model blob/file-contents (std::string). You can read
|
||||||
|
// model blob from file (using whatever APIs you have) and pass
|
||||||
|
// it to the graph as input side packet or you can use some of
|
||||||
|
// calculators like LocalFileContentsCalculator to get model
|
||||||
|
// blob and use it as input here.
|
||||||
|
//
|
||||||
|
// Output side packets:
|
||||||
|
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
// std::function<void(tflite::FlatBufferModel*)>>)
|
||||||
|
//
|
||||||
|
// Example use:
|
||||||
|
//
|
||||||
|
// node {
|
||||||
|
// calculator: "TfLiteModelCalculator"
|
||||||
|
// input_side_packet: "MODEL_BLOB:model_blob"
|
||||||
|
// output_side_packet: "MODEL:model"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
class TfLiteModelCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
using TfLiteModelPtr =
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
|
||||||
|
cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Open(CalculatorContext* cc) override {
|
||||||
|
const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
|
||||||
|
const std::string& model_blob = model_packet.Get<std::string>();
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||||
|
tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
|
||||||
|
model_blob.size());
|
||||||
|
RET_CHECK(model) << "Failed to load TfLite model from blob.";
|
||||||
|
|
||||||
|
cc->OutputSidePackets().Tag("MODEL").Set(
|
||||||
|
MakePacket<TfLiteModelPtr>(TfLiteModelPtr(
|
||||||
|
model.release(), [model_packet](tflite::FlatBufferModel* model) {
|
||||||
|
// Keeping model_packet in order to keep underlying model blob
|
||||||
|
// which can be released only after TfLite model is not needed
|
||||||
|
// anymore (deleted).
|
||||||
|
delete model;
|
||||||
|
})));
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) override {
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(TfLiteModelCalculator);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
88
mediapipe/calculators/tflite/tflite_model_calculator_test.cc
Normal file
88
mediapipe/calculators/tflite/tflite_model_calculator_test.cc
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
TEST(TfLiteModelCalculatorTest, SmokeTest) {
|
||||||
|
// Prepare single calculator graph to and wait for packets.
|
||||||
|
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
||||||
|
CalculatorGraphConfig>(
|
||||||
|
R"(
|
||||||
|
node {
|
||||||
|
calculator: "ConstantSidePacketCalculator"
|
||||||
|
output_side_packet: "PACKET:model_path"
|
||||||
|
options: {
|
||||||
|
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||||
|
packet {
|
||||||
|
string_value: "mediapipe/calculators/tflite/testdata/add.bin"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "LocalFileContentsCalculator"
|
||||||
|
input_side_packet: "FILE_PATH:model_path"
|
||||||
|
output_side_packet: "CONTENTS:model_blob"
|
||||||
|
}
|
||||||
|
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteModelCalculator"
|
||||||
|
input_side_packet: "MODEL_BLOB:model_blob"
|
||||||
|
output_side_packet: "MODEL:model"
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
CalculatorGraph graph(graph_config);
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
auto status_or_packet = graph.GetOutputSidePacket("model");
|
||||||
|
MP_ASSERT_OK(status_or_packet);
|
||||||
|
auto model_packet = status_or_packet.ValueOrDie();
|
||||||
|
const auto& model = model_packet.Get<
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
|
std::function<void(tflite::FlatBufferModel*)>>>();
|
||||||
|
|
||||||
|
auto expected_model = tflite::FlatBufferModel::BuildFromFile(
|
||||||
|
"mediapipe/calculators/tflite/testdata/add.bin");
|
||||||
|
|
||||||
|
EXPECT_EQ(model->GetModel()->version(),
|
||||||
|
expected_model->GetModel()->version());
|
||||||
|
EXPECT_EQ(model->GetModel()->buffers()->size(),
|
||||||
|
expected_model->GetModel()->buffers()->size());
|
||||||
|
const int num_subgraphs = expected_model->GetModel()->subgraphs()->size();
|
||||||
|
EXPECT_EQ(model->GetModel()->subgraphs()->size(), num_subgraphs);
|
||||||
|
for (int i = 0; i < num_subgraphs; ++i) {
|
||||||
|
const auto* expected_subgraph =
|
||||||
|
expected_model->GetModel()->subgraphs()->Get(i);
|
||||||
|
const auto* subgraph = model->GetModel()->subgraphs()->Get(i);
|
||||||
|
const int num_tensors = expected_subgraph->tensors()->size();
|
||||||
|
EXPECT_EQ(subgraph->tensors()->size(), num_tensors);
|
||||||
|
for (int j = 0; j < num_tensors; ++j) {
|
||||||
|
EXPECT_EQ(subgraph->tensors()->Get(j)->name()->str(),
|
||||||
|
expected_subgraph->tensors()->Get(j)->name()->str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -129,24 +129,45 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
|
||||||
num_classes *= raw_score_tensor->dims->data[i];
|
num_classes *= raw_score_tensor->dims->data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (options_.binary_classification()) {
|
||||||
|
RET_CHECK_EQ(num_classes, 1);
|
||||||
|
// Number of classes for binary classification.
|
||||||
|
num_classes = 2;
|
||||||
|
}
|
||||||
if (label_map_loaded_) {
|
if (label_map_loaded_) {
|
||||||
RET_CHECK_EQ(num_classes, label_map_.size());
|
RET_CHECK_EQ(num_classes, label_map_.size());
|
||||||
}
|
}
|
||||||
const float* raw_scores = raw_score_tensor->data.f;
|
const float* raw_scores = raw_score_tensor->data.f;
|
||||||
|
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
|
if (options_.binary_classification()) {
|
||||||
|
Classification* class_first = classification_list->add_classification();
|
||||||
|
Classification* class_second = classification_list->add_classification();
|
||||||
|
class_first->set_index(0);
|
||||||
|
class_second->set_index(1);
|
||||||
|
class_first->set_score(raw_scores[0]);
|
||||||
|
class_second->set_score(1. - raw_scores[0]);
|
||||||
|
|
||||||
|
if (label_map_loaded_) {
|
||||||
|
class_first->set_label(label_map_[0]);
|
||||||
|
class_second->set_label(label_map_[1]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int i = 0; i < num_classes; ++i) {
|
for (int i = 0; i < num_classes; ++i) {
|
||||||
if (options_.has_min_score_threshold() &&
|
if (options_.has_min_score_threshold() &&
|
||||||
raw_scores[i] < options_.min_score_threshold()) {
|
raw_scores[i] < options_.min_score_threshold()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Classification* classification = classification_list->add_classification();
|
Classification* classification =
|
||||||
|
classification_list->add_classification();
|
||||||
classification->set_index(i);
|
classification->set_index(i);
|
||||||
classification->set_score(raw_scores[i]);
|
classification->set_score(raw_scores[i]);
|
||||||
|
|
||||||
if (label_map_loaded_) {
|
if (label_map_loaded_) {
|
||||||
classification->set_label(label_map_[i]);
|
classification->set_label(label_map_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note that partial_sort will raise error when top_k_ >
|
// Note that partial_sort will raise error when top_k_ >
|
||||||
// classification_list->classification_size().
|
// classification_list->classification_size().
|
||||||
|
|
|
@ -32,4 +32,10 @@ message TfLiteTensorsToClassificationCalculatorOptions {
|
||||||
optional int32 top_k = 2;
|
optional int32 top_k = 2;
|
||||||
// Path to a label map file for getting the actual name of class ids.
|
// Path to a label map file for getting the actual name of class ids.
|
||||||
optional string label_map_path = 3;
|
optional string label_map_path = 3;
|
||||||
|
// Whether the input is a single float for binary classification.
|
||||||
|
// When true, only a single float is expected in the input tensor and the
|
||||||
|
// label map, if provided, is expected to have exactly two labels.
|
||||||
|
// The single score(float) represent the probability of first label, and
|
||||||
|
// 1 - score is the probabilility of the second label.
|
||||||
|
optional bool binary_classification = 4;
|
||||||
}
|
}
|
||||||
|
|
|
@ -998,6 +998,7 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
@ -1015,6 +1016,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":collection_has_min_size_calculator_cc_proto",
|
":collection_has_min_size_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -1022,6 +1024,18 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "collection_has_min_size_calculator_test",
|
||||||
|
srcs = ["collection_has_min_size_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":collection_has_min_size_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "association_calculator",
|
name = "association_calculator",
|
||||||
hdrs = ["association_calculator.h"],
|
hdrs = ["association_calculator.h"],
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
|
|
||||||
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
|
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -23,4 +26,9 @@ typedef CollectionHasMinSizeCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||||
NormalizedRectVectorHasMinSizeCalculator;
|
NormalizedRectVectorHasMinSizeCalculator;
|
||||||
REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator);
|
REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator);
|
||||||
|
|
||||||
|
typedef CollectionHasMinSizeCalculator<
|
||||||
|
std::vector<::mediapipe::NormalizedLandmarkList>>
|
||||||
|
NormalizedLandmarkListVectorHasMinSizeCalculator;
|
||||||
|
REGISTER_CALCULATOR(NormalizedLandmarkListVectorHasMinSizeCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
typedef CollectionHasMinSizeCalculator<std::vector<int>>
|
||||||
|
TestIntCollectionHasMinSizeCalculator;
|
||||||
|
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
||||||
|
|
||||||
|
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
||||||
|
CalculatorRunner* runner) {
|
||||||
|
runner->MutableInputs()
|
||||||
|
->Tag("ITERABLE")
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestIntCollectionHasMinSizeCalculator, DoesHaveMinSize) {
|
||||||
|
CalculatorGraphConfig::Node node_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||||
|
calculator: "TestIntCollectionHasMinSizeCalculator"
|
||||||
|
input_stream: "ITERABLE:input_vector"
|
||||||
|
output_stream: "output_vector"
|
||||||
|
options {
|
||||||
|
[mediapipe.CollectionHasMinSizeCalculatorOptions.ext] { min_size: 2 }
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
CalculatorRunner runner(node_config);
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
|
||||||
|
AddInputVector({1, 2}, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
EXPECT_TRUE(outputs[0].Get<bool>());
|
||||||
|
|
||||||
|
AddInputVector({1, 2, 3}, /*timestamp=*/2, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||||
|
EXPECT_TRUE(outputs[1].Get<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestIntCollectionHasMinSizeCalculator,
|
||||||
|
DoesHaveMinSize_MinSizeAsSidePacket) {
|
||||||
|
CalculatorGraphConfig::Node node_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||||
|
calculator: "TestIntCollectionHasMinSizeCalculator"
|
||||||
|
input_stream: "ITERABLE:input_vector"
|
||||||
|
input_side_packet: "min_size"
|
||||||
|
output_stream: "output_vector"
|
||||||
|
)");
|
||||||
|
CalculatorRunner runner(node_config);
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
|
||||||
|
runner.MutableSidePackets()->Index(0) = MakePacket<int>(2);
|
||||||
|
|
||||||
|
AddInputVector({1, 2}, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
EXPECT_TRUE(outputs[0].Get<bool>());
|
||||||
|
|
||||||
|
AddInputVector({1, 2, 3}, /*timestamp=*/2, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||||
|
EXPECT_TRUE(outputs[1].Get<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestIntCollectionHasMinSizeCalculator, DoesNotHaveMinSize) {
|
||||||
|
CalculatorGraphConfig::Node node_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||||
|
calculator: "TestIntCollectionHasMinSizeCalculator"
|
||||||
|
input_stream: "ITERABLE:input_vector"
|
||||||
|
output_stream: "output_vector"
|
||||||
|
options {
|
||||||
|
[mediapipe.CollectionHasMinSizeCalculatorOptions.ext] { min_size: 3 }
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
CalculatorRunner runner(node_config);
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
|
||||||
|
AddInputVector({1}, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
EXPECT_FALSE(outputs[0].Get<bool>());
|
||||||
|
|
||||||
|
AddInputVector({1, 2}, /*timestamp=*/2, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||||
|
EXPECT_FALSE(outputs[1].Get<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestIntCollectionHasMinSizeCalculator,
|
||||||
|
DoesNotHaveMinSize_MinSizeAsSidePacket) {
|
||||||
|
CalculatorGraphConfig::Node node_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
|
||||||
|
calculator: "TestIntCollectionHasMinSizeCalculator"
|
||||||
|
input_stream: "ITERABLE:input_vector"
|
||||||
|
input_side_packet: "min_size"
|
||||||
|
output_stream: "output_vector"
|
||||||
|
)");
|
||||||
|
CalculatorRunner runner(node_config);
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
|
||||||
|
|
||||||
|
runner.MutableSidePackets()->Index(0) = MakePacket<int>(3);
|
||||||
|
|
||||||
|
AddInputVector({1}, /*timestamp=*/1, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
|
||||||
|
EXPECT_FALSE(outputs[0].Get<bool>());
|
||||||
|
|
||||||
|
AddInputVector({1, 2}, /*timestamp=*/2, &runner);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, outputs.size());
|
||||||
|
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
|
||||||
|
EXPECT_FALSE(outputs[1].Get<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
|
||||||
|
@ -31,4 +32,8 @@ typedef FilterCollectionCalculator<
|
||||||
FilterLandmarkListCollectionCalculator;
|
FilterLandmarkListCollectionCalculator;
|
||||||
REGISTER_CALCULATOR(FilterLandmarkListCollectionCalculator);
|
REGISTER_CALCULATOR(FilterLandmarkListCollectionCalculator);
|
||||||
|
|
||||||
|
typedef FilterCollectionCalculator<std::vector<::mediapipe::ClassificationList>>
|
||||||
|
FilterClassificationListCollectionCalculator;
|
||||||
|
REGISTER_CALCULATOR(FilterClassificationListCollectionCalculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -29,6 +29,7 @@ namespace {
|
||||||
|
|
||||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||||
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
||||||
|
constexpr char kRenderScaleTag[] = "RENDER_SCALE";
|
||||||
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||||
constexpr char kLandmarkLabel[] = "KEYPOINT";
|
constexpr char kLandmarkLabel[] = "KEYPOINT";
|
||||||
constexpr int kMaxLandmarkThickness = 18;
|
constexpr int kMaxLandmarkThickness = 18;
|
||||||
|
@ -71,6 +72,83 @@ void SetColorSizeValueFromZ(float z, float z_min, float z_max,
|
||||||
render_annotation->set_thickness(thickness);
|
render_annotation->set_thickness(thickness);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class LandmarkType>
|
||||||
|
void AddConnectionToRenderData(const LandmarkType& start,
|
||||||
|
const LandmarkType& end, int gray_val1,
|
||||||
|
int gray_val2, float thickness, bool normalized,
|
||||||
|
RenderData* render_data) {
|
||||||
|
auto* connection_annotation = render_data->add_render_annotations();
|
||||||
|
RenderAnnotation::GradientLine* line =
|
||||||
|
connection_annotation->mutable_gradient_line();
|
||||||
|
line->set_x_start(start.x());
|
||||||
|
line->set_y_start(start.y());
|
||||||
|
line->set_x_end(end.x());
|
||||||
|
line->set_y_end(end.y());
|
||||||
|
line->set_normalized(normalized);
|
||||||
|
line->mutable_color1()->set_r(gray_val1);
|
||||||
|
line->mutable_color1()->set_g(gray_val1);
|
||||||
|
line->mutable_color1()->set_b(gray_val1);
|
||||||
|
line->mutable_color2()->set_r(gray_val2);
|
||||||
|
line->mutable_color2()->set_g(gray_val2);
|
||||||
|
line->mutable_color2()->set_b(gray_val2);
|
||||||
|
connection_annotation->set_thickness(thickness);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class LandmarkListType, class LandmarkType>
|
||||||
|
void AddConnectionsWithDepth(const LandmarkListType& landmarks,
|
||||||
|
const std::vector<int>& landmark_connections,
|
||||||
|
float thickness, bool normalized, float min_z,
|
||||||
|
float max_z, RenderData* render_data) {
|
||||||
|
for (int i = 0; i < landmark_connections.size(); i += 2) {
|
||||||
|
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
|
||||||
|
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
|
||||||
|
const int gray_val1 =
|
||||||
|
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
|
||||||
|
const int gray_val2 =
|
||||||
|
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
|
||||||
|
AddConnectionToRenderData<LandmarkType>(ld0, ld1, gray_val1, gray_val2,
|
||||||
|
thickness, normalized, render_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class LandmarkType>
|
||||||
|
void AddConnectionToRenderData(const LandmarkType& start,
|
||||||
|
const LandmarkType& end,
|
||||||
|
const Color& connection_color, float thickness,
|
||||||
|
bool normalized, RenderData* render_data) {
|
||||||
|
auto* connection_annotation = render_data->add_render_annotations();
|
||||||
|
RenderAnnotation::Line* line = connection_annotation->mutable_line();
|
||||||
|
line->set_x_start(start.x());
|
||||||
|
line->set_y_start(start.y());
|
||||||
|
line->set_x_end(end.x());
|
||||||
|
line->set_y_end(end.y());
|
||||||
|
line->set_normalized(normalized);
|
||||||
|
SetColor(connection_annotation, connection_color);
|
||||||
|
connection_annotation->set_thickness(thickness);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class LandmarkListType, class LandmarkType>
|
||||||
|
void AddConnections(const LandmarkListType& landmarks,
|
||||||
|
const std::vector<int>& landmark_connections,
|
||||||
|
const Color& connection_color, float thickness,
|
||||||
|
bool normalized, RenderData* render_data) {
|
||||||
|
for (int i = 0; i < landmark_connections.size(); i += 2) {
|
||||||
|
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
|
||||||
|
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
|
||||||
|
AddConnectionToRenderData<LandmarkType>(ld0, ld1, connection_color,
|
||||||
|
thickness, normalized, render_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RenderAnnotation* AddPointRenderData(const Color& landmark_color,
|
||||||
|
float thickness, RenderData* render_data) {
|
||||||
|
auto* landmark_data_annotation = render_data->add_render_annotations();
|
||||||
|
landmark_data_annotation->set_scene_tag(kLandmarkLabel);
|
||||||
|
SetColor(landmark_data_annotation, landmark_color);
|
||||||
|
landmark_data_annotation->set_thickness(thickness);
|
||||||
|
return landmark_data_annotation;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A calculator that converts Landmark proto to RenderData proto for
|
// A calculator that converts Landmark proto to RenderData proto for
|
||||||
|
@ -107,29 +185,6 @@ class LandmarksToRenderDataCalculator : public CalculatorBase {
|
||||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void AddConnectionToRenderData(
|
|
||||||
float start_x, float start_y, float end_x, float end_y,
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
|
|
||||||
RenderData* render_data);
|
|
||||||
static void SetRenderAnnotationColorThickness(
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options,
|
|
||||||
RenderAnnotation* render_annotation);
|
|
||||||
static RenderAnnotation* AddPointRenderData(
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options,
|
|
||||||
RenderData* render_data);
|
|
||||||
static void AddConnectionToRenderData(
|
|
||||||
float start_x, float start_y, float end_x, float end_y,
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
|
|
||||||
int gray_val1, int gray_val2, RenderData* render_data);
|
|
||||||
|
|
||||||
template <class LandmarkListType>
|
|
||||||
void AddConnections(const LandmarkListType& landmarks, bool normalized,
|
|
||||||
RenderData* render_data);
|
|
||||||
template <class LandmarkListType>
|
|
||||||
void AddConnectionsWithDepth(const LandmarkListType& landmarks,
|
|
||||||
bool normalized, float min_z, float max_z,
|
|
||||||
RenderData* render_data);
|
|
||||||
|
|
||||||
LandmarksToRenderDataCalculatorOptions options_;
|
LandmarksToRenderDataCalculatorOptions options_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
|
@ -150,6 +205,9 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
|
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
|
||||||
cc->Inputs().Tag(kNormLandmarksTag).Set<NormalizedLandmarkList>();
|
cc->Inputs().Tag(kNormLandmarksTag).Set<NormalizedLandmarkList>();
|
||||||
}
|
}
|
||||||
|
if (cc->Inputs().HasTag(kRenderScaleTag)) {
|
||||||
|
cc->Inputs().Tag(kRenderScaleTag).Set<float>();
|
||||||
|
}
|
||||||
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
|
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -169,11 +227,26 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
float z_min = 0.f;
|
float z_min = 0.f;
|
||||||
float z_max = 0.f;
|
float z_max = 0.f;
|
||||||
|
|
||||||
|
// Apply scale to `thickness` of rendered landmarks and connections to make
|
||||||
|
// them bigger when object (e.g. pose, hand or face) is closer/bigger and
|
||||||
|
// snaller when object is further/smaller.
|
||||||
|
float thickness = options_.thickness();
|
||||||
|
if (cc->Inputs().HasTag(kRenderScaleTag)) {
|
||||||
|
const float render_scale = cc->Inputs().Tag(kRenderScaleTag).Get<float>();
|
||||||
|
thickness *= render_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse landmarks connections to a vector.
|
||||||
|
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
|
||||||
|
<< "Number of entries in landmark connections must be a multiple of 2";
|
||||||
|
std::vector<int> landmark_connections;
|
||||||
|
for (int i = 0; i < options_.landmark_connections_size(); i += 1) {
|
||||||
|
landmark_connections.push_back(options_.landmark_connections(i));
|
||||||
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kLandmarksTag)) {
|
if (cc->Inputs().HasTag(kLandmarksTag)) {
|
||||||
const LandmarkList& landmarks =
|
const LandmarkList& landmarks =
|
||||||
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
|
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
|
||||||
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
|
|
||||||
<< "Number of entries in landmark connections must be a multiple of 2";
|
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
GetMinMaxZ<LandmarkList, Landmark>(landmarks, &z_min, &z_max);
|
GetMinMaxZ<LandmarkList, Landmark>(landmarks, &z_min, &z_max);
|
||||||
}
|
}
|
||||||
|
@ -181,8 +254,8 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
visualize_depth &= ((z_max - z_min) > 1e-3);
|
visualize_depth &= ((z_max - z_min) > 1e-3);
|
||||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||||
const Landmark& landmark = landmarks.landmark(i);
|
const Landmark& landmark = landmarks.landmark(i);
|
||||||
auto* landmark_data_render =
|
auto* landmark_data_render = AddPointRenderData(
|
||||||
AddPointRenderData(options_, render_data.get());
|
options_.landmark_color(), thickness, render_data.get());
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||||
landmark_data_render);
|
landmark_data_render);
|
||||||
|
@ -193,19 +266,19 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
landmark_data->set_y(landmark.y());
|
landmark_data->set_y(landmark.y());
|
||||||
}
|
}
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
AddConnectionsWithDepth<LandmarkList>(landmarks, /*normalized=*/false,
|
AddConnectionsWithDepth<LandmarkList, Landmark>(
|
||||||
|
landmarks, landmark_connections, thickness, /*normalized=*/false,
|
||||||
z_min, z_max, render_data.get());
|
z_min, z_max, render_data.get());
|
||||||
} else {
|
} else {
|
||||||
AddConnections<LandmarkList>(landmarks, /*normalized=*/false,
|
AddConnections<LandmarkList, Landmark>(
|
||||||
render_data.get());
|
landmarks, landmark_connections, options_.connection_color(),
|
||||||
|
thickness, /*normalized=*/false, render_data.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
|
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
|
||||||
const NormalizedLandmarkList& landmarks =
|
const NormalizedLandmarkList& landmarks =
|
||||||
cc->Inputs().Tag(kNormLandmarksTag).Get<NormalizedLandmarkList>();
|
cc->Inputs().Tag(kNormLandmarksTag).Get<NormalizedLandmarkList>();
|
||||||
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
|
|
||||||
<< "Number of entries in landmark connections must be a multiple of 2";
|
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
GetMinMaxZ<NormalizedLandmarkList, NormalizedLandmark>(landmarks, &z_min,
|
GetMinMaxZ<NormalizedLandmarkList, NormalizedLandmark>(landmarks, &z_min,
|
||||||
&z_max);
|
&z_max);
|
||||||
|
@ -214,8 +287,8 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
visualize_depth &= ((z_max - z_min) > 1e-3);
|
visualize_depth &= ((z_max - z_min) > 1e-3);
|
||||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||||
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
const NormalizedLandmark& landmark = landmarks.landmark(i);
|
||||||
auto* landmark_data_render =
|
auto* landmark_data_render = AddPointRenderData(
|
||||||
AddPointRenderData(options_, render_data.get());
|
options_.landmark_color(), thickness, render_data.get());
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
|
||||||
landmark_data_render);
|
landmark_data_render);
|
||||||
|
@ -226,11 +299,13 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
landmark_data->set_y(landmark.y());
|
landmark_data->set_y(landmark.y());
|
||||||
}
|
}
|
||||||
if (visualize_depth) {
|
if (visualize_depth) {
|
||||||
AddConnectionsWithDepth<NormalizedLandmarkList>(
|
AddConnectionsWithDepth<NormalizedLandmarkList, NormalizedLandmark>(
|
||||||
landmarks, /*normalized=*/true, z_min, z_max, render_data.get());
|
landmarks, landmark_connections, thickness, /*normalized=*/true,
|
||||||
|
z_min, z_max, render_data.get());
|
||||||
} else {
|
} else {
|
||||||
AddConnections<NormalizedLandmarkList>(landmarks, /*normalized=*/true,
|
AddConnections<NormalizedLandmarkList, NormalizedLandmark>(
|
||||||
render_data.get());
|
landmarks, landmark_connections, options_.connection_color(),
|
||||||
|
thickness, /*normalized=*/true, render_data.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,84 +315,4 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class LandmarkListType>
|
|
||||||
void LandmarksToRenderDataCalculator::AddConnectionsWithDepth(
|
|
||||||
const LandmarkListType& landmarks, bool normalized, float min_z,
|
|
||||||
float max_z, RenderData* render_data) {
|
|
||||||
for (int i = 0; i < options_.landmark_connections_size(); i += 2) {
|
|
||||||
const auto& ld0 = landmarks.landmark(options_.landmark_connections(i));
|
|
||||||
const auto& ld1 = landmarks.landmark(options_.landmark_connections(i + 1));
|
|
||||||
const int gray_val1 =
|
|
||||||
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
|
|
||||||
const int gray_val2 =
|
|
||||||
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
|
|
||||||
AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_,
|
|
||||||
normalized, gray_val1, gray_val2, render_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LandmarksToRenderDataCalculator::AddConnectionToRenderData(
|
|
||||||
float start_x, float start_y, float end_x, float end_y,
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
|
|
||||||
int gray_val1, int gray_val2, RenderData* render_data) {
|
|
||||||
auto* connection_annotation = render_data->add_render_annotations();
|
|
||||||
RenderAnnotation::GradientLine* line =
|
|
||||||
connection_annotation->mutable_gradient_line();
|
|
||||||
line->set_x_start(start_x);
|
|
||||||
line->set_y_start(start_y);
|
|
||||||
line->set_x_end(end_x);
|
|
||||||
line->set_y_end(end_y);
|
|
||||||
line->set_normalized(normalized);
|
|
||||||
line->mutable_color1()->set_r(gray_val1);
|
|
||||||
line->mutable_color1()->set_g(gray_val1);
|
|
||||||
line->mutable_color1()->set_b(gray_val1);
|
|
||||||
line->mutable_color2()->set_r(gray_val2);
|
|
||||||
line->mutable_color2()->set_g(gray_val2);
|
|
||||||
line->mutable_color2()->set_b(gray_val2);
|
|
||||||
connection_annotation->set_thickness(options.thickness());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class LandmarkListType>
|
|
||||||
void LandmarksToRenderDataCalculator::AddConnections(
|
|
||||||
const LandmarkListType& landmarks, bool normalized,
|
|
||||||
RenderData* render_data) {
|
|
||||||
for (int i = 0; i < options_.landmark_connections_size(); i += 2) {
|
|
||||||
const auto& ld0 = landmarks.landmark(options_.landmark_connections(i));
|
|
||||||
const auto& ld1 = landmarks.landmark(options_.landmark_connections(i + 1));
|
|
||||||
AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_,
|
|
||||||
normalized, render_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LandmarksToRenderDataCalculator::AddConnectionToRenderData(
|
|
||||||
float start_x, float start_y, float end_x, float end_y,
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
|
|
||||||
RenderData* render_data) {
|
|
||||||
auto* connection_annotation = render_data->add_render_annotations();
|
|
||||||
RenderAnnotation::Line* line = connection_annotation->mutable_line();
|
|
||||||
line->set_x_start(start_x);
|
|
||||||
line->set_y_start(start_y);
|
|
||||||
line->set_x_end(end_x);
|
|
||||||
line->set_y_end(end_y);
|
|
||||||
line->set_normalized(normalized);
|
|
||||||
SetColor(connection_annotation, options.connection_color());
|
|
||||||
connection_annotation->set_thickness(options.thickness());
|
|
||||||
}
|
|
||||||
|
|
||||||
RenderAnnotation* LandmarksToRenderDataCalculator::AddPointRenderData(
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options,
|
|
||||||
RenderData* render_data) {
|
|
||||||
auto* landmark_data_annotation = render_data->add_render_annotations();
|
|
||||||
landmark_data_annotation->set_scene_tag(kLandmarkLabel);
|
|
||||||
SetRenderAnnotationColorThickness(options, landmark_data_annotation);
|
|
||||||
return landmark_data_annotation;
|
|
||||||
}
|
|
||||||
|
|
||||||
void LandmarksToRenderDataCalculator::SetRenderAnnotationColorThickness(
|
|
||||||
const LandmarksToRenderDataCalculatorOptions& options,
|
|
||||||
RenderAnnotation* render_annotation) {
|
|
||||||
SetColor(render_annotation, options.landmark_color());
|
|
||||||
render_annotation->set_thickness(options.thickness());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -276,6 +276,7 @@ TEST_F(PacketLatencyCalculatorTest, DoesNotOutputUntilReferencePacketReceived) {
|
||||||
"delayed_packet_0", Adopt(new double()).At(Timestamp(2))));
|
"delayed_packet_0", Adopt(new double()).At(Timestamp(2))));
|
||||||
|
|
||||||
// Send a reference packet with timestamp 10 usec.
|
// Send a reference packet with timestamp 10 usec.
|
||||||
|
simulation_clock_->Sleep(absl::Microseconds(1));
|
||||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||||
"camera_frames", Adopt(new double()).At(Timestamp(10))));
|
"camera_frames", Adopt(new double()).At(Timestamp(10))));
|
||||||
simulation_clock_->Sleep(absl::Microseconds(1));
|
simulation_clock_->Sleep(absl::Microseconds(1));
|
||||||
|
|
|
@ -138,7 +138,7 @@ cc_library(
|
||||||
srcs = ["flow_to_image_calculator.cc"],
|
srcs = ["flow_to_image_calculator.cc"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/video:flow_to_image_calculator_cc_proto",
|
":flow_to_image_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/video/tool:flow_quantizer_model",
|
"//mediapipe/calculators/video/tool:flow_quantizer_model",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||||
|
@ -384,20 +384,18 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
MEDIAPIPE_DEPS = [
|
|
||||||
"//mediapipe/calculators/video:box_tracker_calculator",
|
|
||||||
"//mediapipe/calculators/video:flow_packager_calculator",
|
|
||||||
"//mediapipe/calculators/video:motion_analysis_calculator",
|
|
||||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
|
||||||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
|
||||||
]
|
|
||||||
|
|
||||||
mediapipe_binary_graph(
|
mediapipe_binary_graph(
|
||||||
name = "parallel_tracker_binarypb",
|
name = "parallel_tracker_binarypb",
|
||||||
graph = "testdata/parallel_tracker_graph.pbtxt",
|
graph = "testdata/parallel_tracker_graph.pbtxt",
|
||||||
output_name = "testdata/parallel_tracker.binarypb",
|
output_name = "testdata/parallel_tracker.binarypb",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = MEDIAPIPE_DEPS,
|
deps = [
|
||||||
|
":box_tracker_calculator",
|
||||||
|
":flow_packager_calculator",
|
||||||
|
":motion_analysis_calculator",
|
||||||
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
|
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_binary_graph(
|
mediapipe_binary_graph(
|
||||||
|
@ -405,7 +403,13 @@ mediapipe_binary_graph(
|
||||||
graph = "testdata/tracker_graph.pbtxt",
|
graph = "testdata/tracker_graph.pbtxt",
|
||||||
output_name = "testdata/tracker.binarypb",
|
output_name = "testdata/tracker.binarypb",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = MEDIAPIPE_DEPS,
|
deps = [
|
||||||
|
":box_tracker_calculator",
|
||||||
|
":flow_packager_calculator",
|
||||||
|
":motion_analysis_calculator",
|
||||||
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
|
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
tricorder: {
|
||||||
|
options: {
|
||||||
|
builder: {
|
||||||
|
config: "android_arm64"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -95,7 +95,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
|
||||||
const std::vector<FocusPointFrame>& focus_point_frames,
|
const std::vector<FocusPointFrame>& focus_point_frames,
|
||||||
const std::vector<FocusPointFrame>& prior_focus_point_frames,
|
const std::vector<FocusPointFrame>& prior_focus_point_frames,
|
||||||
const int original_width, const int original_height, const int output_width,
|
const int original_width, const int original_height, const int output_width,
|
||||||
const int output_height, std::vector<cv::Mat>* all_xforms) {
|
const int output_height, std::vector<cv::Mat>* all_transforms) {
|
||||||
RET_CHECK_GE(original_width, output_width);
|
RET_CHECK_GE(original_width, output_width);
|
||||||
RET_CHECK_GE(original_height, output_height);
|
RET_CHECK_GE(original_height, output_height);
|
||||||
const bool should_solve_x_problem = original_width != output_width;
|
const bool should_solve_x_problem = original_width != output_width;
|
||||||
|
@ -138,9 +138,10 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
|
||||||
Solver::Options options;
|
Solver::Options options;
|
||||||
options.linear_solver_type = ceres::DENSE_QR;
|
options.linear_solver_type = ceres::DENSE_QR;
|
||||||
|
|
||||||
Solver::Summary summary;
|
Solver::Summary summary_x, summary_y;
|
||||||
Solve(options, &problem_x, &summary);
|
Solve(options, &problem_x, &summary_x);
|
||||||
all_xforms->clear();
|
Solve(options, &problem_y, &summary_y);
|
||||||
|
all_transforms->clear();
|
||||||
for (int i = 0;
|
for (int i = 0;
|
||||||
i < focus_point_frames.size() + prior_focus_point_frames.size(); i++) {
|
i < focus_point_frames.size() + prior_focus_point_frames.size(); i++) {
|
||||||
// Code below assigns values into an affine model, defined as:
|
// Code below assigns values into an affine model, defined as:
|
||||||
|
@ -160,7 +161,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
|
||||||
yb_, yc_, yd_, yk_);
|
yb_, yc_, yd_, yk_);
|
||||||
transform.at<float>(1, 2) = delta;
|
transform.at<float>(1, 2) = delta;
|
||||||
}
|
}
|
||||||
all_xforms->push_back(transform);
|
all_transforms->push_back(transform);
|
||||||
}
|
}
|
||||||
return mediapipe::OkStatus();
|
return mediapipe::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,14 +40,14 @@ class PolynomialRegressionPathSolver {
|
||||||
// Given a series of focus points on frames, uses polynomial regression to
|
// Given a series of focus points on frames, uses polynomial regression to
|
||||||
// compute a best guess of a 1D camera movement trajectory along x-axis and
|
// compute a best guess of a 1D camera movement trajectory along x-axis and
|
||||||
// y-axis, such that focus points can be preserved as much as possible. The
|
// y-axis, such that focus points can be preserved as much as possible. The
|
||||||
// returned |all_xforms| hold the camera location at each timestamp
|
// returned |all_transforms| hold the camera location at each timestamp
|
||||||
// corresponding to each input frame.
|
// corresponding to each input frame.
|
||||||
::mediapipe::Status ComputeCameraPath(
|
::mediapipe::Status ComputeCameraPath(
|
||||||
const std::vector<FocusPointFrame>& focus_point_frames,
|
const std::vector<FocusPointFrame>& focus_point_frames,
|
||||||
const std::vector<FocusPointFrame>& prior_focus_point_frames,
|
const std::vector<FocusPointFrame>& prior_focus_point_frames,
|
||||||
const int original_width, const int original_height,
|
const int original_width, const int original_height,
|
||||||
const int output_width, const int output_height,
|
const int output_width, const int output_height,
|
||||||
std::vector<cv::Mat>* all_xforms);
|
std::vector<cv::Mat>* all_transforms);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Adds a new cost function, constructed using |in| and |out|, into |problem|.
|
// Adds a new cost function, constructed using |in| and |out|, into |problem|.
|
||||||
|
|
|
@ -24,3 +24,17 @@ cc_binary(
|
||||||
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
|
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "hair_segmentation_cpu",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/examples/desktop:demo_run_graph_main",
|
||||||
|
] + select({
|
||||||
|
"//mediapipe/gpu:disable_gpu": [
|
||||||
|
"//mediapipe/graphs/hair_segmentation:desktop_calculators",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
|
@ -361,6 +361,7 @@ cc_library(
|
||||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||||
"//mediapipe/framework:packet_generator_cc_proto",
|
"//mediapipe/framework:packet_generator_cc_proto",
|
||||||
"//mediapipe/framework:status_handler_cc_proto",
|
"//mediapipe/framework:status_handler_cc_proto",
|
||||||
|
"//mediapipe/framework:stream_handler_cc_proto",
|
||||||
"//mediapipe/framework/port:any_proto",
|
"//mediapipe/framework/port:any_proto",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/tool:options_util",
|
"//mediapipe/framework/tool:options_util",
|
||||||
|
|
|
@ -84,7 +84,7 @@ class CalculatorContract {
|
||||||
return *output_side_packets_;
|
return *output_side_packets_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set this Node's default InputStreamHandler.
|
// Specifies the preferred InputStreamHandler for this Node.
|
||||||
// If there is an InputStreamHandler specified in the graph (.pbtxt) for this
|
// If there is an InputStreamHandler specified in the graph (.pbtxt) for this
|
||||||
// Node, then the graph's InputStreamHandler will take priority.
|
// Node, then the graph's InputStreamHandler will take priority.
|
||||||
void SetInputStreamHandler(const std::string& name) {
|
void SetInputStreamHandler(const std::string& name) {
|
||||||
|
@ -104,6 +104,29 @@ class CalculatorContract {
|
||||||
return input_stream_handler_options_;
|
return input_stream_handler_options_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The next few methods are concerned with timestamp bound propagation
|
||||||
|
// (see scheduling_sync.md#input-policies). Every calculator that processes
|
||||||
|
// live inputs should specify either ProcessTimestampBounds or
|
||||||
|
// TimestampOffset. Calculators that produce output at the same timestamp as
|
||||||
|
// the input, or with a fixed offset, should declare this fact using
|
||||||
|
// SetTimestampOffset. Calculators that require custom timestamp bound
|
||||||
|
// calculations should use SetProcessTimestampBounds.
|
||||||
|
|
||||||
|
// When true, Process is called for every new timestamp bound, with or without
|
||||||
|
// new packets. A call to Process with only an input timestamp bound is
|
||||||
|
// normally used to compute a new output timestamp bound.
|
||||||
|
void SetProcessTimestampBounds(bool process_timestamps) {
|
||||||
|
process_timestamps_ = process_timestamps;
|
||||||
|
}
|
||||||
|
bool GetProcessTimestampBounds() const { return process_timestamps_; }
|
||||||
|
|
||||||
|
// Specifies the maximum difference between input and output timestamps.
|
||||||
|
// When specified, the mediapipe framework automatically computes output
|
||||||
|
// timestamp bounds based on input timestamps. The special value
|
||||||
|
// TimestampDiff::Unset disables the timestamp offset.
|
||||||
|
void SetTimestampOffset(TimestampDiff offset) { timestamp_offset_ = offset; }
|
||||||
|
TimestampDiff GetTimestampOffset() const { return timestamp_offset_; }
|
||||||
|
|
||||||
class GraphServiceRequest {
|
class GraphServiceRequest {
|
||||||
public:
|
public:
|
||||||
// APIs that should be used by calculators.
|
// APIs that should be used by calculators.
|
||||||
|
@ -147,6 +170,8 @@ class CalculatorContract {
|
||||||
MediaPipeOptions input_stream_handler_options_;
|
MediaPipeOptions input_stream_handler_options_;
|
||||||
std::string node_name_;
|
std::string node_name_;
|
||||||
std::map<std::string, GraphServiceRequest> service_requests_;
|
std::map<std::string, GraphServiceRequest> service_requests_;
|
||||||
|
bool process_timestamps_ = false;
|
||||||
|
TimestampDiff timestamp_offset_ = TimestampDiff::Unset();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -143,7 +143,7 @@ class CalculatorGraph {
|
||||||
const std::string& graph_type = "",
|
const std::string& graph_type = "",
|
||||||
const Subgraph::SubgraphOptions* options = nullptr);
|
const Subgraph::SubgraphOptions* options = nullptr);
|
||||||
|
|
||||||
// Resturns the canonicalized CalculatorGraphConfig for this graph.
|
// Returns the canonicalized CalculatorGraphConfig for this graph.
|
||||||
const CalculatorGraphConfig& Config() const {
|
const CalculatorGraphConfig& Config() const {
|
||||||
return validated_graph_->Config();
|
return validated_graph_->Config();
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,17 @@ namespace {
|
||||||
typedef std::function<::mediapipe::Status(CalculatorContext* cc)>
|
typedef std::function<::mediapipe::Status(CalculatorContext* cc)>
|
||||||
CalculatorContextFunction;
|
CalculatorContextFunction;
|
||||||
|
|
||||||
|
// Returns the contents of a set of Packets.
|
||||||
|
// The contents must be copyable.
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetContents(const std::vector<Packet>& packets) {
|
||||||
|
std::vector<T> result;
|
||||||
|
for (Packet p : packets) {
|
||||||
|
result.push_back(p.Get<T>());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// A simple Semaphore for synchronizing test threads.
|
// A simple Semaphore for synchronizing test threads.
|
||||||
class AtomicSemaphore {
|
class AtomicSemaphore {
|
||||||
public:
|
public:
|
||||||
|
@ -671,9 +682,9 @@ REGISTER_CALCULATOR(BoundToPacketCalculator);
|
||||||
|
|
||||||
// A Calculator that produces packets at timestamps beyond the input timestamp.
|
// A Calculator that produces packets at timestamps beyond the input timestamp.
|
||||||
class FuturePacketCalculator : public CalculatorBase {
|
class FuturePacketCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
static constexpr int64 kOutputFutureMicros = 3;
|
static constexpr int64 kOutputFutureMicros = 3;
|
||||||
|
|
||||||
public:
|
|
||||||
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).Set<int>();
|
cc->Inputs().Index(0).Set<int>();
|
||||||
cc->Outputs().Index(0).Set<int>();
|
cc->Outputs().Index(0).Set<int>();
|
||||||
|
@ -742,9 +753,8 @@ TEST(CalculatorGraphBoundsTest, OffsetBoundPropagation) {
|
||||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shows that bounds changes alone do not invoke Process.
|
// Shows that timestamp bounds changes alone do not invoke Process,
|
||||||
// Note: Bounds changes alone will invoke Process eventually
|
// without SetProcessTimestampBounds(true).
|
||||||
// when SetOffset is cleared, see: go/mediapipe-realtime-graph.
|
|
||||||
TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
|
TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
|
||||||
// OffsetBoundCalculator produces only timestamp bounds.
|
// OffsetBoundCalculator produces only timestamp bounds.
|
||||||
// The BoundToPacketCalculator delivers an output packet whenever the
|
// The BoundToPacketCalculator delivers an output packet whenever the
|
||||||
|
@ -753,8 +763,13 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
|
||||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||||
input_stream: 'input'
|
input_stream: 'input'
|
||||||
node {
|
node {
|
||||||
calculator: 'OffsetBoundCalculator'
|
calculator: 'FuturePacketCalculator'
|
||||||
input_stream: 'input'
|
input_stream: 'input'
|
||||||
|
output_stream: 'input_2'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'OffsetBoundCalculator'
|
||||||
|
input_stream: 'input_2'
|
||||||
output_stream: 'bounds'
|
output_stream: 'bounds'
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
|
@ -778,6 +793,7 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
|
||||||
for (int i = 0; i < kNumInputs; ++i) {
|
for (int i = 0; i < kNumInputs; ++i) {
|
||||||
Packet p = MakePacket<int>(33).At(Timestamp(i));
|
Packet p = MakePacket<int>(33).At(Timestamp(i));
|
||||||
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
}
|
}
|
||||||
|
|
||||||
// No packets arrive, because updated timestamp bounds do not invoke
|
// No packets arrive, because updated timestamp bounds do not invoke
|
||||||
|
@ -1104,5 +1120,254 @@ TEST(CalculatorGraphBoundsTest, BoundsForEmptyInputs_SyncSets) {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A Calculator that produces a packet for each timestamp bounds update.
|
||||||
|
class ProcessBoundToPacketCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||||
|
cc->Inputs().Index(i).SetAny();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||||
|
cc->Outputs().Index(i).Set<Timestamp>();
|
||||||
|
}
|
||||||
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
|
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||||
|
Timestamp t = cc->Inputs().Index(i).Value().Timestamp();
|
||||||
|
if (t == cc->InputTimestamp() &&
|
||||||
|
t >= cc->Outputs().Index(i).NextTimestampBound()) {
|
||||||
|
cc->Outputs().Index(i).Add(new auto(t), t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(ProcessBoundToPacketCalculator);
|
||||||
|
|
||||||
|
// A Calculator that passes through each packet and timestamp immediately.
|
||||||
|
class ImmediatePassthroughCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
|
||||||
|
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
|
||||||
|
cc->Inputs().Index(i).SetAny();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||||
|
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
|
||||||
|
}
|
||||||
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
::mediapipe::Status Process(CalculatorContext* cc) final {
|
||||||
|
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
|
||||||
|
if (!cc->Inputs().Index(i).IsEmpty()) {
|
||||||
|
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
|
||||||
|
} else {
|
||||||
|
Timestamp input_bound =
|
||||||
|
cc->Inputs().Index(i).Value().Timestamp().NextAllowedInStream();
|
||||||
|
if (cc->Outputs().Index(i).NextTimestampBound() < input_bound) {
|
||||||
|
cc->Outputs().Index(i).SetNextTimestampBound(input_bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(ImmediatePassthroughCalculator);
|
||||||
|
|
||||||
|
// Shows that Process is called for input-sets without input packets.
|
||||||
|
void TestProcessForEmptyInputs(const std::string& input_stream_handler) {
|
||||||
|
// FuturePacketCalculator and OffsetBoundCalculator produce only ts bounds,
|
||||||
|
// The ProcessBoundToPacketCalculator has SetProcessTimestampBounds(true),
|
||||||
|
// and produces an output packet for every timestamp bound update.
|
||||||
|
std::string config_str = R"(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: 'FuturePacketCalculator'
|
||||||
|
input_stream: 'input'
|
||||||
|
output_stream: 'futures'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'OffsetBoundCalculator'
|
||||||
|
input_stream: 'futures'
|
||||||
|
output_stream: 'bounds'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'ProcessBoundToPacketCalculator'
|
||||||
|
input_stream: 'bounds'
|
||||||
|
output_stream: 'bounds_ts'
|
||||||
|
input_stream_handler { $input_stream_handler }
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
absl::StrReplaceAll({{"$input_stream_handler", input_stream_handler}},
|
||||||
|
&config_str);
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> input_ts_packets;
|
||||||
|
std::vector<Packet> bounds_ts_packets;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream("bounds_ts", [&](const Packet& p) {
|
||||||
|
bounds_ts_packets.push_back(p);
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Add four packets into the graph.
|
||||||
|
constexpr int kFutureMicros = FuturePacketCalculator::kOutputFutureMicros;
|
||||||
|
Packet p;
|
||||||
|
p = MakePacket<int>(33).At(Timestamp(0));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
p = MakePacket<int>(33).At(Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
p = MakePacket<int>(33).At(Timestamp(20));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
p = MakePacket<int>(33).At(Timestamp(30));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Packets arrive.
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_EQ(bounds_ts_packets.size(), 4);
|
||||||
|
|
||||||
|
std::vector<Timestamp> expected = {
|
||||||
|
Timestamp(0 + kFutureMicros), Timestamp(10 + kFutureMicros),
|
||||||
|
Timestamp(20 + kFutureMicros), Timestamp(30 + kFutureMicros)};
|
||||||
|
EXPECT_EQ(GetContents<Timestamp>(bounds_ts_packets), expected);
|
||||||
|
|
||||||
|
// Shutdown the graph.
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows that Process is called for input-sets without input packets
|
||||||
|
// using an DefaultInputStreamHandler.
|
||||||
|
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Default) {
|
||||||
|
TestProcessForEmptyInputs(R"(
|
||||||
|
input_stream_handler: "DefaultInputStreamHandler")");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows that Process is called for input-sets without input packets
|
||||||
|
// using an ImmediateInputStreamHandler.
|
||||||
|
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Immediate) {
|
||||||
|
TestProcessForEmptyInputs(R"(
|
||||||
|
input_stream_handler: "ImmediateInputStreamHandler")");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows that Process is called for input-sets without input packets
|
||||||
|
// using a SyncSetInputStreamHandler with a single sync-set.
|
||||||
|
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_SyncSet) {
|
||||||
|
TestProcessForEmptyInputs(R"(
|
||||||
|
input_stream_handler: "SyncSetInputStreamHandler")");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows that Process is called for input-sets without input packets
|
||||||
|
// using a SyncSetInputStreamHandler with multiple sync-sets.
|
||||||
|
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_SyncSets) {
|
||||||
|
TestProcessForEmptyInputs(R"(
|
||||||
|
input_stream_handler: "SyncSetInputStreamHandler"
|
||||||
|
options {
|
||||||
|
[mediapipe.SyncSetInputStreamHandlerOptions.ext] {
|
||||||
|
sync_set { tag_index: ":0" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Demonstrates the functionality of an "ImmediatePassthroughCalculator".
|
||||||
|
// The ImmediatePassthroughCalculator simply relays each input packet to
|
||||||
|
// the corresponding output stream. ProcessTimestampBounds is needed to
|
||||||
|
// relay timestamp bounds as well as packets.
|
||||||
|
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) {
|
||||||
|
// OffsetBoundCalculator produces timestamp bounds.
|
||||||
|
// ImmediatePassthroughCalculator relays packets and bounds.
|
||||||
|
// ProcessBoundToPacketCalculator reports packets and bounds as packets.
|
||||||
|
std::string config_str = R"(
|
||||||
|
input_stream: "input_0"
|
||||||
|
input_stream: "input_1"
|
||||||
|
node {
|
||||||
|
calculator: "OffsetBoundCalculator"
|
||||||
|
input_stream: "input_1"
|
||||||
|
output_stream: "bound_1"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "ImmediatePassthroughCalculator"
|
||||||
|
input_stream: "input_0"
|
||||||
|
input_stream: "bound_1"
|
||||||
|
output_stream: "same_0"
|
||||||
|
output_stream: "same_1"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "ProcessBoundToPacketCalculator"
|
||||||
|
input_stream: "same_0"
|
||||||
|
input_stream: "same_1"
|
||||||
|
output_stream: "output_0"
|
||||||
|
output_stream: "output_1"
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
CalculatorGraphConfig config =
|
||||||
|
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
std::vector<Packet> output_0_packets;
|
||||||
|
std::vector<Packet> output_1_packets;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(config));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
|
||||||
|
output_0_packets.push_back(p);
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_ASSERT_OK(graph.ObserveOutputStream("output_1", [&](const Packet& p) {
|
||||||
|
output_1_packets.push_back(p);
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
// Add four packets to input_0.
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
Packet p = MakePacket<int>(33).At(Timestamp(i * 10));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Packets arrive.
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_EQ(output_0_packets.size(), 4);
|
||||||
|
EXPECT_EQ(output_1_packets.size(), 0);
|
||||||
|
std::vector<Timestamp> expected = //
|
||||||
|
{Timestamp(0), Timestamp(10), Timestamp(20), Timestamp(30)};
|
||||||
|
EXPECT_EQ(GetContents<Timestamp>(output_0_packets), expected);
|
||||||
|
|
||||||
|
// Add two timestamp bounds to bound_1.
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
Packet p = MakePacket<int>(33).At(Timestamp(10 + i * 10));
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream("input_1", p));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bounds arrive.
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_EQ(output_0_packets.size(), 4);
|
||||||
|
EXPECT_EQ(output_1_packets.size(), 2);
|
||||||
|
expected = //
|
||||||
|
{Timestamp(10), Timestamp(20)};
|
||||||
|
EXPECT_EQ(GetContents<Timestamp>(output_1_packets), expected);
|
||||||
|
|
||||||
|
// Shutdown the graph.
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -97,6 +97,7 @@ Timestamp CalculatorNode::SourceProcessOrder(
|
||||||
|
|
||||||
const NodeTypeInfo& node_type_info =
|
const NodeTypeInfo& node_type_info =
|
||||||
validated_graph_->CalculatorInfos()[node_id_];
|
validated_graph_->CalculatorInfos()[node_id_];
|
||||||
|
const CalculatorContract& contract = node_type_info.Contract();
|
||||||
|
|
||||||
uses_gpu_ =
|
uses_gpu_ =
|
||||||
node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
|
node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
|
||||||
|
@ -147,6 +148,14 @@ Timestamp CalculatorNode::SourceProcessOrder(
|
||||||
use_calc_specified ? handler_config : node_config.input_stream_handler(),
|
use_calc_specified ? handler_config : node_config.input_stream_handler(),
|
||||||
node_type_info.InputStreamTypes()));
|
node_type_info.InputStreamTypes()));
|
||||||
|
|
||||||
|
for (auto& stream : output_stream_handler_->OutputStreams()) {
|
||||||
|
stream->Spec()->offset_enabled =
|
||||||
|
(contract.GetTimestampOffset() != TimestampDiff::Unset());
|
||||||
|
stream->Spec()->offset = contract.GetTimestampOffset();
|
||||||
|
}
|
||||||
|
input_stream_handler_->SetProcessTimestampBounds(
|
||||||
|
contract.GetProcessTimestampBounds());
|
||||||
|
|
||||||
return InitializeInputStreams(input_stream_managers, output_stream_managers);
|
return InitializeInputStreams(input_stream_managers, output_stream_managers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,10 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// List of namespaces that can register calculators inside the namespace
|
||||||
|
// and still refer to them using an unqualified name. This whitelist
|
||||||
|
// is meant to facilitate migration from unqualified to fully qualified
|
||||||
|
// calculator names.
|
||||||
constexpr char const* kTopNamespaces[] = {
|
constexpr char const* kTopNamespaces[] = {
|
||||||
"mediapipe",
|
"mediapipe",
|
||||||
};
|
};
|
||||||
|
|
|
@ -49,3 +49,10 @@ mediapipe_cc_proto_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [":rasterization_proto"],
|
deps = [":rasterization_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Expose the proto source files for building mediapipe AAR.
|
||||||
|
filegroup(
|
||||||
|
name = "protos_src",
|
||||||
|
srcs = glob(["*.proto"]),
|
||||||
|
visibility = ["//mediapipe:__subpackages__"],
|
||||||
|
)
|
||||||
|
|
|
@ -16,6 +16,9 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.formats.annotation.proto";
|
||||||
|
option java_outer_classname = "RasterizationProto";
|
||||||
|
|
||||||
// A Region can be represented in each frame as a set of scanlines
|
// A Region can be represented in each frame as a set of scanlines
|
||||||
// (compressed RLE, similar to rasterization of polygons).
|
// (compressed RLE, similar to rasterization of polygons).
|
||||||
// For each scanline with y-coordinate y, we save (possibly multiple) intervals
|
// For each scanline with y-coordinate y, we save (possibly multiple) intervals
|
||||||
|
|
|
@ -23,6 +23,9 @@ package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/formats/annotation/rasterization.proto";
|
import "mediapipe/framework/formats/annotation/rasterization.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.formats.proto";
|
||||||
|
option java_outer_classname = "LocationDataProto";
|
||||||
|
|
||||||
message LocationData {
|
message LocationData {
|
||||||
// The supported formats for representing location data. A single location
|
// The supported formats for representing location data. A single location
|
||||||
// must store its data in exactly one way.
|
// must store its data in exactly one way.
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using SyncSet = InputStreamHandler::SyncSet;
|
||||||
|
|
||||||
::mediapipe::Status InputStreamHandler::InitializeInputStreamManagers(
|
::mediapipe::Status InputStreamHandler::InitializeInputStreamManagers(
|
||||||
InputStreamManager* flat_input_stream_managers) {
|
InputStreamManager* flat_input_stream_managers) {
|
||||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
for (CollectionItemId id = input_stream_managers_.BeginId();
|
||||||
|
@ -300,4 +302,92 @@ void InputStreamHandler::SetLatePreparation(bool late_preparation) {
|
||||||
late_preparation_ = late_preparation;
|
late_preparation_ = late_preparation;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SyncSet::SyncSet(InputStreamHandler* input_stream_handler,
|
||||||
|
std::vector<CollectionItemId> stream_ids)
|
||||||
|
: input_stream_handler_(input_stream_handler),
|
||||||
|
stream_ids_(std::move(stream_ids)) {}
|
||||||
|
|
||||||
|
NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) {
|
||||||
|
Timestamp min_bound = Timestamp::Done();
|
||||||
|
Timestamp min_packet = Timestamp::Done();
|
||||||
|
for (CollectionItemId id : stream_ids_) {
|
||||||
|
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
|
||||||
|
bool empty;
|
||||||
|
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
|
||||||
|
if (empty) {
|
||||||
|
min_bound = std::min(min_bound, stream_timestamp);
|
||||||
|
} else {
|
||||||
|
min_packet = std::min(min_packet, stream_timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*min_stream_timestamp = std::min(min_packet, min_bound);
|
||||||
|
if (*min_stream_timestamp == Timestamp::Done()) {
|
||||||
|
last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream();
|
||||||
|
return NodeReadiness::kReadyForClose;
|
||||||
|
}
|
||||||
|
if (!input_stream_handler_->process_timestamps_) {
|
||||||
|
// Only an input_ts with packets can be processed.
|
||||||
|
// Note that (min_bound - 1) is the highest fully settled timestamp.
|
||||||
|
if (min_bound > min_packet) {
|
||||||
|
last_processed_ts_ = *min_stream_timestamp;
|
||||||
|
return NodeReadiness::kReadyForProcess;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Any unprocessed input_ts can be processed.
|
||||||
|
// Note that (min_bound - 1) is the highest fully settled timestamp.
|
||||||
|
Timestamp input_timestamp =
|
||||||
|
std::min(min_packet, min_bound.PreviousAllowedInStream());
|
||||||
|
if (input_timestamp >
|
||||||
|
std::max(last_processed_ts_, Timestamp::Unstarted())) {
|
||||||
|
*min_stream_timestamp = input_timestamp;
|
||||||
|
last_processed_ts_ = input_timestamp;
|
||||||
|
return NodeReadiness::kReadyForProcess;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NodeReadiness::kNotReady;
|
||||||
|
}
|
||||||
|
|
||||||
|
Timestamp SyncSet::LastProcessed() const { return last_processed_ts_; }
|
||||||
|
|
||||||
|
Timestamp SyncSet::MinPacketTimestamp() const {
|
||||||
|
Timestamp result = Timestamp::Done();
|
||||||
|
for (CollectionItemId id : stream_ids_) {
|
||||||
|
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
|
||||||
|
bool empty;
|
||||||
|
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
|
||||||
|
if (!empty) {
|
||||||
|
result = std::min(result, stream_timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SyncSet::FillInputSet(Timestamp input_timestamp,
|
||||||
|
InputStreamShardSet* input_set) {
|
||||||
|
CHECK(input_timestamp.IsAllowedInStream());
|
||||||
|
CHECK(input_set);
|
||||||
|
for (CollectionItemId id : stream_ids_) {
|
||||||
|
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
|
||||||
|
int num_packets_dropped = 0;
|
||||||
|
bool stream_is_done = false;
|
||||||
|
Packet current_packet = stream->PopPacketAtTimestamp(
|
||||||
|
input_timestamp, &num_packets_dropped, &stream_is_done);
|
||||||
|
CHECK_EQ(num_packets_dropped, 0)
|
||||||
|
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
||||||
|
num_packets_dropped, stream->Name());
|
||||||
|
input_stream_handler_->AddPacketToShard(
|
||||||
|
&input_set->Get(id), std::move(current_packet), stream_is_done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SyncSet::FillInputBounds(InputStreamShardSet* input_set) {
|
||||||
|
for (CollectionItemId id : stream_ids_) {
|
||||||
|
const auto* stream = input_stream_handler_->input_stream_managers_.Get(id);
|
||||||
|
Timestamp bound = stream->MinTimestampOrBound(nullptr);
|
||||||
|
input_stream_handler_->AddPacketToShard(
|
||||||
|
&input_set->Get(id), Packet().At(bound.PreviousAllowedInStream()),
|
||||||
|
bound == Timestamp::Done());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -74,9 +74,7 @@ class InputStreamHandler {
|
||||||
: input_stream_managers_(std::move(tag_map)),
|
: input_stream_managers_(std::move(tag_map)),
|
||||||
calculator_context_manager_(calculator_context_manager),
|
calculator_context_manager_(calculator_context_manager),
|
||||||
options_(options),
|
options_(options),
|
||||||
calculator_run_in_parallel_(calculator_run_in_parallel),
|
calculator_run_in_parallel_(calculator_run_in_parallel) {}
|
||||||
late_preparation_(false),
|
|
||||||
batch_size_(1) {}
|
|
||||||
|
|
||||||
virtual ~InputStreamHandler() = default;
|
virtual ~InputStreamHandler() = default;
|
||||||
|
|
||||||
|
@ -174,6 +172,57 @@ class InputStreamHandler {
|
||||||
return unset_header_count_.load(std::memory_order_relaxed);
|
return unset_header_count_.load(std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When true, Calculator::Process is called for any increase in the
|
||||||
|
// timestamp bound, whether or not any packets are available.
|
||||||
|
// Calculator::Process is called when the minimum timestamp bound
|
||||||
|
// increases for any synchronized set of input streams.
|
||||||
|
// DefaultInputStreamHandler groups all input streams into a single set.
|
||||||
|
// ImmediateInputStreamHandler treats each input stream as a separate set.
|
||||||
|
void SetProcessTimestampBounds(bool process_ts) {
|
||||||
|
process_timestamps_ = process_ts;
|
||||||
|
}
|
||||||
|
|
||||||
|
// When true, Calculator::Process is called for every input timestamp bound.
|
||||||
|
bool ProcessTimestampBounds() { return process_timestamps_; }
|
||||||
|
|
||||||
|
// A helper class to build input packet sets for a certain set of streams.
|
||||||
|
//
|
||||||
|
// ReadyForProcess requires all of the streams to be fully determined
|
||||||
|
// at the same input-timestamp.
|
||||||
|
// This is the readiness policy for all streams in DefaultInputStreamHandler.
|
||||||
|
// It is also the policy for each sync-set in SyncSetInputStreamHandler.
|
||||||
|
// It is also the policy for each input-stream in ImmediateInputStreamHandler.
|
||||||
|
//
|
||||||
|
// If ProcessTimestampBounds() is set, then a fully determined input timestamp
|
||||||
|
// with only empty input packets will qualify as ReadyForProcess.
|
||||||
|
class SyncSet {
|
||||||
|
public:
|
||||||
|
// Creates a SyncSet for a certain set of streams, |stream_ids|.
|
||||||
|
SyncSet(InputStreamHandler* input_stream_handler,
|
||||||
|
std::vector<CollectionItemId> stream_ids);
|
||||||
|
|
||||||
|
// Answers whether this stream is ready for Process or Close.
|
||||||
|
NodeReadiness GetReadiness(Timestamp* min_stream_timestamp);
|
||||||
|
|
||||||
|
// Returns the latest timestamp returned for processing.
|
||||||
|
Timestamp LastProcessed() const;
|
||||||
|
|
||||||
|
// The earliest available packet timestamp, or Timestamp::Done.
|
||||||
|
Timestamp MinPacketTimestamp() const;
|
||||||
|
|
||||||
|
// Moves packets from all input streams to the input_set.
|
||||||
|
void FillInputSet(Timestamp input_timestamp,
|
||||||
|
InputStreamShardSet* input_set);
|
||||||
|
|
||||||
|
// Copies timestamp bounds from all input streams to the input_set.
|
||||||
|
void FillInputBounds(InputStreamShardSet* input_set);
|
||||||
|
|
||||||
|
private:
|
||||||
|
InputStreamHandler* input_stream_handler_;
|
||||||
|
std::vector<CollectionItemId> stream_ids_;
|
||||||
|
Timestamp last_processed_ts_ = Timestamp::Unset();
|
||||||
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typedef internal::Collection<InputStreamManager*> InputStreamManagerSet;
|
typedef internal::Collection<InputStreamManager*> InputStreamManagerSet;
|
||||||
|
|
||||||
|
@ -240,11 +289,14 @@ class InputStreamHandler {
|
||||||
// The variable is set to false by default. A subclass should set it to true
|
// The variable is set to false by default. A subclass should set it to true
|
||||||
// with SetLatePreparation(true) in the constructor if the input sets need to
|
// with SetLatePreparation(true) in the constructor if the input sets need to
|
||||||
// be filled in ProcessNode().
|
// be filled in ProcessNode().
|
||||||
bool late_preparation_;
|
bool late_preparation_ = false;
|
||||||
|
|
||||||
// Determines how many sets of input packets are collected before a
|
// Determines how many sets of input packets are collected before a
|
||||||
// CalculatorNode is scheduled.
|
// CalculatorNode is scheduled.
|
||||||
int batch_size_;
|
int batch_size_ = 1;
|
||||||
|
|
||||||
|
// When true, any increase in timestamp bound invokes Calculator::Process.
|
||||||
|
bool process_timestamps_ = false;
|
||||||
|
|
||||||
// A callback to notify the observer when all the input stream headers
|
// A callback to notify the observer when all the input stream headers
|
||||||
// (excluding headers of back edges) become available.
|
// (excluding headers of back edges) become available.
|
||||||
|
|
|
@ -107,6 +107,9 @@ CalculatorContext* LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
|
||||||
template <>
|
template <>
|
||||||
CalculatorContract*
|
CalculatorContract*
|
||||||
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
|
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
|
||||||
|
#elif _MSC_VER
|
||||||
|
// MSVC interprets these declarations as definitions and during linking it
|
||||||
|
// generates an error about multiple definitions of current_.
|
||||||
#else
|
#else
|
||||||
template <>
|
template <>
|
||||||
thread_local CalculatorContext*
|
thread_local CalculatorContext*
|
||||||
|
|
|
@ -46,6 +46,7 @@ class OutputStreamHandler {
|
||||||
// ids of upstream sources that affect it.
|
// ids of upstream sources that affect it.
|
||||||
typedef std::unordered_map<std::string, std::unordered_set<int>>
|
typedef std::unordered_map<std::string, std::unordered_set<int>>
|
||||||
OutputStreamToSourcesMap;
|
OutputStreamToSourcesMap;
|
||||||
|
typedef internal::Collection<OutputStreamManager*> OutputStreamManagerSet;
|
||||||
|
|
||||||
// The constructor of the OutputStreamHandler takes four arguments.
|
// The constructor of the OutputStreamHandler takes four arguments.
|
||||||
// The tag_map argument holds the information needed for tag/index retrieval
|
// The tag_map argument holds the information needed for tag/index retrieval
|
||||||
|
@ -119,9 +120,11 @@ class OutputStreamHandler {
|
||||||
// collection for debugging purpose.
|
// collection for debugging purpose.
|
||||||
std::string FirstStreamName() const;
|
std::string FirstStreamName() const;
|
||||||
|
|
||||||
protected:
|
const OutputStreamManagerSet& OutputStreams() {
|
||||||
typedef internal::Collection<OutputStreamManager*> OutputStreamManagerSet;
|
return output_stream_managers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
// Checks if the given input bound should be propagated or not. If any output
|
// Checks if the given input bound should be propagated or not. If any output
|
||||||
// streams with OffsetEnabled() need to have the timestamp bounds updated,
|
// streams with OffsetEnabled() need to have the timestamp bounds updated,
|
||||||
// then propagates the timestamp bounds of all output streams with
|
// then propagates the timestamp bounds of all output streams with
|
||||||
|
|
|
@ -27,6 +27,9 @@ class OutputStreamPoller {
|
||||||
OutputStreamPoller(const OutputStreamPoller&) = delete;
|
OutputStreamPoller(const OutputStreamPoller&) = delete;
|
||||||
OutputStreamPoller& operator=(const OutputStreamPoller&) = delete;
|
OutputStreamPoller& operator=(const OutputStreamPoller&) = delete;
|
||||||
OutputStreamPoller(OutputStreamPoller&&) = default;
|
OutputStreamPoller(OutputStreamPoller&&) = default;
|
||||||
|
// Move assignment needs to be explicitly defaulted to allow ASSIGN_OR_RETURN
|
||||||
|
// on `StatusOr<OutputStreamPoller>`.
|
||||||
|
OutputStreamPoller& operator=(OutputStreamPoller&&) = default;
|
||||||
|
|
||||||
// Resets OutputStramPollerImpl and cleans the internal packet queue.
|
// Resets OutputStramPollerImpl and cleans the internal packet queue.
|
||||||
void Reset() {
|
void Reset() {
|
||||||
|
|
97
mediapipe/framework/profiler/testdata/profile_latency_test.pbtxt
vendored
Normal file
97
mediapipe/framework/profiler/testdata/profile_latency_test.pbtxt
vendored
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
graph_trace: {
|
||||||
|
calculator_name : ["ACalculator", "BCalculator"]
|
||||||
|
stream_name : [ "", "input1", "a_b"]
|
||||||
|
base_time : 0
|
||||||
|
base_timestamp : 100
|
||||||
|
|
||||||
|
# Fire off three input packets and have them spend time in Calculator A.
|
||||||
|
# Drop the middle packet.
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: -1
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 1000
|
||||||
|
output_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: -1
|
||||||
|
input_timestamp: 101
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 2000
|
||||||
|
output_trace: {
|
||||||
|
packet_timestamp: 101
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 1200 # 200 after initial input (emits at 1000)
|
||||||
|
finish_time : 1500 # Speed to delivery is 500 (1500 - 1000)
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 101
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 2100 # 100 after initial input (emits at 2000)
|
||||||
|
finish_time : 2500 # Speed to delivery is 500 (2500 - 2000)
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 101
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 1
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 1600 # 600 after the initial input (emits at 1000)
|
||||||
|
finish_time : 2000 # Speed to delivery is 1000 (2000 - 1000)
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 1
|
||||||
|
input_timestamp: 101
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 2900 # 700 after the initial input (emits at 2000)
|
||||||
|
finish_time : 3100 # Speed to delivery is 1000 (3000 - 2000)
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 101
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config: {
|
||||||
|
node: {
|
||||||
|
name: "ACalculator"
|
||||||
|
calculator: "ACalculator"
|
||||||
|
input_stream: "input1"
|
||||||
|
output_stream: "a_b"
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "BCalculator"
|
||||||
|
calculator: "BCalculator"
|
||||||
|
input_stream: "a_b"
|
||||||
|
}
|
||||||
|
}
|
122
mediapipe/framework/profiler/testdata/profile_process_test.pbtxt
vendored
Normal file
122
mediapipe/framework/profiler/testdata/profile_process_test.pbtxt
vendored
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
graph_trace: {
|
||||||
|
calculator_name : ["ACalculator", "BCalculator"]
|
||||||
|
stream_name : [ "", "input1"]
|
||||||
|
base_time : 0
|
||||||
|
base_timestamp : 100
|
||||||
|
|
||||||
|
# Fire off three input packets and have them spend time in Calculator A.
|
||||||
|
# Drop the middle packet.
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: -1
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 1000
|
||||||
|
output_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: -1
|
||||||
|
input_timestamp: 101
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 2000
|
||||||
|
output_trace: {
|
||||||
|
packet_timestamp: 101
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: -1
|
||||||
|
input_timestamp: 102
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 3000
|
||||||
|
output_trace: {
|
||||||
|
packet_timestamp: 102
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# First event is disconnected. We'll see the output_trace later.
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 1100
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# # We're going to drop this packet.
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 101
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 2100
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 101
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
# # Here's that matching output trace.
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 100
|
||||||
|
event_type : PROCESS
|
||||||
|
finish_time : 1500
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 100
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
# Third packet is processed all at the same time.
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 0
|
||||||
|
input_timestamp: 102
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 3100
|
||||||
|
finish_time : 3600
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 102
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# A second calculator will process an input in order to affect the
|
||||||
|
# time_percent.
|
||||||
|
|
||||||
|
calculator_trace: {
|
||||||
|
node_id: 1
|
||||||
|
input_timestamp: 102
|
||||||
|
event_type : PROCESS
|
||||||
|
start_time : 3200
|
||||||
|
finish_time : 3500
|
||||||
|
input_trace: {
|
||||||
|
packet_timestamp: 102
|
||||||
|
stream_id : 1
|
||||||
|
}
|
||||||
|
thread_id : 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config: {
|
||||||
|
node: {
|
||||||
|
name: "ACalculator"
|
||||||
|
calculator: "ACalculator"
|
||||||
|
input_stream: "input1"
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "BCalculator"
|
||||||
|
calculator: "BCalculator"
|
||||||
|
input_stream: "input1"
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,7 +25,11 @@
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#define AUTORELEASEPOOL @autoreleasepool
|
||||||
|
#else
|
||||||
#define AUTORELEASEPOOL
|
#define AUTORELEASEPOOL
|
||||||
|
#endif // __APPLE__
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
|
@ -17,16 +17,28 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/input_stream_handler.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
REGISTER_INPUT_STREAM_HANDLER(DefaultInputStreamHandler);
|
REGISTER_INPUT_STREAM_HANDLER(DefaultInputStreamHandler);
|
||||||
|
|
||||||
|
// Returns all CollectionItemId's for a Collection TagMap.
|
||||||
|
std::vector<CollectionItemId> GetIds(
|
||||||
|
const std::shared_ptr<tool::TagMap>& tag_map) {
|
||||||
|
std::vector<CollectionItemId> result;
|
||||||
|
for (auto id = tag_map->BeginId(); id < tag_map->EndId(); ++id) {
|
||||||
|
result.push_back(id);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
DefaultInputStreamHandler::DefaultInputStreamHandler(
|
DefaultInputStreamHandler::DefaultInputStreamHandler(
|
||||||
std::shared_ptr<tool::TagMap> tag_map, CalculatorContextManager* cc_manager,
|
std::shared_ptr<tool::TagMap> tag_map, CalculatorContextManager* cc_manager,
|
||||||
const MediaPipeOptions& options, bool calculator_run_in_parallel)
|
const MediaPipeOptions& options, bool calculator_run_in_parallel)
|
||||||
: InputStreamHandler(std::move(tag_map), cc_manager, options,
|
: InputStreamHandler(std::move(tag_map), cc_manager, options,
|
||||||
calculator_run_in_parallel) {
|
calculator_run_in_parallel),
|
||||||
|
sync_set_(this, GetIds(input_stream_managers_.TagMap())) {
|
||||||
if (options.HasExtension(DefaultInputStreamHandlerOptions::ext)) {
|
if (options.HasExtension(DefaultInputStreamHandlerOptions::ext)) {
|
||||||
SetBatchSize(options.GetExtension(DefaultInputStreamHandlerOptions::ext)
|
SetBatchSize(options.GetExtension(DefaultInputStreamHandlerOptions::ext)
|
||||||
.batch_size());
|
.batch_size());
|
||||||
|
@ -35,47 +47,12 @@ DefaultInputStreamHandler::DefaultInputStreamHandler(
|
||||||
|
|
||||||
NodeReadiness DefaultInputStreamHandler::GetNodeReadiness(
|
NodeReadiness DefaultInputStreamHandler::GetNodeReadiness(
|
||||||
Timestamp* min_stream_timestamp) {
|
Timestamp* min_stream_timestamp) {
|
||||||
DCHECK(min_stream_timestamp);
|
return sync_set_.GetReadiness(min_stream_timestamp);
|
||||||
*min_stream_timestamp = Timestamp::Done();
|
|
||||||
Timestamp min_bound = Timestamp::Done();
|
|
||||||
for (const auto& stream : input_stream_managers_) {
|
|
||||||
bool empty;
|
|
||||||
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
|
|
||||||
if (empty) {
|
|
||||||
min_bound = std::min(min_bound, stream_timestamp);
|
|
||||||
}
|
|
||||||
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (*min_stream_timestamp == Timestamp::Done()) {
|
|
||||||
return NodeReadiness::kReadyForClose;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (min_bound > *min_stream_timestamp) {
|
|
||||||
return NodeReadiness::kReadyForProcess;
|
|
||||||
}
|
|
||||||
|
|
||||||
CHECK_EQ(min_bound, *min_stream_timestamp);
|
|
||||||
return NodeReadiness::kNotReady;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DefaultInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
void DefaultInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
||||||
InputStreamShardSet* input_set) {
|
InputStreamShardSet* input_set) {
|
||||||
CHECK(input_timestamp.IsAllowedInStream());
|
sync_set_.FillInputSet(input_timestamp, input_set);
|
||||||
CHECK(input_set);
|
|
||||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
|
||||||
id < input_stream_managers_.EndId(); ++id) {
|
|
||||||
auto& stream = input_stream_managers_.Get(id);
|
|
||||||
int num_packets_dropped = 0;
|
|
||||||
bool stream_is_done = false;
|
|
||||||
Packet current_packet = stream->PopPacketAtTimestamp(
|
|
||||||
input_timestamp, &num_packets_dropped, &stream_is_done);
|
|
||||||
CHECK_EQ(num_packets_dropped, 0)
|
|
||||||
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
|
||||||
num_packets_dropped, stream->Name());
|
|
||||||
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
|
|
||||||
stream_is_done);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -45,6 +45,9 @@ class DefaultInputStreamHandler : public InputStreamHandler {
|
||||||
// Only invoked when associated GetNodeReadiness() returned kReadyForProcess.
|
// Only invoked when associated GetNodeReadiness() returned kReadyForProcess.
|
||||||
void FillInputSet(Timestamp input_timestamp,
|
void FillInputSet(Timestamp input_timestamp,
|
||||||
InputStreamShardSet* input_set) override;
|
InputStreamShardSet* input_set) override;
|
||||||
|
|
||||||
|
// The packet-set builder.
|
||||||
|
SyncSet sync_set_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using SyncSet = InputStreamHandler::SyncSet;
|
||||||
|
|
||||||
// An input stream handler that delivers input packets to the Calculator
|
// An input stream handler that delivers input packets to the Calculator
|
||||||
// immediately, with no dependency between input streams. It also invokes
|
// immediately, with no dependency between input streams. It also invokes
|
||||||
// Calculator::Process when any input stream becomes done.
|
// Calculator::Process when any input stream becomes done.
|
||||||
|
@ -47,8 +49,11 @@ class ImmediateInputStreamHandler : public InputStreamHandler {
|
||||||
void FillInputSet(Timestamp input_timestamp,
|
void FillInputSet(Timestamp input_timestamp,
|
||||||
InputStreamShardSet* input_set) override;
|
InputStreamShardSet* input_set) override;
|
||||||
|
|
||||||
// Record of the last reported timestamp bound for each input stream.
|
absl::Mutex mutex_;
|
||||||
mediapipe::internal::Collection<Timestamp> timestamp_bounds_;
|
// The packet-set builder for each input stream.
|
||||||
|
std::vector<SyncSet> sync_sets_ ABSL_GUARDED_BY(mutex_);
|
||||||
|
// The input timestamp for each kReadyForProcess input stream.
|
||||||
|
std::vector<Timestamp> ready_timestamps_ ABSL_GUARDED_BY(mutex_);
|
||||||
};
|
};
|
||||||
REGISTER_INPUT_STREAM_HANDLER(ImmediateInputStreamHandler);
|
REGISTER_INPUT_STREAM_HANDLER(ImmediateInputStreamHandler);
|
||||||
|
|
||||||
|
@ -57,31 +62,47 @@ ImmediateInputStreamHandler::ImmediateInputStreamHandler(
|
||||||
CalculatorContextManager* calculator_context_manager,
|
CalculatorContextManager* calculator_context_manager,
|
||||||
const MediaPipeOptions& options, bool calculator_run_in_parallel)
|
const MediaPipeOptions& options, bool calculator_run_in_parallel)
|
||||||
: InputStreamHandler(tag_map, calculator_context_manager, options,
|
: InputStreamHandler(tag_map, calculator_context_manager, options,
|
||||||
calculator_run_in_parallel),
|
calculator_run_in_parallel) {
|
||||||
timestamp_bounds_(std::move(tag_map)) {}
|
for (auto id = tag_map->BeginId(); id < tag_map->EndId(); ++id) {
|
||||||
|
sync_sets_.emplace_back(this, std::vector<CollectionItemId>{id});
|
||||||
|
ready_timestamps_.push_back(Timestamp::Unset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
||||||
Timestamp* min_stream_timestamp) {
|
Timestamp* min_stream_timestamp) {
|
||||||
*min_stream_timestamp = Timestamp::Done();
|
absl::MutexLock lock(&mutex_);
|
||||||
Timestamp input_timestamp = Timestamp::Done();
|
Timestamp input_timestamp = Timestamp::Done();
|
||||||
|
Timestamp min_bound = Timestamp::Done();
|
||||||
bool stream_became_done = false;
|
bool stream_became_done = false;
|
||||||
|
for (int i = 0; i < sync_sets_.size(); ++i) {
|
||||||
for (CollectionItemId i = input_stream_managers_.BeginId();
|
if (ready_timestamps_[i] > Timestamp::Unset()) {
|
||||||
i < input_stream_managers_.EndId(); ++i) {
|
min_bound = std::min(min_bound, ready_timestamps_[i]);
|
||||||
const auto& stream = input_stream_managers_.Get(i);
|
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
|
||||||
bool empty;
|
continue;
|
||||||
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
|
|
||||||
if (!empty) {
|
|
||||||
input_timestamp = std::min(input_timestamp, stream_timestamp);
|
|
||||||
}
|
}
|
||||||
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
|
Timestamp prev_ts = sync_sets_[i].LastProcessed();
|
||||||
if (stream_timestamp != timestamp_bounds_.Get(i)) {
|
Timestamp stream_ts;
|
||||||
if (stream_timestamp == Timestamp::Done()) {
|
NodeReadiness readiness = sync_sets_[i].GetReadiness(&stream_ts);
|
||||||
|
min_bound = std::min(min_bound, stream_ts);
|
||||||
|
if (readiness == NodeReadiness::kReadyForProcess) {
|
||||||
|
ready_timestamps_[i] = stream_ts;
|
||||||
|
input_timestamp = std::min(input_timestamp, stream_ts);
|
||||||
|
} else if (readiness == NodeReadiness::kReadyForClose) {
|
||||||
|
CHECK_EQ(stream_ts, Timestamp::Done());
|
||||||
|
if (ProcessTimestampBounds()) {
|
||||||
|
// With kReadyForClose, the timestamp-bound Done is returned.
|
||||||
|
// This bound is processed using the preceding input-timestamp.
|
||||||
|
// TODO: Make all InputStreamHandlers process Done() like this.
|
||||||
|
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream();
|
||||||
|
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
|
||||||
|
} else if (prev_ts < Timestamp::Done()) {
|
||||||
stream_became_done = true;
|
stream_became_done = true;
|
||||||
}
|
ready_timestamps_[i] = Timestamp::Done();
|
||||||
timestamp_bounds_.Get(i) = stream_timestamp;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
*min_stream_timestamp = min_bound;
|
||||||
|
|
||||||
if (*min_stream_timestamp == Timestamp::Done()) {
|
if (*min_stream_timestamp == Timestamp::Done()) {
|
||||||
return NodeReadiness::kReadyForClose;
|
return NodeReadiness::kReadyForClose;
|
||||||
|
@ -94,6 +115,8 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (stream_became_done) {
|
if (stream_became_done) {
|
||||||
|
// The stream_became_done logic is kept for backward compatibility.
|
||||||
|
// Note that the minimum bound is returned in min_stream_timestamp.
|
||||||
return NodeReadiness::kReadyForProcess;
|
return NodeReadiness::kReadyForProcess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,23 +125,13 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
|
||||||
|
|
||||||
void ImmediateInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
void ImmediateInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
||||||
InputStreamShardSet* input_set) {
|
InputStreamShardSet* input_set) {
|
||||||
CHECK(input_timestamp.IsAllowedInStream());
|
absl::MutexLock lock(&mutex_);
|
||||||
CHECK(input_set);
|
for (int i = 0; i < sync_sets_.size(); ++i) {
|
||||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
if (ready_timestamps_[i] == input_timestamp) {
|
||||||
id < input_stream_managers_.EndId(); ++id) {
|
sync_sets_[i].FillInputSet(input_timestamp, input_set);
|
||||||
auto& stream = input_stream_managers_.Get(id);
|
ready_timestamps_[i] = Timestamp::Unset();
|
||||||
if (stream->QueueHead().Timestamp() == input_timestamp) {
|
|
||||||
int num_packets_dropped = 0;
|
|
||||||
bool stream_is_done = false;
|
|
||||||
Packet current_packet = stream->PopPacketAtTimestamp(
|
|
||||||
input_timestamp, &num_packets_dropped, &stream_is_done);
|
|
||||||
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
|
|
||||||
stream_is_done);
|
|
||||||
} else {
|
} else {
|
||||||
Timestamp bound = stream->MinTimestampOrBound(nullptr);
|
sync_sets_[i].FillInputBounds(input_set);
|
||||||
AddPacketToShard(&input_set->Get(id),
|
|
||||||
Packet().At(bound.PreviousAllowedInStream()),
|
|
||||||
bound == Timestamp::Done());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
// TODO: Move protos in another CL after the C++ code migration.
|
// TODO: Move protos in another CL after the C++ code migration.
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "mediapipe/framework/collection_item_id.h"
|
||||||
#include "mediapipe/framework/input_stream_handler.h"
|
#include "mediapipe/framework/input_stream_handler.h"
|
||||||
#include "mediapipe/framework/mediapipe_options.pb.h"
|
#include "mediapipe/framework/mediapipe_options.pb.h"
|
||||||
#include "mediapipe/framework/packet_set.h"
|
#include "mediapipe/framework/packet_set.h"
|
||||||
|
@ -69,7 +70,7 @@ class SyncSetInputStreamHandler : public InputStreamHandler {
|
||||||
private:
|
private:
|
||||||
absl::Mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
// The ids of each set of inputs.
|
// The ids of each set of inputs.
|
||||||
std::vector<std::vector<CollectionItemId>> sync_sets_ ABSL_GUARDED_BY(mutex_);
|
std::vector<SyncSet> sync_sets_ ABSL_GUARDED_BY(mutex_);
|
||||||
// The index of the ready sync set. A value of -1 indicates that no
|
// The index of the ready sync set. A value of -1 indicates that no
|
||||||
// sync sets are ready.
|
// sync sets are ready.
|
||||||
int ready_sync_set_index_ ABSL_GUARDED_BY(mutex_) = -1;
|
int ready_sync_set_index_ ABSL_GUARDED_BY(mutex_) = -1;
|
||||||
|
@ -98,7 +99,7 @@ void SyncSetInputStreamHandler::PrepareForRun(
|
||||||
sync_sets_.clear();
|
sync_sets_.clear();
|
||||||
std::set<CollectionItemId> used_ids;
|
std::set<CollectionItemId> used_ids;
|
||||||
for (const auto& sync_set : handler_options.sync_set()) {
|
for (const auto& sync_set : handler_options.sync_set()) {
|
||||||
sync_sets_.emplace_back();
|
std::vector<CollectionItemId> stream_ids;
|
||||||
CHECK_LT(0, sync_set.tag_index_size());
|
CHECK_LT(0, sync_set.tag_index_size());
|
||||||
for (const auto& tag_index : sync_set.tag_index()) {
|
for (const auto& tag_index : sync_set.tag_index()) {
|
||||||
std::string tag;
|
std::string tag;
|
||||||
|
@ -109,8 +110,9 @@ void SyncSetInputStreamHandler::PrepareForRun(
|
||||||
CHECK(!::mediapipe::ContainsKey(used_ids, id))
|
CHECK(!::mediapipe::ContainsKey(used_ids, id))
|
||||||
<< "stream \"" << tag_index << "\" is in more than one sync set.";
|
<< "stream \"" << tag_index << "\" is in more than one sync set.";
|
||||||
used_ids.insert(id);
|
used_ids.insert(id);
|
||||||
sync_sets_.back().push_back(id);
|
stream_ids.push_back(id);
|
||||||
}
|
}
|
||||||
|
sync_sets_.emplace_back(this, std::move(stream_ids));
|
||||||
}
|
}
|
||||||
std::vector<CollectionItemId> remaining_ids;
|
std::vector<CollectionItemId> remaining_ids;
|
||||||
for (CollectionItemId id = input_stream_managers_.BeginId();
|
for (CollectionItemId id = input_stream_managers_.BeginId();
|
||||||
|
@ -120,7 +122,7 @@ void SyncSetInputStreamHandler::PrepareForRun(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!remaining_ids.empty()) {
|
if (!remaining_ids.empty()) {
|
||||||
sync_sets_.push_back(std::move(remaining_ids));
|
sync_sets_.emplace_back(this, std::move(remaining_ids));
|
||||||
}
|
}
|
||||||
ready_sync_set_index_ = -1;
|
ready_sync_set_index_ = -1;
|
||||||
ready_timestamp_ = Timestamp::Done();
|
ready_timestamp_ = Timestamp::Done();
|
||||||
|
@ -137,24 +139,14 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
if (ready_sync_set_index_ >= 0) {
|
if (ready_sync_set_index_ >= 0) {
|
||||||
*min_stream_timestamp = ready_timestamp_;
|
*min_stream_timestamp = ready_timestamp_;
|
||||||
|
// TODO: Return kNotReady unless a new ready syncset is found.
|
||||||
return NodeReadiness::kReadyForProcess;
|
return NodeReadiness::kReadyForProcess;
|
||||||
}
|
}
|
||||||
for (int sync_set_index = 0; sync_set_index < sync_sets_.size();
|
for (int sync_set_index = 0; sync_set_index < sync_sets_.size();
|
||||||
++sync_set_index) {
|
++sync_set_index) {
|
||||||
const std::vector<CollectionItemId>& sync_set = sync_sets_[sync_set_index];
|
NodeReadiness readiness =
|
||||||
*min_stream_timestamp = Timestamp::Done();
|
sync_sets_[sync_set_index].GetReadiness(min_stream_timestamp);
|
||||||
Timestamp min_bound = Timestamp::Done();
|
if (readiness == NodeReadiness::kReadyForClose) {
|
||||||
for (CollectionItemId id : sync_set) {
|
|
||||||
const auto& stream = input_stream_managers_.Get(id);
|
|
||||||
bool empty;
|
|
||||||
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
|
|
||||||
if (empty) {
|
|
||||||
min_bound = std::min(min_bound, stream_timestamp);
|
|
||||||
}
|
|
||||||
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (*min_stream_timestamp == Timestamp::Done()) {
|
|
||||||
// This sync set is done, remove it. Note that this invalidates
|
// This sync set is done, remove it. Note that this invalidates
|
||||||
// sync set indexes higher than sync_set_index. However, we are
|
// sync set indexes higher than sync_set_index. However, we are
|
||||||
// guaranteed that we were not ready before entering the outer
|
// guaranteed that we were not ready before entering the outer
|
||||||
|
@ -165,15 +157,14 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (min_bound > *min_stream_timestamp) {
|
if (readiness == NodeReadiness::kReadyForProcess) {
|
||||||
|
// TODO: Prioritize sync-sets to avoid starvation.
|
||||||
if (*min_stream_timestamp < ready_timestamp_) {
|
if (*min_stream_timestamp < ready_timestamp_) {
|
||||||
// Store the timestamp and corresponding sync set index for the
|
// Store the timestamp and corresponding sync set index for the
|
||||||
// sync set with the earliest arrival timestamp.
|
// sync set with the earliest arrival timestamp.
|
||||||
ready_timestamp_ = *min_stream_timestamp;
|
ready_timestamp_ = *min_stream_timestamp;
|
||||||
ready_sync_set_index_ = sync_set_index;
|
ready_sync_set_index_ = sync_set_index;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
CHECK_EQ(min_bound, *min_stream_timestamp);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ready_sync_set_index_ >= 0) {
|
if (ready_sync_set_index_ >= 0) {
|
||||||
|
@ -188,44 +179,17 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
|
||||||
return NodeReadiness::kNotReady;
|
return NodeReadiness::kNotReady;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SyncSetInputStreamHandler::FillInputBounds(
|
|
||||||
Timestamp input_timestamp, InputStreamShardSet* input_set) {
|
|
||||||
for (int i = 0; i < sync_sets_.size(); ++i) {
|
|
||||||
if (i != ready_sync_set_index_) {
|
|
||||||
// Set the input streams for the not-ready sync sets.
|
|
||||||
for (CollectionItemId id : sync_sets_[i]) {
|
|
||||||
const auto stream = input_stream_managers_.Get(id);
|
|
||||||
Timestamp bound = stream->MinTimestampOrBound(nullptr);
|
|
||||||
AddPacketToShard(&input_set->Get(id),
|
|
||||||
Packet().At(bound.PreviousAllowedInStream()),
|
|
||||||
bound == Timestamp::Done());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SyncSetInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
void SyncSetInputStreamHandler::FillInputSet(Timestamp input_timestamp,
|
||||||
InputStreamShardSet* input_set) {
|
InputStreamShardSet* input_set) {
|
||||||
// Assume that all current packets are already cleared.
|
// Assume that all current packets are already cleared.
|
||||||
CHECK(input_timestamp.IsAllowedInStream());
|
|
||||||
CHECK(input_set);
|
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
CHECK_LE(0, ready_sync_set_index_);
|
CHECK_LE(0, ready_sync_set_index_);
|
||||||
CHECK_EQ(input_timestamp, ready_timestamp_);
|
sync_sets_[ready_sync_set_index_].FillInputSet(input_timestamp, input_set);
|
||||||
// Set the input streams for the ready sync set.
|
for (int i = 0; i < sync_sets_.size(); ++i) {
|
||||||
for (CollectionItemId id : sync_sets_[ready_sync_set_index_]) {
|
if (i != ready_sync_set_index_) {
|
||||||
const auto& stream = input_stream_managers_.Get(id);
|
sync_sets_[i].FillInputBounds(input_set);
|
||||||
int num_packets_dropped = 0;
|
}
|
||||||
bool stream_is_done = false;
|
|
||||||
Packet current_packet = stream->PopPacketAtTimestamp(
|
|
||||||
input_timestamp, &num_packets_dropped, &stream_is_done);
|
|
||||||
CHECK_EQ(num_packets_dropped, 0)
|
|
||||||
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
|
|
||||||
num_packets_dropped, stream->Name());
|
|
||||||
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
|
|
||||||
stream_is_done);
|
|
||||||
}
|
}
|
||||||
FillInputBounds(input_timestamp, input_set);
|
|
||||||
ready_sync_set_index_ = -1;
|
ready_sync_set_index_ = -1;
|
||||||
ready_timestamp_ = Timestamp::Done();
|
ready_timestamp_ = Timestamp::Done();
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,7 +122,6 @@ std::string TimestampDiff::DebugString() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Timestamp Timestamp::NextAllowedInStream() const {
|
Timestamp Timestamp::NextAllowedInStream() const {
|
||||||
CHECK(IsAllowedInStream()) << "Timestamp is: " << DebugString();
|
|
||||||
if (*this >= Max() || *this == PreStream()) {
|
if (*this >= Max() || *this == PreStream()) {
|
||||||
// Indicates that no further timestamps may occur.
|
// Indicates that no further timestamps may occur.
|
||||||
return OneOverPostStream();
|
return OneOverPostStream();
|
||||||
|
|
|
@ -247,6 +247,12 @@ class TimestampDiff {
|
||||||
TimestampDiff operator-(const TimestampDiff other) const;
|
TimestampDiff operator-(const TimestampDiff other) const;
|
||||||
Timestamp operator+(const Timestamp other) const;
|
Timestamp operator+(const Timestamp other) const;
|
||||||
|
|
||||||
|
// Special values.
|
||||||
|
|
||||||
|
static TimestampDiff Unset() {
|
||||||
|
return TimestampDiff(Timestamp::Unset().Value());
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TimestampBaseType timestamp_;
|
TimestampBaseType timestamp_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -815,16 +815,25 @@ NodeTypeInfo::NodeRef ValidatedGraphConfig::NodeForSorterIndex(
|
||||||
sorted_nodes_.push_back(&tmp_calculators.back());
|
sorted_nodes_.push_back(&tmp_calculators.back());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (cyclic) {
|
||||||
|
// This reads from partilly altered config_ (by node Swap()) but we assume
|
||||||
|
// the nodes in the cycle are not altered, as TopologicalSorter reports
|
||||||
|
// cyclicity before processing any node in cycle.
|
||||||
|
auto node_name_formatter = [this](std::string* out, int i) {
|
||||||
|
const auto& n = NodeForSorterIndex(i);
|
||||||
|
absl::StrAppend(out, n.type == NodeTypeInfo::NodeType::CALCULATOR
|
||||||
|
? tool::CanonicalNodeName(Config(), n.index)
|
||||||
|
: DebugName(Config(), n.type, n.index));
|
||||||
|
};
|
||||||
|
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||||
|
<< "Generator side packet cycle or calculator stream cycle detected "
|
||||||
|
"in graph: ["
|
||||||
|
<< absl::StrJoin(cycle_indexes, ", ", node_name_formatter) << "]";
|
||||||
|
}
|
||||||
generator_configs.Swap(config_.mutable_packet_generator());
|
generator_configs.Swap(config_.mutable_packet_generator());
|
||||||
tmp_generators.swap(generators_);
|
tmp_generators.swap(generators_);
|
||||||
node_configs.Swap(config_.mutable_node());
|
node_configs.Swap(config_.mutable_node());
|
||||||
tmp_calculators.swap(calculators_);
|
tmp_calculators.swap(calculators_);
|
||||||
if (cyclic) {
|
|
||||||
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
|
||||||
<< "Generator side packet cycle or calculator stream cycle detected "
|
|
||||||
"in graph. Cycle indexes: "
|
|
||||||
<< absl::StrJoin(cycle_indexes, ", ");
|
|
||||||
}
|
|
||||||
#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE))
|
#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE))
|
||||||
VLOG(2) << "AFTER TOPOLOGICAL SORT:\n" << config_.DebugString();
|
VLOG(2) << "AFTER TOPOLOGICAL SORT:\n" << config_.DebugString();
|
||||||
#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE)
|
#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE)
|
||||||
|
|
|
@ -57,15 +57,14 @@
|
||||||
#include <EGL/egl.h>
|
#include <EGL/egl.h>
|
||||||
#include <GLES2/gl2.h>
|
#include <GLES2/gl2.h>
|
||||||
#include <GLES2/gl2ext.h>
|
#include <GLES2/gl2ext.h>
|
||||||
|
#if defined(__ANDROID__)
|
||||||
#ifdef __ANDROID__
|
|
||||||
// Weak-link all GL APIs included from this point on.
|
// Weak-link all GL APIs included from this point on.
|
||||||
// TODO: Annotate these with availability attributes for the
|
// TODO: Annotate these with availability attributes for the
|
||||||
// appropriate versions of Android, by including gl{3,31,31}.h and resetting
|
// appropriate versions of Android, by including gl{3,31,31}.h and resetting
|
||||||
// GL_APICALL for each.
|
// GL_APICALL for each.
|
||||||
#undef GL_APICALL
|
#undef GL_APICALL
|
||||||
#define GL_APICALL __attribute__((weak_import)) KHRONOS_APICALL
|
#define GL_APICALL __attribute__((weak_import)) KHRONOS_APICALL
|
||||||
#endif // __ANDROID__
|
#endif // defined(__ANDROID__)
|
||||||
|
|
||||||
#include <GLES3/gl32.h>
|
#include <GLES3/gl32.h>
|
||||||
|
|
||||||
|
|
|
@ -83,6 +83,10 @@ class GlCalculatorHelperImpl {
|
||||||
GLuint framebuffer_ = 0;
|
GLuint framebuffer_ = 0;
|
||||||
|
|
||||||
GpuResources& gpu_resources_;
|
GpuResources& gpu_resources_;
|
||||||
|
|
||||||
|
// Necessary to compute for a given GlContext in order to properly enforce the
|
||||||
|
// SetStandardTextureParams.
|
||||||
|
bool can_linear_filter_float_textures_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -22,6 +22,17 @@ GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc,
|
||||||
GpuResources* gpu_resources)
|
GpuResources* gpu_resources)
|
||||||
: gpu_resources_(*gpu_resources) {
|
: gpu_resources_(*gpu_resources) {
|
||||||
gl_context_ = gpu_resources_.gl_context(cc);
|
gl_context_ = gpu_resources_.gl_context(cc);
|
||||||
|
// GL_ES_VERSION_2_0 and up (at least through ES 3.2) may contain the extension.
|
||||||
|
// Checking against one also checks against higher ES versions. So this checks
|
||||||
|
// against GLES >= 2.0.
|
||||||
|
#if GL_ES_VERSION_2_0
|
||||||
|
// No linear float filtering by default, check extensions.
|
||||||
|
can_linear_filter_float_textures_ =
|
||||||
|
gl_context_->HasGlExtension("OES_texture_float_linear");
|
||||||
|
#else
|
||||||
|
// Any float32 texture we create should automatically have linear filtering.
|
||||||
|
can_linear_filter_float_textures_ = true;
|
||||||
|
#endif // GL_ES_VERSION_2_0
|
||||||
}
|
}
|
||||||
|
|
||||||
GlCalculatorHelperImpl::~GlCalculatorHelperImpl() {
|
GlCalculatorHelperImpl::~GlCalculatorHelperImpl() {
|
||||||
|
@ -89,13 +100,15 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
|
||||||
|
|
||||||
void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target,
|
void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target,
|
||||||
GLint internal_format) {
|
GLint internal_format) {
|
||||||
|
// Default to using linear filter everywhere. For float32 textures, fall back
|
||||||
|
// to GL_NEAREST if linear filtering unsupported.
|
||||||
GLint filter;
|
GLint filter;
|
||||||
switch (internal_format) {
|
switch (internal_format) {
|
||||||
case GL_R32F:
|
case GL_R32F:
|
||||||
case GL_RGBA32F:
|
case GL_RGBA32F:
|
||||||
// 32F (unlike 16f) textures do not support texture filtering
|
// 32F (unlike 16f) textures do not always support texture filtering
|
||||||
// (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION])
|
// (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION])
|
||||||
filter = GL_NEAREST;
|
filter = can_linear_filter_float_textures_ ? GL_LINEAR : GL_NEAREST;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
filter = GL_LINEAR;
|
filter = GL_LINEAR;
|
||||||
|
|
|
@ -203,6 +203,69 @@ bool GlContext::ParseGlVersion(absl::string_view version_string, GLint* major,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GlContext::HasGlExtension(absl::string_view extension) const {
|
||||||
|
return gl_extensions_.find(extension) != gl_extensions_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function for GL3.0+ to query for and store all of our available GL extensions
|
||||||
|
// in an easily-accessible set. The glGetString call is actually *not* required
|
||||||
|
// to work with GL_EXTENSIONS for newer GL versions, so we must maintain both
|
||||||
|
// variations of this function.
|
||||||
|
::mediapipe::Status GlContext::GetGlExtensions() {
|
||||||
|
gl_extensions_.clear();
|
||||||
|
// glGetStringi only introduced in GL 3.0+; so we exit out this function if
|
||||||
|
// we don't have that function defined, regardless of version number reported.
|
||||||
|
// The function itself is also fully stubbed out if we're linking against an
|
||||||
|
// API version without a glGetStringi declaration. Although Emscripten
|
||||||
|
// sometimes provides this function, its default library implementation
|
||||||
|
// appears to only provide glGetString, so we skip this for Emscripten
|
||||||
|
// platforms to avoid possible undefined symbol or runtime errors.
|
||||||
|
#if (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__)
|
||||||
|
if (!SymbolAvailable(&glGetStringi)) {
|
||||||
|
LOG(ERROR) << "GL major version > 3.0 indicated, but glGetStringi not "
|
||||||
|
<< "defined. Falling back to deprecated GL extensions querying "
|
||||||
|
<< "method.";
|
||||||
|
return ::mediapipe::InternalError("glGetStringi not defined, but queried");
|
||||||
|
}
|
||||||
|
int num_extensions = 0;
|
||||||
|
glGetIntegerv(GL_NUM_EXTENSIONS, &num_extensions);
|
||||||
|
if (glGetError() != 0) {
|
||||||
|
return ::mediapipe::InternalError(
|
||||||
|
"Error querying for number of extensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_extensions; ++i) {
|
||||||
|
const GLubyte* res = glGetStringi(GL_EXTENSIONS, i);
|
||||||
|
if (glGetError() != 0 || res == nullptr) {
|
||||||
|
return ::mediapipe::InternalError(
|
||||||
|
"Error querying for an extension by index");
|
||||||
|
}
|
||||||
|
const char* signed_res = reinterpret_cast<const char*>(res);
|
||||||
|
gl_extensions_.insert(signed_res);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
#else
|
||||||
|
return ::mediapipe::InternalError("GL version mismatch in GlGetExtensions");
|
||||||
|
#endif // (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same as GetGlExtensions() above, but for pre-GL3.0, where glGetStringi did
|
||||||
|
// not exist.
|
||||||
|
::mediapipe::Status GlContext::GetGlExtensionsCompat() {
|
||||||
|
gl_extensions_.clear();
|
||||||
|
|
||||||
|
const GLubyte* res = glGetString(GL_EXTENSIONS);
|
||||||
|
if (glGetError() != 0 || res == nullptr) {
|
||||||
|
LOG(ERROR) << "Error querying for GL extensions";
|
||||||
|
return ::mediapipe::InternalError("Error querying for GL extensions");
|
||||||
|
}
|
||||||
|
const char* signed_res = reinterpret_cast<const char*>(res);
|
||||||
|
gl_extensions_ = absl::StrSplit(signed_res, ' ');
|
||||||
|
|
||||||
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
::mediapipe::Status GlContext::FinishInitialization(bool create_thread) {
|
::mediapipe::Status GlContext::FinishInitialization(bool create_thread) {
|
||||||
if (create_thread) {
|
if (create_thread) {
|
||||||
thread_ = absl::make_unique<GlContext::DedicatedThread>();
|
thread_ = absl::make_unique<GlContext::DedicatedThread>();
|
||||||
|
@ -232,8 +295,13 @@ bool GlContext::ParseGlVersion(absl::string_view version_string, GLint* major,
|
||||||
|
|
||||||
LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_
|
LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_
|
||||||
<< " (" << glGetString(GL_VERSION) << ")";
|
<< " (" << glGetString(GL_VERSION) << ")";
|
||||||
|
if (gl_major_version_ >= 3) {
|
||||||
|
auto status = GetGlExtensions();
|
||||||
|
if (status.ok()) {
|
||||||
return ::mediapipe::OkStatus();
|
return ::mediapipe::OkStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return GetGlExtensionsCompat();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -237,6 +237,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
||||||
static bool ParseGlVersion(absl::string_view version_string, GLint* major,
|
static bool ParseGlVersion(absl::string_view version_string, GLint* major,
|
||||||
GLint* minor);
|
GLint* minor);
|
||||||
|
|
||||||
|
// Simple query for GL extension support; only valid after GlContext has
|
||||||
|
// finished its initialization successfully.
|
||||||
|
bool HasGlExtension(absl::string_view extension) const;
|
||||||
|
|
||||||
int64_t gl_finish_count() { return gl_finish_count_; }
|
int64_t gl_finish_count() { return gl_finish_count_; }
|
||||||
|
|
||||||
// Used by GlFinishSyncPoint. The count_to_pass cannot exceed the current
|
// Used by GlFinishSyncPoint. The count_to_pass cannot exceed the current
|
||||||
|
@ -346,6 +350,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
||||||
bool HasContext() const;
|
bool HasContext() const;
|
||||||
bool CheckForGlErrors();
|
bool CheckForGlErrors();
|
||||||
void LogUncheckedGlErrors(bool had_gl_errors);
|
void LogUncheckedGlErrors(bool had_gl_errors);
|
||||||
|
::mediapipe::Status GetGlExtensions();
|
||||||
|
::mediapipe::Status GetGlExtensionsCompat();
|
||||||
|
|
||||||
// The following ContextBinding functions have platform-specific
|
// The following ContextBinding functions have platform-specific
|
||||||
// implementations.
|
// implementations.
|
||||||
|
@ -366,6 +372,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
||||||
GLint gl_major_version_ = 0;
|
GLint gl_major_version_ = 0;
|
||||||
GLint gl_minor_version_ = 0;
|
GLint gl_minor_version_ = 0;
|
||||||
|
|
||||||
|
// glGetString and glGetStringi both return pointers to static strings,
|
||||||
|
// so we should be fine storing the extension pieces as string_view's.
|
||||||
|
std::set<absl::string_view> gl_extensions_;
|
||||||
|
|
||||||
// Number of glFinish calls completed on the GL thread.
|
// Number of glFinish calls completed on the GL thread.
|
||||||
// Changes should be guarded by mutex_. However, we use simple atomic
|
// Changes should be guarded by mutex_. However, we use simple atomic
|
||||||
// loads for efficiency on the fast path.
|
// loads for efficiency on the fast path.
|
||||||
|
|
|
@ -24,6 +24,19 @@ namespace mediapipe {
|
||||||
#define _STRINGIFY(_x) __STRINGIFY(_x)
|
#define _STRINGIFY(_x) __STRINGIFY(_x)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Our fragment shaders use DEFAULT_PRECISION to define the default precision
|
||||||
|
// for a type. The macro strips out the precision declaration on desktop GL,
|
||||||
|
// where it's not supported.
|
||||||
|
//
|
||||||
|
// Note: this does not use a raw std::string because some compilers don't handle
|
||||||
|
// raw strings inside macros correctly. It uses a macro because we want to be
|
||||||
|
// able to concatenate strings by juxtaposition. We want to concatenate strings
|
||||||
|
// by juxtaposition so we can export const char* static data containing the
|
||||||
|
// pre-expanded strings.
|
||||||
|
//
|
||||||
|
// TODO: this was written before we could rely on C++11 support.
|
||||||
|
// Consider replacing it with constexpr std::string concatenation, or replacing
|
||||||
|
// the static variables with functions.
|
||||||
#define PRECISION_COMPAT \
|
#define PRECISION_COMPAT \
|
||||||
GLES_VERSION_COMPAT \
|
GLES_VERSION_COMPAT \
|
||||||
"#ifdef GL_ES \n" \
|
"#ifdef GL_ES \n" \
|
||||||
|
@ -42,10 +55,15 @@ namespace mediapipe {
|
||||||
"#define out varying\n" \
|
"#define out varying\n" \
|
||||||
"#endif // __VERSION__ < 130\n"
|
"#endif // __VERSION__ < 130\n"
|
||||||
|
|
||||||
|
// Note: on systems where highp precision for floats is not supported (look up
|
||||||
|
// GL_FRAGMENT_PRECISION_HIGH), we replace it with mediump.
|
||||||
#define FRAGMENT_PREAMBLE \
|
#define FRAGMENT_PREAMBLE \
|
||||||
PRECISION_COMPAT \
|
PRECISION_COMPAT \
|
||||||
"#if __VERSION__ < 130\n" \
|
"#if __VERSION__ < 130\n" \
|
||||||
"#define in varying\n" \
|
"#define in varying\n" \
|
||||||
|
"#if GL_ES && !GL_FRAGMENT_PRECISION_HIGH\n" \
|
||||||
|
"#define highp mediump\n" \
|
||||||
|
"#endif // GL_ES && !GL_FRAGMENT_PRECISION_HIGH\n" \
|
||||||
"#endif // __VERSION__ < 130\n"
|
"#endif // __VERSION__ < 130\n"
|
||||||
|
|
||||||
const GLchar* const kMediaPipeVertexShaderPreamble = VERTEX_PREAMBLE;
|
const GLchar* const kMediaPipeVertexShaderPreamble = VERTEX_PREAMBLE;
|
||||||
|
|
|
@ -12,6 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
||||||
|
"mediapipe_binary_graph",
|
||||||
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
@ -33,9 +38,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
load(
|
cc_library(
|
||||||
"//mediapipe/framework/tool:mediapipe_graph.bzl",
|
name = "desktop_calculators",
|
||||||
"mediapipe_binary_graph",
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
|
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||||
|
"//mediapipe/calculators/image:recolor_calculator",
|
||||||
|
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||||
|
"//mediapipe/calculators/tflite:tflite_converter_calculator",
|
||||||
|
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
|
||||||
|
"//mediapipe/calculators/tflite:tflite_inference_calculator",
|
||||||
|
"//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_binary_graph(
|
mediapipe_binary_graph(
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on CPU.
|
||||||
|
# Used in the example in
|
||||||
|
# mediapipie/examples/desktop/hair_segmentation:hair_segmentation_cpu
|
||||||
|
|
||||||
|
# Images on CPU coming into and out of the graph.
|
||||||
|
input_stream: "input_video"
|
||||||
|
output_stream: "output_video"
|
||||||
|
|
||||||
|
# Throttles the images flowing downstream for flow control. It passes through
|
||||||
|
# the very first incoming image unaltered, and waits for
|
||||||
|
# TfLiteTensorsToSegmentationCalculator downstream in the graph to finish
|
||||||
|
# generating the corresponding hair mask before it passes through another
|
||||||
|
# image. All images that come in while waiting are dropped, limiting the number
|
||||||
|
# of in-flight images between this calculator and
|
||||||
|
# TfLiteTensorsToSegmentationCalculator to 1. This prevents the nodes in between
|
||||||
|
# from queuing up incoming images and data excessively, which leads to increased
|
||||||
|
# latency and memory usage, unwanted in real-time mobile applications. It also
|
||||||
|
# eliminates unnecessarily computation, e.g., a transformed image produced by
|
||||||
|
# ImageTransformationCalculator may get dropped downstream if the subsequent
|
||||||
|
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
|
||||||
|
# processing previous inputs.
|
||||||
|
node {
|
||||||
|
calculator: "FlowLimiterCalculator"
|
||||||
|
input_stream: "input_video"
|
||||||
|
input_stream: "FINISHED:hair_mask"
|
||||||
|
input_stream_info: {
|
||||||
|
tag_index: "FINISHED"
|
||||||
|
back_edge: true
|
||||||
|
}
|
||||||
|
output_stream: "throttled_input_video"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Transforms the input image on CPU to a 512x512 image. To scale the image, by
|
||||||
|
# default it uses the STRETCH scale mode that maps the entire input image to the
|
||||||
|
# entire transformed image. As a result, image aspect ratio may be changed and
|
||||||
|
# objects in the image may be deformed (stretched or squeezed), but the hair
|
||||||
|
# segmentation model used in this graph is agnostic to that deformation.
|
||||||
|
node: {
|
||||||
|
calculator: "ImageTransformationCalculator"
|
||||||
|
input_stream: "IMAGE:throttled_input_video"
|
||||||
|
output_stream: "IMAGE:transformed_input_video"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
|
||||||
|
output_width: 512
|
||||||
|
output_height: 512
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Caches a mask fed back from the previous round of hair segmentation, and upon
|
||||||
|
# the arrival of the next input image sends out the cached mask with the
|
||||||
|
# timestamp replaced by that of the input image, essentially generating a packet
|
||||||
|
# that carries the previous mask. Note that upon the arrival of the very first
|
||||||
|
# input image, an empty packet is sent out to jump start the feedback loop.
|
||||||
|
node {
|
||||||
|
calculator: "PreviousLoopbackCalculator"
|
||||||
|
input_stream: "MAIN:throttled_input_video"
|
||||||
|
input_stream: "LOOP:hair_mask"
|
||||||
|
input_stream_info: {
|
||||||
|
tag_index: "LOOP"
|
||||||
|
back_edge: true
|
||||||
|
}
|
||||||
|
output_stream: "PREV_LOOP:previous_hair_mask"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Embeds the hair mask generated from the previous round of hair segmentation
|
||||||
|
# as the alpha channel of the current input image.
|
||||||
|
node {
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:transformed_input_video"
|
||||||
|
input_stream: "ALPHA:previous_hair_mask"
|
||||||
|
output_stream: "IMAGE:mask_embedded_input_video"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Converts the transformed input image on CPU into an image tensor stored in
|
||||||
|
# TfLiteTensor. The zero_center option is set to false to normalize the
|
||||||
|
# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. With the
|
||||||
|
# max_num_channels option set to 4, all 4 RGBA channels are contained in the
|
||||||
|
# image tensor.
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteConverterCalculator"
|
||||||
|
input_stream: "IMAGE:mask_embedded_input_video"
|
||||||
|
output_stream: "TENSORS:image_tensor"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
|
||||||
|
zero_center: false
|
||||||
|
max_num_channels: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generates a single side packet containing a TensorFlow Lite op resolver that
|
||||||
|
# supports custom ops needed by the model used in this graph.
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteCustomOpResolverCalculator"
|
||||||
|
output_side_packet: "op_resolver"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
|
||||||
|
use_gpu: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
|
||||||
|
# tensor representing the hair segmentation, which has the same width and height
|
||||||
|
# as the input image tensor.
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteInferenceCalculator"
|
||||||
|
input_stream: "TENSORS:image_tensor"
|
||||||
|
output_stream: "TENSORS:segmentation_tensor"
|
||||||
|
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
|
||||||
|
model_path: "mediapipe/models/hair_segmentation.tflite"
|
||||||
|
use_gpu: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Decodes the segmentation tensor generated by the TensorFlow Lite model into a
|
||||||
|
# mask of values in [0, 255], stored in a CPU buffer. It also
|
||||||
|
# takes the mask generated previously as another input to improve the temporal
|
||||||
|
# consistency.
|
||||||
|
node {
|
||||||
|
calculator: "TfLiteTensorsToSegmentationCalculator"
|
||||||
|
input_stream: "TENSORS:segmentation_tensor"
|
||||||
|
input_stream: "PREV_MASK:previous_hair_mask"
|
||||||
|
output_stream: "MASK:hair_mask"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] {
|
||||||
|
tensor_width: 512
|
||||||
|
tensor_height: 512
|
||||||
|
tensor_channels: 2
|
||||||
|
combine_with_previous_ratio: 0.9
|
||||||
|
output_layer_index: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Colors the hair segmentation with the color specified in the option.
|
||||||
|
node {
|
||||||
|
calculator: "RecolorCalculator"
|
||||||
|
input_stream: "IMAGE:throttled_input_video"
|
||||||
|
input_stream: "MASK:hair_mask"
|
||||||
|
output_stream: "IMAGE:output_video"
|
||||||
|
node_options: {
|
||||||
|
[type.googleapis.com/mediapipe.RecolorCalculatorOptions] {
|
||||||
|
color { r: 0 g: 0 b: 255 }
|
||||||
|
mask_channel: RED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -78,6 +78,33 @@ cat > $(OUTS) <<EOF
|
||||||
srcs = ["//mediapipe/framework/formats:protos_src"],
|
srcs = ["//mediapipe/framework/formats:protos_src"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_proto_java_src_generator(
|
||||||
|
name = "rasterization_proto",
|
||||||
|
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
|
||||||
|
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||||
|
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
|
||||||
|
)
|
||||||
|
|
||||||
|
_proto_java_src_generator(
|
||||||
|
name = "location_data_proto",
|
||||||
|
proto_src = "mediapipe/framework/formats/location_data.proto",
|
||||||
|
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||||
|
srcs = [
|
||||||
|
"//mediapipe/framework/formats:protos_src",
|
||||||
|
"//mediapipe/framework/formats/annotation:protos_src",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
_proto_java_src_generator(
|
||||||
|
name = "detection_proto",
|
||||||
|
proto_src = "mediapipe/framework/formats/detection.proto",
|
||||||
|
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||||
|
srcs = [
|
||||||
|
"//mediapipe/framework/formats:protos_src",
|
||||||
|
"//mediapipe/framework/formats/annotation:protos_src",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
android_library(
|
android_library(
|
||||||
name = name + "_android_lib",
|
name = name + "_android_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -86,6 +113,9 @@ cat > $(OUTS) <<EOF
|
||||||
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
||||||
"com/google/mediapipe/proto/CalculatorProto.java",
|
"com/google/mediapipe/proto/CalculatorProto.java",
|
||||||
"com/google/mediapipe/formats/proto/LandmarkProto.java",
|
"com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||||
|
"com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||||
|
"com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||||
|
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||||
],
|
],
|
||||||
manifest = "AndroidManifest.xml",
|
manifest = "AndroidManifest.xml",
|
||||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||||
|
|
|
@ -43,31 +43,6 @@ MEDIAPIPE_IOS_HDRS = [
|
||||||
"NSError+util_status.h",
|
"NSError+util_status.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
MEDIAPIPE_IOS_CC_DEPS = [
|
|
||||||
":CFHolder",
|
|
||||||
":util",
|
|
||||||
"//mediapipe/framework/port:map_util",
|
|
||||||
"//mediapipe/framework/port:ret_check",
|
|
||||||
"//mediapipe/framework/port:source_location",
|
|
||||||
"//mediapipe/framework/port:status",
|
|
||||||
"//mediapipe/framework/port:statusor",
|
|
||||||
"//mediapipe/framework/port:threadpool",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"//mediapipe/framework:mediapipe_profiling",
|
|
||||||
"//mediapipe/gpu:MPPGraphGPUData",
|
|
||||||
"//mediapipe/gpu:pixel_buffer_pool_util",
|
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
|
||||||
"//mediapipe/gpu:gl_base",
|
|
||||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
|
||||||
"//mediapipe/gpu:graph_support",
|
|
||||||
# Other deps
|
|
||||||
"//mediapipe/util:cpu_util",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
]
|
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "mediapipe_framework_ios",
|
name = "mediapipe_framework_ios",
|
||||||
srcs = MEDIAPIPE_IOS_SRCS,
|
srcs = MEDIAPIPE_IOS_SRCS,
|
||||||
|
@ -80,8 +55,28 @@ objc_library(
|
||||||
"Accelerate",
|
"Accelerate",
|
||||||
],
|
],
|
||||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||||
deps = MEDIAPIPE_IOS_CC_DEPS + [
|
deps = [
|
||||||
# These are objc_library deps.
|
":CFHolder",
|
||||||
|
":util",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:mediapipe_profiling",
|
||||||
|
"//mediapipe/framework/port:map_util",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/framework/port:source_location",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/port:statusor",
|
||||||
|
"//mediapipe/framework/port:threadpool",
|
||||||
|
"//mediapipe/gpu:MPPGraphGPUData",
|
||||||
|
"//mediapipe/gpu:gl_base",
|
||||||
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||||
|
"//mediapipe/gpu:graph_support",
|
||||||
|
"//mediapipe/gpu:pixel_buffer_pool_util",
|
||||||
|
"//mediapipe/util:cpu_util",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"@google_toolbox_for_mac//:GTM_Defines",
|
"@google_toolbox_for_mac//:GTM_Defines",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -426,7 +426,7 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||||
cv::Point point_to_draw(x, y);
|
cv::Point point_to_draw(x, y);
|
||||||
const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color());
|
const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color());
|
||||||
const int thickness = annotation.thickness();
|
const int thickness = annotation.thickness();
|
||||||
cv::circle(mat_image_, point_to_draw, thickness, color, thickness);
|
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
||||||
|
|
|
@ -22,7 +22,6 @@ cc_library(
|
||||||
hdrs = ["media_sequence_util.h"],
|
hdrs = ["media_sequence_util.h"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe:__subpackages__",
|
"//mediapipe:__subpackages__",
|
||||||
"//research/action_recognition/sequence:__subpackages__",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/port:core_proto",
|
"//mediapipe/framework/port:core_proto",
|
||||||
|
@ -38,7 +37,6 @@ cc_library(
|
||||||
hdrs = ["media_sequence.h"],
|
hdrs = ["media_sequence.h"],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe:__subpackages__",
|
"//mediapipe:__subpackages__",
|
||||||
"//research/action_recognition/sequence:__subpackages__",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":media_sequence_util",
|
":media_sequence_util",
|
||||||
|
|
|
@ -36,7 +36,11 @@ cd /tmp/build_opencv
|
||||||
git clone https://github.com/opencv/opencv_contrib.git
|
git clone https://github.com/opencv/opencv_contrib.git
|
||||||
git clone https://github.com/opencv/opencv.git
|
git clone https://github.com/opencv/opencv.git
|
||||||
mkdir opencv/release
|
mkdir opencv/release
|
||||||
cd opencv/release
|
cd opencv_contrib
|
||||||
|
git checkout 3.4
|
||||||
|
cd ../opencv
|
||||||
|
git checkout 3.4
|
||||||
|
cd release
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr/local \
|
cmake .. -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr/local \
|
||||||
-DBUILD_TESTS=OFF -DBUILD_PERF_TESTS=OFF -DBUILD_opencv_ts=OFF \
|
-DBUILD_TESTS=OFF -DBUILD_PERF_TESTS=OFF -DBUILD_opencv_ts=OFF \
|
||||||
-DOPENCV_EXTRA_MODULES_PATH=/tmp/build_opencv/opencv_contrib/modules \
|
-DOPENCV_EXTRA_MODULES_PATH=/tmp/build_opencv/opencv_contrib/modules \
|
||||||
|
|
112
third_party/org_tensorflow_9696366bcadab23a25c773b3ed405bac8ded4d0d.diff
vendored
Normal file
112
third_party/org_tensorflow_9696366bcadab23a25c773b3ed405bac8ded4d0d.diff
vendored
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel
|
||||||
|
index 8d89521612..6ea60acdda 100644
|
||||||
|
--- a/third_party/cpuinfo/BUILD.bazel
|
||||||
|
+++ b/third_party/cpuinfo/BUILD.bazel
|
||||||
|
@@ -116,6 +111,8 @@ cc_library(
|
||||||
|
":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
|
||||||
|
":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
|
||||||
|
":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
|
||||||
|
+ ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
|
||||||
|
+ ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
|
||||||
|
":emscripten_wasm": COMMON_SRCS + EMSCRIPTEN_SRCS,
|
||||||
|
}),
|
||||||
|
copts = C99OPTS + [
|
||||||
|
@@ -212,7 +209,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "ios_armv7",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "ios",
|
||||||
|
"cpu": "ios_armv7",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -220,7 +217,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "ios_arm64",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "ios",
|
||||||
|
"cpu": "ios_arm64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -228,7 +225,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "ios_arm64e",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "ios",
|
||||||
|
"cpu": "ios_arm64e",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -236,7 +233,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "ios_x86",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "ios",
|
||||||
|
"cpu": "ios_i386",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -244,7 +241,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "ios_x86_64",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "ios",
|
||||||
|
"cpu": "ios_x86_64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -252,7 +249,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "watchos_armv7k",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "watchos",
|
||||||
|
"cpu": "watchos_armv7k",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -260,7 +257,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "watchos_arm64_32",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "watchos",
|
||||||
|
"cpu": "watchos_arm64_32",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -268,7 +265,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "watchos_x86",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "watchos",
|
||||||
|
"cpu": "watchos_i386",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -276,7 +273,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "watchos_x86_64",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "watchos",
|
||||||
|
"cpu": "watchos_x86_64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -284,7 +281,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "tvos_arm64",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "tvos",
|
||||||
|
"cpu": "tvos_arm64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@@ -292,7 +289,7 @@ config_setting(
|
||||||
|
config_setting(
|
||||||
|
name = "tvos_x86_64",
|
||||||
|
values = {
|
||||||
|
- "crosstool_top": "//tools/osx/crosstool:crosstool",
|
||||||
|
+ "apple_platform_type": "tvos",
|
||||||
|
"cpu": "tvos_x86_64",
|
||||||
|
},
|
||||||
|
)
|
3083
third_party/org_tensorflow_cfc31e324c8de6b52f752a39cb161d99d853ca99.diff
vendored
Normal file
3083
third_party/org_tensorflow_cfc31e324c8de6b52f752a39cb161d99d853ca99.diff
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user