Project import generated by Copybara.

GitOrigin-RevId: 43cd697ec87dcc5cab5051f27960bb77a057399d
This commit is contained in:
MediaPipe Team 2020-03-20 13:09:58 -07:00 committed by jqtang
parent 3b6d3c4058
commit 1722d4b8a2
71 changed files with 6114 additions and 626 deletions

View File

@ -129,7 +129,11 @@ http_archive(
],
# A compatibility patch
patches = [
"@//third_party:org_tensorflow_528e22eae8bf3206189a066032c66e9e5c9b4a61.diff"
"@//third_party:org_tensorflow_528e22eae8bf3206189a066032c66e9e5c9b4a61.diff",
# Updates for XNNPACK: https://github.com/tensorflow/tensorflow/commit/cfc31e324c8de6b52f752a39cb161d99d853ca99
"@//third_party:org_tensorflow_cfc31e324c8de6b52f752a39cb161d99d853ca99.diff",
# CpuInfo's build rule fixes.
"@//third_party:org_tensorflow_9696366bcadab23a25c773b3ed405bac8ded4d0d.diff",
],
patch_args = [
"-p1",

View File

@ -228,6 +228,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
@ -249,6 +250,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
@ -265,10 +267,11 @@ cc_test(
deps = [
":begin_loop_calculator",
":end_loop_calculator",
"//mediapipe/calculators/core:packet_cloner_calculator",
":gate_calculator",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
@ -334,6 +337,7 @@ cc_library(
deps = [
":clip_vector_size_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -693,15 +697,17 @@ cc_test(
name = "previous_loopback_calculator_test",
srcs = ["previous_loopback_calculator_test.cc"],
deps = [
":gate_calculator",
":make_pair_calculator",
":pass_through_calculator",
":previous_loopback_calculator",
"//mediapipe/calculators/core:make_pair_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/time",
@ -769,6 +775,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":split_vector_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",

View File

@ -20,6 +20,8 @@
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
@ -28,6 +30,13 @@
namespace mediapipe {
namespace {
MATCHER_P2(PacketOfIntsEq, timestamp, value, "") {
Timestamp actual_timestamp = arg.Timestamp();
const auto& actual_value = arg.template Get<std::vector<int>>();
return testing::Value(actual_timestamp, testing::Eq(timestamp)) &&
testing::Value(actual_value, testing::ElementsAreArray(value));
}
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntegerCalculator;
REGISTER_CALCULATOR(BeginLoopIntegerCalculator);
@ -59,8 +68,8 @@ REGISTER_CALCULATOR(EndLoopIntegersCalculator);
class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
protected:
BeginEndLoopCalculatorGraphTest() {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
num_threads: 4
input_stream: "ints"
@ -82,94 +91,222 @@ class BeginEndLoopCalculatorGraphTest : public ::testing::Test {
output_stream: "ITERABLE:ints_plus_one"
}
)");
tool::AddVectorSink("ints_plus_one", &graph_config_, &output_packets_);
tool::AddVectorSink("ints_plus_one", &graph_config, &output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
CalculatorGraphConfig graph_config_;
void SendPacketOfInts(Timestamp timestamp, std::vector<int> ints) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphTest, InputStreamForIterableIsEmpty) {
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no packets
// to process.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, SingleEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
SendPacketOfInts(Timestamp(0), {});
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
ASSERT_EQ(0, output_packets_.size());
EXPECT_TRUE(output_packets_.empty());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, SingleNonEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
input_vector->emplace_back(0);
input_vector->emplace_back(1);
input_vector->emplace_back(2);
Timestamp input_timestamp = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
SendPacketOfInts(input_timestamp, {0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
ASSERT_EQ(1, output_packets_.size());
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector = {1, 2, 3};
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp, std::vector<int>{1, 2, 3})));
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector0 = absl::make_unique<std::vector<int>>();
input_vector0->emplace_back(0);
input_vector0->emplace_back(1);
Timestamp input_timestamp0 = Timestamp(0);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
SendPacketOfInts(input_timestamp0, {0, 1});
auto input_vector1 = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp1 = Timestamp(1);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
SendPacketOfInts(input_timestamp1, {});
auto input_vector2 = absl::make_unique<std::vector<int>>();
input_vector2->emplace_back(2);
input_vector2->emplace_back(3);
Timestamp input_timestamp2 = Timestamp(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
SendPacketOfInts(input_timestamp2, {2, 3});
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(2, output_packets_.size());
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector0 = {1, 2};
EXPECT_EQ(expected_output_vector0,
output_packets_[0].Get<std::vector<int>>());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp0, std::vector<int>{1, 2}),
PacketOfIntsEq(input_timestamp2, std::vector<int>{3, 4})));
}
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
std::vector<int> expected_output_vector2 = {3, 4};
EXPECT_EQ(expected_output_vector2,
output_packets_[1].Get<std::vector<int>>());
// Passes non empty vector through or outputs empty vector in case of timestamp
// bound update.
class PassThroughOrEmptyVectorCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->SetProcessTimestampBounds(true);
cc->Inputs().Index(0).Set<std::vector<int>>();
cc->Outputs().Index(0).Set<std::vector<int>>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (!cc->Inputs().Index(0).IsEmpty()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
} else {
cc->Outputs().Index(0).AddPacket(
MakePacket<std::vector<int>>(std::vector<int>())
.At(cc->InputTimestamp()));
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(PassThroughOrEmptyVectorCalculator);
class BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest
: public ::testing::Test {
protected:
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
num_threads: 4
input_stream: "ints"
input_stream: "force_ints_to_be_timestamp_bound_update"
node {
calculator: "GateCalculator"
input_stream: "ints"
input_stream: "DISALLOW:force_ints_to_be_timestamp_bound_update"
output_stream: "ints_passed_through"
}
node {
calculator: "BeginLoopIntegerCalculator"
input_stream: "ITERABLE:ints_passed_through"
output_stream: "ITEM:int"
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "IncrementCalculator"
input_stream: "int"
output_stream: "int_plus_one"
}
node {
calculator: "EndLoopIntegersCalculator"
input_stream: "ITEM:int_plus_one"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:ints_plus_one"
}
node {
calculator: "PassThroughOrEmptyVectorCalculator"
input_stream: "ints_plus_one"
output_stream: "ints_plus_one_passed_through"
}
)");
tool::AddVectorSink("ints_plus_one_passed_through", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPacketOfIntsOrBound(Timestamp timestamp, std::vector<int> ints) {
// All "ints" packets which are empty are forced to be just timestamp
// bound updates for begin loop calculator.
bool force_ints_to_be_timestamp_bound_update = ints.empty();
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_ints_to_be_timestamp_bound_update",
MakePacket<bool>(force_ints_to_be_timestamp_bound_update)
.At(timestamp)));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
SingleEmptyVector) {
SendPacketOfIntsOrBound(Timestamp(0), {});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
Timestamp(0), std::vector<int>{})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest,
SingleNonEmptyVector) {
SendPacketOfIntsOrBound(Timestamp(0), {0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(output_packets_, testing::ElementsAre(PacketOfIntsEq(
Timestamp(0), std::vector<int>{1, 2, 3})));
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) {
SendPacketOfIntsOrBound(Timestamp(0), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
SendPacketOfIntsOrBound(Timestamp(1), {0, 1});
SendPacketOfIntsOrBound(Timestamp(2), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
SendPacketOfIntsOrBound(Timestamp(3), {2, 3});
SendPacketOfIntsOrBound(Timestamp(4), {});
// Waiting until idle to guarantee all timestamp bound updates are processed
// individually. (Timestamp bounds updates occur in the provide config only
// if input is an empty vector.)
MP_ASSERT_OK(graph_.WaitUntilIdle());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
EXPECT_THAT(
output_packets_,
testing::ElementsAre(PacketOfIntsEq(Timestamp(0), std::vector<int>{}),
PacketOfIntsEq(Timestamp(1), std::vector<int>{1, 2}),
PacketOfIntsEq(Timestamp(2), std::vector<int>{}),
PacketOfIntsEq(Timestamp(3), std::vector<int>{3, 4}),
PacketOfIntsEq(Timestamp(4), std::vector<int>{})));
}
class MultiplierCalculator : public CalculatorBase {
@ -199,8 +336,8 @@ REGISTER_CALCULATOR(MultiplierCalculator);
class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
protected:
BeginEndLoopCalculatorGraphWithClonedInputsTest() {
graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>(
void SetUp() override {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"(
num_threads: 4
input_stream: "ints"
@ -226,109 +363,85 @@ class BeginEndLoopCalculatorGraphWithClonedInputsTest : public ::testing::Test {
output_stream: "ITERABLE:multiplied_ints"
}
)");
tool::AddVectorSink("multiplied_ints", &graph_config_, &output_packets_);
tool::AddVectorSink("multiplied_ints", &graph_config, &output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
}
CalculatorGraphConfig graph_config_;
void SendPackets(Timestamp timestamp, int multiplier, std::vector<int> ints) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"ints", MakePacket<std::vector<int>>(std::move(ints)).At(timestamp)));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
}
void SendMultiplier(Timestamp timestamp, int multiplier) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"multiplier", MakePacket<int>(multiplier).At(timestamp)));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest,
InputStreamForIterableIsEmpty) {
Timestamp input_timestamp = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
auto multiplier = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
SendMultiplier(input_timestamp, /*multiplier=*/2);
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no packets
// to process.
ASSERT_EQ(0, output_packets_.size());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleEmptyVector) {
SendPackets(Timestamp(0), /*multiplier=*/2, /*ints=*/{});
MP_ASSERT_OK(graph_.WaitUntilIdle());
// EndLoopCalc will forward the timestamp bound because there are no elements
// in collection to output.
ASSERT_EQ(0, output_packets_.size());
EXPECT_TRUE(output_packets_.empty());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, SingleNonEmptyVector) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector = absl::make_unique<std::vector<int>>();
input_vector->emplace_back(0);
input_vector->emplace_back(1);
input_vector->emplace_back(2);
Timestamp input_timestamp = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector.release()).At(input_timestamp)));
auto multiplier = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier.release()).At(input_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
SendPackets(input_timestamp, /*multiplier=*/2, /*ints=*/{0, 1, 2});
MP_ASSERT_OK(graph_.WaitUntilIdle());
ASSERT_EQ(1, output_packets_.size());
EXPECT_EQ(input_timestamp, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector = {0, 2, 4};
EXPECT_EQ(expected_output_vector, output_packets_[0].Get<std::vector<int>>());
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp, std::vector<int>{0, 2, 4})));
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
}
TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config_));
MP_EXPECT_OK(graph.StartRun({}));
auto input_vector0 = absl::make_unique<std::vector<int>>();
input_vector0->emplace_back(0);
input_vector0->emplace_back(1);
Timestamp input_timestamp0 = Timestamp(42);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector0.release()).At(input_timestamp0)));
auto multiplier0 = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier0.release()).At(input_timestamp0)));
SendPackets(input_timestamp0, /*multiplier=*/2, /*ints=*/{0, 1});
auto input_vector1 = absl::make_unique<std::vector<int>>();
Timestamp input_timestamp1 = Timestamp(43);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector1.release()).At(input_timestamp1)));
auto multiplier1 = absl::make_unique<int>(2);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier1.release()).At(input_timestamp1)));
SendPackets(input_timestamp1, /*multiplier=*/2, /*ints=*/{});
auto input_vector2 = absl::make_unique<std::vector<int>>();
input_vector2->emplace_back(2);
input_vector2->emplace_back(3);
Timestamp input_timestamp2 = Timestamp(44);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"ints", Adopt(input_vector2.release()).At(input_timestamp2)));
auto multiplier2 = absl::make_unique<int>(3);
MP_ASSERT_OK(graph.AddPacketToInputStream(
"multiplier", Adopt(multiplier2.release()).At(input_timestamp2)));
SendPackets(input_timestamp2, /*multiplier=*/3, /*ints=*/{2, 3});
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(2, output_packets_.size());
EXPECT_EQ(input_timestamp0, output_packets_[0].Timestamp());
std::vector<int> expected_output_vector0 = {0, 2};
EXPECT_EQ(expected_output_vector0,
output_packets_[0].Get<std::vector<int>>());
MP_ASSERT_OK(graph_.CloseAllPacketSources());
MP_ASSERT_OK(graph_.WaitUntilDone());
// At input_timestamp1, EndLoopCalc will forward timestamp bound as there are
// no elements in vector to process.
EXPECT_EQ(input_timestamp2, output_packets_[1].Timestamp());
std::vector<int> expected_output_vector2 = {6, 9};
EXPECT_EQ(expected_output_vector2,
output_packets_[1].Get<std::vector<int>>());
EXPECT_THAT(output_packets_,
testing::ElementsAre(
PacketOfIntsEq(input_timestamp0, std::vector<int>{0, 2}),
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
}
} // namespace

View File

@ -16,6 +16,7 @@
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -31,4 +32,9 @@ typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
BeginLoopNormalizedRectCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedRectCalculator);
// A calculator to process std::vector<Detection>.
typedef BeginLoopCalculator<std::vector<::mediapipe::Detection>>
BeginLoopDetectionCalculator;
REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
} // namespace mediapipe

View File

@ -52,20 +52,28 @@ namespace mediapipe {
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// }
//
// BeginLoopCalculator accepts an optional input stream tagged with "TICK"
// which if non-empty, wakes up the calculator and calls
// BeginLoopCalculator::Process(). Input streams tagged with "CLONE" are cloned
// to the corresponding output streams at loop timestamps. This ensures that a
// MediaPipe graph or sub-graph can run multiple times, once per element in the
// "ITERABLE" for each pakcet clone of the packets in the "CLONE" input streams.
// Input streams tagged with "CLONE" are cloned to the corresponding output
// streams at loop timestamps. This ensures that a MediaPipe graph or sub-graph
// can run multiple times, once per element in the "ITERABLE" for each pakcet
// clone of the packets in the "CLONE" input streams.
template <typename IterableT>
class BeginLoopCalculator : public CalculatorBase {
using ItemT = typename IterableT::value_type;
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
// The below enables processing of timestamp bound updates, and that enables
// correct timestamp propagation by the companion EndLoopCalculator.
//
// For instance, Process() function will be still invoked even if upstream
// calculator has updated timestamp bound for ITERABLE input instead of
// providing actual value.
cc->SetProcessTimestampBounds(true);
// A non-empty packet in the optional "TICK" input stream wakes up the
// calculator.
// DEPRECATED as timestamp bound updates are processed by default in this
// calculator.
if (cc->Inputs().HasTag("TICK")) {
cc->Inputs().Tag("TICK").SetAny();
}

View File

@ -17,6 +17,7 @@
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
@ -25,4 +26,8 @@ typedef ClipVectorSizeCalculator<::mediapipe::NormalizedRect>
ClipNormalizedRectVectorSizeCalculator;
REGISTER_CALCULATOR(ClipNormalizedRectVectorSizeCalculator);
typedef ClipVectorSizeCalculator<::mediapipe::Detection>
ClipDetectionVectorSizeCalculator;
REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator);
} // namespace mediapipe

View File

@ -16,6 +16,7 @@
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h"
@ -37,4 +38,8 @@ typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>>
EndLoopRenderDataCalculator;
REGISTER_CALCULATOR(EndLoopRenderDataCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::ClassificationList>>
EndLoopClassificationListCalculator;
REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
} // namespace mediapipe

View File

@ -25,13 +25,17 @@ namespace mediapipe {
// together with some previous output.
//
// For the first packet that arrives on the MAIN input, the timestamp bound is
// advanced on the output. Downstream calculators will see this as an empty
// advanced on the PREV_LOOP. Downstream calculators will see this as an empty
// packet. This way they are not kept waiting for the previous output, which
// for the first iteration does not exist.
//
// Thereafter, each packet received on MAIN is matched with a packet received
// on LOOP; the LOOP packet's timestamp is changed to that of the MAIN packet,
// and it is output on PREV_LOOP.
// Thereafter,
// - Each non-empty MAIN packet results in:
// a) a PREV_LOOP packet with contents of the LOOP packet received at the
// timestamp of the previous non-empty MAIN packet
// b) or in a PREV_LOOP timestamp bound update if the LOOP packet was empty.
// - Each empty MAIN packet indicating timestamp bound update results in a
// PREV_LOOP timestamp bound update.
//
// Example config:
// node {
@ -56,83 +60,115 @@ class PreviousLoopbackCalculator : public CalculatorBase {
// TODO: an optional PREV_TIMESTAMP output could be added to
// carry the original timestamp of the packet on PREV_LOOP.
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
// Process() function is invoked in response to MAIN/LOOP stream timestamp
// bound updates.
cc->SetProcessTimestampBounds(true);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
main_id_ = cc->Inputs().GetId("MAIN", 0);
loop_id_ = cc->Inputs().GetId("LOOP", 0);
loop_out_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
prev_loop_id_ = cc->Outputs().GetId("PREV_LOOP", 0);
cc->Outputs()
.Get(loop_out_id_)
.Get(prev_loop_id_)
.SetHeader(cc->Inputs().Get(loop_id_).Header());
// Use an empty packet for the first round, since there is no previous
// output.
loopback_packets_.push_back({});
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
Packet& main_packet = cc->Inputs().Get(main_id_).Value();
if (!main_packet.IsEmpty()) {
main_ts_.push_back(main_packet.Timestamp());
}
Packet& loopback_packet = cc->Inputs().Get(loop_id_).Value();
if (!loopback_packet.IsEmpty()) {
loopback_packets_.push_back(loopback_packet);
while (!main_ts_.empty() &&
main_ts_.front() <= loopback_packets_.front().Timestamp()) {
main_ts_.pop_front();
}
}
auto& loop_out = cc->Outputs().Get(loop_out_id_);
// Non-empty packets and empty packets indicating timestamp bound updates
// are guaranteed to have timestamps greater than timestamps of previous
// packets within the same stream. Calculator tracks and operates on such
// packets.
while (!main_ts_.empty() && !loopback_packets_.empty()) {
Timestamp main_timestamp = main_ts_.front();
main_ts_.pop_front();
Packet previous_loopback = loopback_packets_.front().At(main_timestamp);
loopback_packets_.pop_front();
if (previous_loopback.IsEmpty()) {
// TODO: SetCompleteTimestampBound would be more useful.
loop_out.SetNextTimestampBound(main_timestamp + 1);
const Packet& main_packet = cc->Inputs().Get(main_id_).Value();
if (prev_main_ts_ < main_packet.Timestamp()) {
Timestamp loop_timestamp;
if (!main_packet.IsEmpty()) {
loop_timestamp = prev_non_empty_main_ts_;
prev_non_empty_main_ts_ = main_packet.Timestamp();
} else {
loop_out.AddPacket(std::move(previous_loopback));
// Calculator advances PREV_LOOP timestamp bound in response to empty
// MAIN packet, hence not caring about corresponding loop packet.
loop_timestamp = Timestamp::Unset();
}
main_packet_specs_.push_back({.timestamp = main_packet.Timestamp(),
.loop_timestamp = loop_timestamp});
prev_main_ts_ = main_packet.Timestamp();
}
const Packet& loop_packet = cc->Inputs().Get(loop_id_).Value();
if (prev_loop_ts_ < loop_packet.Timestamp()) {
loop_packets_.push_back(loop_packet);
prev_loop_ts_ = loop_packet.Timestamp();
}
auto& prev_loop = cc->Outputs().Get(prev_loop_id_);
while (!main_packet_specs_.empty() && !loop_packets_.empty()) {
// The earliest MAIN packet.
const MainPacketSpec& main_spec = main_packet_specs_.front();
// The earliest LOOP packet.
const Packet& loop_candidate = loop_packets_.front();
// Match LOOP and MAIN packets.
if (main_spec.loop_timestamp < loop_candidate.Timestamp()) {
// No LOOP packet can match the MAIN packet under review.
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
main_packet_specs_.pop_front();
} else if (main_spec.loop_timestamp > loop_candidate.Timestamp()) {
// No MAIN packet can match the LOOP packet under review.
loop_packets_.pop_front();
} else {
// Exact match found.
if (loop_candidate.IsEmpty()) {
// However, LOOP packet is empty.
prev_loop.SetNextTimestampBound(main_spec.timestamp + 1);
} else {
prev_loop.AddPacket(loop_candidate.At(main_spec.timestamp));
}
loop_packets_.pop_front();
main_packet_specs_.pop_front();
}
}
// In case of an empty loopback input, the next timestamp bound for
// loopback input is the loopback timestamp + 1. The next timestamp bound
// for output is set and the main_ts_ vector is truncated accordingly.
if (loopback_packet.IsEmpty() &&
loopback_packet.Timestamp() != Timestamp::Unstarted()) {
Timestamp loopback_bound =
loopback_packet.Timestamp().NextAllowedInStream();
while (!main_ts_.empty() && main_ts_.front() <= loopback_bound) {
main_ts_.pop_front();
}
if (main_ts_.empty()) {
loop_out.SetNextTimestampBound(loopback_bound.NextAllowedInStream());
}
}
if (!main_ts_.empty()) {
loop_out.SetNextTimestampBound(main_ts_.front());
}
if (cc->Inputs().Get(main_id_).IsDone() && main_ts_.empty()) {
loop_out.Close();
if (main_packet_specs_.empty() && cc->Inputs().Get(main_id_).IsDone()) {
prev_loop.Close();
}
return ::mediapipe::OkStatus();
}
private:
struct MainPacketSpec {
Timestamp timestamp;
// Expected timestamp of the packet from LOOP stream that corresponds to the
// packet from MAIN stream descirbed by this spec.
Timestamp loop_timestamp;
};
CollectionItemId main_id_;
CollectionItemId loop_id_;
CollectionItemId loop_out_id_;
CollectionItemId prev_loop_id_;
std::deque<Timestamp> main_ts_;
std::deque<Packet> loopback_packets_;
// Contains specs for MAIN packets which only can be:
// - non-empty packets
// - empty packets indicating timestamp bound updates
//
// Sorted according to packet timestamps.
std::deque<MainPacketSpec> main_packet_specs_;
Timestamp prev_main_ts_ = Timestamp::Unstarted();
Timestamp prev_non_empty_main_ts_ = Timestamp::Unstarted();
// Contains LOOP packets which only can be:
// - the very first empty packet
// - non empty packets
// - empty packets indicating timestamp bound updates
//
// Sorted according to packet timestamps.
std::deque<Packet> loop_packets_;
// Using "Timestamp::Unset" instead of "Timestamp::Unstarted" in order to
// allow addition of the very first empty packet (which doesn't indicate
// timestamp bound change necessarily).
Timestamp prev_loop_ts_ = Timestamp::Unset();
};
REGISTER_CALCULATOR(PreviousLoopbackCalculator);

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <vector>
@ -25,12 +26,17 @@
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pair;
using ::testing::Value;
namespace {
// Returns the timestamp values for a vector of Packets.
@ -43,6 +49,23 @@ std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
return result;
}
MATCHER(EmptyPacket, negation ? "isn't empty" : "is empty") {
if (arg.IsEmpty()) {
return true;
}
return false;
}
MATCHER_P(IntPacket, value, "") {
return Value(arg.template Get<int>(), Eq(value));
}
MATCHER_P2(PairPacket, timestamp, pair, "") {
Timestamp actual_timestamp = arg.Timestamp();
const auto& actual_pair = arg.template Get<std::pair<Packet, Packet>>();
return Value(actual_timestamp, Eq(timestamp)) && Value(actual_pair, pair);
}
TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
std::vector<Packet> in_prev;
CalculatorGraphConfig graph_config_ =
@ -81,32 +104,30 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) {
MP_EXPECT_OK(graph_.AddPacketToInputStream(
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
auto pair_values = [](const Packet& packet) {
auto pair = packet.Get<std::pair<Packet, Packet>>();
int first = pair.first.IsEmpty() ? -1 : pair.first.Get<int>();
int second = pair.second.IsEmpty() ? -1 : pair.second.Get<int>();
return std::make_pair(first, second);
};
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(1, -1));
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())));
send_packet("in", 2);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(2, 1));
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2, 5}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(5, 2));
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(2))));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(in_prev), (std::vector<int64>{1, 2, 5, 15}));
EXPECT_EQ(pair_values(in_prev.back()), std::make_pair(15, 5));
EXPECT_THAT(TimestampValues(in_prev), ElementsAre(1, 2, 5, 15));
EXPECT_THAT(in_prev.back(),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
@ -185,24 +206,24 @@ TEST(PreviousLoopbackCalculator, ClosesCorrectly) {
send_packet("in", 1);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1}));
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1));
send_packet("in", 2);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2}));
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2));
send_packet("in", 5);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2, 5}));
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5));
send_packet("in", 15);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{1, 2, 5, 15}));
EXPECT_THAT(TimestampValues(outputs), ElementsAre(1, 2, 5, 15));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs),
(std::vector<int64>{1, 2, 5, 15, Timestamp::Max().Value()}));
EXPECT_THAT(TimestampValues(outputs),
ElementsAre(1, 2, 5, 15, Timestamp::Max().Value()));
MP_EXPECT_OK(graph_.WaitUntilDone());
}
@ -247,16 +268,12 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
input_name, MakePacket<int>(n).At(Timestamp(n))));
};
send_packet("in", 0);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_EQ(TimestampValues(outputs), (std::vector<int64>{0}));
for (int main_ts = 1; main_ts < 50; ++main_ts) {
for (int main_ts = 0; main_ts < 50; ++main_ts) {
send_packet("in", main_ts);
MP_EXPECT_OK(graph_.WaitUntilIdle());
std::vector<int64> ts_values = TimestampValues(outputs);
EXPECT_EQ(ts_values.size(), main_ts + 1);
for (int j = 0; j < main_ts; ++j) {
for (int j = 0; j < main_ts + 1; ++j) {
EXPECT_EQ(ts_values[j], j);
}
}
@ -266,5 +283,487 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
MP_EXPECT_OK(graph_.WaitUntilDone());
}
class PreviousLoopbackCalculatorProcessingTimestampsTest
: public testing::Test {
protected:
void SetUp() override {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
input_stream: 'force_main_empty'
input_stream: 'force_loop_empty'
# Used to indicate "main" timestamp bound updates.
node {
calculator: 'GateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:force_main_empty'
output_stream: 'main'
}
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:main'
input_stream: 'LOOP:loop'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:prev_loop'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'prev_loop'
output_stream: 'passed_through_input'
output_stream: 'passed_through_prev_loop'
}
# Used to indicate "loop" timestamp bound updates.
node {
calculator: 'GateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:force_loop_empty'
output_stream: 'loop'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'passed_through_input'
input_stream: 'passed_through_prev_loop'
output_stream: 'passed_through_input_and_prev_loop'
}
)");
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPackets(int timestamp, int input, bool force_main_empty,
bool force_loop_empty) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_main_empty",
MakePacket<bool>(force_main_empty).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"force_loop_empty",
MakePacket<bool>(force_loop_empty).At(Timestamp(timestamp))));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsEmptyMainNonEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsNonEmptyMainEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsAlteringMainNonEmptyLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1)))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/true,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), IntPacket(1))),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsNonEmptyMainAlteringLoop) {
SendPackets(/*timestamp=*/1, /*input=*/1,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/3, /*input=*/3,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5,
/*force_main_empty=*/false,
/*force_loop_empty=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3)))));
SendPackets(/*timestamp=*/15, /*input=*/15,
/*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), IntPacket(3))),
PairPacket(Timestamp(15), Pair(IntPacket(15), EmptyPacket()))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest,
MultiplePacketsCheckIfLastCorrectAlteringMainAlteringLoop) {
int num_packets = 1000;
for (int i = 0; i < num_packets; ++i) {
bool force_main_empty = i % 3 == 0 ? true : false;
bool force_loop_empty = i % 2 == 0 ? true : false;
SendPackets(/*timestamp=*/i + 1, /*input=*/i + 1, force_main_empty,
force_loop_empty);
}
SendPackets(/*timestamp=*/num_packets + 1,
/*input=*/num_packets + 1, /*force_main_empty=*/false,
/*force_loop_empty=*/false);
SendPackets(/*timestamp=*/num_packets + 2,
/*input=*/num_packets + 2, /*force_main_empty=*/false,
/*force_loop_empty=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
ASSERT_FALSE(output_packets_.empty());
EXPECT_THAT(
output_packets_.back(),
PairPacket(Timestamp(num_packets + 2),
Pair(IntPacket(num_packets + 2), IntPacket(num_packets + 1))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
// Similar to GateCalculator, but it doesn't propagate timestamp bound updates.
class DroppingGateCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Tag("DISALLOW").Set<bool>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Index(0).IsEmpty() &&
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(DroppingGateCalculator);
// Tests PreviousLoopbackCalculator in cases when there are no "LOOP" timestamp
// bound updates and non-empty packets for a while and the aforementioned start
// to arrive at some point. So, "PREV_LOOP" is delayed for a couple of inputs.
class PreviousLoopbackCalculatorDelayBehaviorTest : public testing::Test {
protected:
void SetUp() override {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
# Drops "loop" when set to "true", delaying output of prev_loop, hence
# delaying output of the graph.
input_stream: 'delay_next_output'
node {
calculator: 'PreviousLoopbackCalculator'
input_stream: 'MAIN:input'
input_stream: 'LOOP:loop'
input_stream_info: { tag_index: 'LOOP' back_edge: true }
output_stream: 'PREV_LOOP:prev_loop'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'prev_loop'
output_stream: 'passed_through_input'
output_stream: 'passed_through_prev_loop'
}
node {
calculator: 'DroppingGateCalculator'
input_stream: 'input'
input_stream: 'DISALLOW:delay_next_output'
output_stream: 'loop'
}
node {
calculator: 'MakePairCalculator'
input_stream: 'passed_through_input'
input_stream: 'passed_through_prev_loop'
output_stream: 'passed_through_input_and_prev_loop'
}
)");
tool::AddVectorSink("passed_through_input_and_prev_loop", &graph_config,
&output_packets_);
MP_ASSERT_OK(graph_.Initialize(graph_config, {}));
MP_ASSERT_OK(graph_.StartRun({}));
}
void SendPackets(int timestamp, int input, bool delay_next_output) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"input", MakePacket<int>(input).At(Timestamp(timestamp))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"delay_next_output",
MakePacket<bool>(delay_next_output).At(Timestamp(timestamp))));
}
CalculatorGraph graph_;
std::vector<Packet> output_packets_;
};
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest, MultipleDelayedOutputs) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), EmptyPacket())),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
TEST_F(PreviousLoopbackCalculatorDelayBehaviorTest,
NonDelayedOutputFollowedByMultipleDelayedOutputs) {
SendPackets(/*timestamp=*/1, /*input=*/1, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket()))));
SendPackets(/*timestamp=*/2, /*input=*/2, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/3, /*input=*/3, /*delay_next_output=*/true);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1)))));
SendPackets(/*timestamp=*/5, /*input=*/5, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket()))));
SendPackets(/*timestamp=*/15, /*input=*/15, /*delay_next_output=*/false);
MP_EXPECT_OK(graph_.WaitUntilIdle());
EXPECT_THAT(
output_packets_,
ElementsAre(
PairPacket(Timestamp(1), Pair(IntPacket(1), EmptyPacket())),
PairPacket(Timestamp(2), Pair(IntPacket(2), IntPacket(1))),
PairPacket(Timestamp(3), Pair(IntPacket(3), EmptyPacket())),
PairPacket(Timestamp(5), Pair(IntPacket(5), EmptyPacket())),
PairPacket(Timestamp(15), Pair(IntPacket(15), IntPacket(5)))));
MP_EXPECT_OK(graph_.CloseAllInputStreams());
MP_EXPECT_OK(graph_.WaitUntilDone());
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -16,6 +16,7 @@
#include <vector>
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "tensorflow/lite/interpreter.h"
@ -48,6 +49,10 @@ typedef SplitVectorCalculator<::mediapipe::NormalizedLandmark, false>
SplitLandmarkVectorCalculator;
REGISTER_CALCULATOR(SplitLandmarkVectorCalculator);
typedef SplitVectorCalculator<::mediapipe::NormalizedLandmarkList, false>
SplitNormalizedLandmarkListVectorCalculator;
REGISTER_CALCULATOR(SplitNormalizedLandmarkListVectorCalculator);
typedef SplitVectorCalculator<::mediapipe::NormalizedRect, false>
SplitNormalizedRectVectorCalculator;
REGISTER_CALCULATOR(SplitNormalizedRectVectorCalculator);
@ -57,4 +62,9 @@ typedef SplitVectorCalculator<::tflite::gpu::gl::GlBuffer, true>
MovableSplitGlBufferVectorCalculator;
REGISTER_CALCULATOR(MovableSplitGlBufferVectorCalculator);
#endif
typedef SplitVectorCalculator<::mediapipe::Detection, false>
SplitDetectionVectorCalculator;
REGISTER_CALCULATOR(SplitDetectionVectorCalculator);
} // namespace mediapipe

View File

@ -422,9 +422,12 @@ cc_library(
":recolor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [

View File

@ -17,6 +17,9 @@
#include "mediapipe/calculators/image/recolor_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/color.pb.h"
@ -39,8 +42,6 @@ namespace mediapipe {
// The luminance of the input image is used to adjust the blending weight,
// to help preserve image textures.
//
// TODO implement cpu support.
//
// Inputs:
// One of the following IMAGE tags:
// IMAGE: An ImageFrame input image, RGB or RGBA.
@ -71,6 +72,8 @@ namespace mediapipe {
// }
// }
//
// Note: Cannot mix-match CPU & GPU inputs/outputs.
// CPU-in & CPU-out <or> GPU-in & GPU-out
class RecolorCalculator : public CalculatorBase {
public:
RecolorCalculator() = default;
@ -138,6 +141,11 @@ REGISTER_CALCULATOR(RecolorCalculator);
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
// Confirm only one of the input streams is present.
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
// Confirm only one of the output streams is present.
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
@ -193,7 +201,62 @@ REGISTER_CALCULATOR(RecolorCalculator);
}
::mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) {
return ::mediapipe::UnimplementedError("CPU support is not implemented yet.");
if (cc->Inputs().Tag("MASK").IsEmpty()) {
return ::mediapipe::OkStatus();
}
// Get inputs and setup output.
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>();
const auto& mask_img = cc->Inputs().Tag("MASK").Get<ImageFrame>();
cv::Mat input_mat = formats::MatView(&input_img);
cv::Mat mask_mat = formats::MatView(&mask_img);
RET_CHECK(input_mat.channels() == 3); // RGB only.
if (mask_mat.channels() > 1) {
std::vector<cv::Mat> channels;
cv::split(mask_mat, channels);
if (mask_channel_ == mediapipe::RecolorCalculatorOptions_MaskChannel_ALPHA)
mask_mat = channels[3];
else
mask_mat = channels[0];
}
cv::Mat mask_full;
cv::resize(mask_mat, mask_full, input_mat.size());
auto output_img = absl::make_unique<ImageFrame>(
input_img.Format(), input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_img.get());
// From GPU shader:
/*
vec4 weight = texture2D(mask, sample_coordinate);
vec4 color1 = texture2D(frame, sample_coordinate);
vec4 color2 = vec4(recolor, 1.0);
float luminance = dot(color1.rgb, vec3(0.299, 0.587, 0.114));
float mix_value = weight.MASK_COMPONENT * luminance;
fragColor = mix(color1, color2, mix_value);
*/
for (int i = 0; i < output_mat.rows; ++i) {
for (int j = 0; j < output_mat.cols; ++j) {
float weight = mask_full.at<uchar>(i, j) * (1.0 / 255.0);
cv::Vec3f color1 = input_mat.at<cv::Vec3b>(i, j);
cv::Vec3f color2 = {color_[0], color_[1], color_[2]};
float luminance =
(color1[0] * 0.299 + color1[1] * 0.587 + color1[2] * 0.114) / 255;
float mix_value = weight * luminance;
cv::Vec3b mix_color = color1 * (1.0 - mix_value) + color2 * mix_value;
output_mat.at<cv::Vec3b>(i, j) = mix_color;
}
}
cc->Outputs().Tag("IMAGE").Add(output_img.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) {
@ -303,9 +366,9 @@ void RecolorCalculator::GlRender() {
if (!options.has_color()) RET_CHECK_FAIL() << "Missing color option.";
color_.push_back(options.color().r() / 255.0);
color_.push_back(options.color().g() / 255.0);
color_.push_back(options.color().b() / 255.0);
color_.push_back(options.color().r());
color_.push_back(options.color().g());
color_.push_back(options.color().b());
return ::mediapipe::OkStatus();
}
@ -378,8 +441,8 @@ void RecolorCalculator::GlRender() {
glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "frame"), 1);
glUniform1i(glGetUniformLocation(program_, "mask"), 2);
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
color_[2]);
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0] / 255.0,
color_[1] / 255.0, color_[2] / 255.0);
#endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus();

View File

@ -1110,6 +1110,7 @@ cc_test(
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:android_tensorflow_lib_with_ops_lite_proto_no_rtti_lib",
"@org_tensorflow//tensorflow/core:android_tensorflow_test_lib",
],
"//mediapipe:ios": [
"@org_tensorflow//tensorflow/core:ios_tensorflow_test_lib",

View File

@ -222,9 +222,11 @@ cc_library(
deps = [
":util",
":tflite_inference_calculator_cc_proto",
"@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/port:ret_check",
@ -254,6 +256,10 @@ cc_library(
"//mediapipe:android": [
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
],
}) + select({
"//conditions:default": [
"//mediapipe/util:cpu_util",
],
}),
alwayslink = 1,
)
@ -308,6 +314,20 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tflite_model_calculator",
srcs = ["tflite_model_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,
)
cc_library(
name = "tflite_tensors_to_segmentation_calculator",
srcs = ["tflite_tensors_to_segmentation_calculator.cc"],
@ -478,6 +498,9 @@ cc_test(
deps = [
":tflite_inference_calculator",
":tflite_inference_calculator_cc_proto",
":tflite_model_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
@ -485,7 +508,9 @@ cc_test(
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:validate_type",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
@ -511,3 +536,19 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
cc_test(
name = "tflite_model_calculator_test",
srcs = ["tflite_model_calculator_test.cc"],
data = ["testdata/add.bin"],
deps = [
":tflite_model_calculator",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@org_tensorflow//tensorflow/lite:framework",
],
)

View File

@ -17,10 +17,16 @@
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#if !defined(__EMSCRIPTEN__)
#include "mediapipe/util/cpu_util.h"
#endif // !__EMSCRIPTEN__
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
@ -50,7 +56,7 @@
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h"
#endif // iOS
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID
@ -113,6 +119,23 @@ struct GPUData {
};
#endif
// Returns number of threads to configure XNNPACK delegate with.
// (Equal to user provided value if specified. Otherwise, it returns number of
// high cores (hard-coded to 1 for __EMSCRIPTEN__))
int GetXnnpackNumThreads(
const mediapipe::TfLiteInferenceCalculatorOptions& opts) {
static constexpr int kDefaultNumThreads = -1;
if (opts.has_delegate() && opts.delegate().has_xnnpack() &&
opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) {
return opts.delegate().xnnpack().num_threads();
}
#if !defined(__EMSCRIPTEN__)
return InferHigherCoreIds().size();
#else
return 1;
#endif // !__EMSCRIPTEN__
}
// Calculator Header Section
// Runs inference on the provided input TFLite tensors and TFLite model.
@ -139,6 +162,9 @@ struct GPUData {
// Input side packet:
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
// instead of the builtin one.
// MODEL (optional) - Use to specify TfLite model
// (std::unique_ptr<tflite::FlatBufferModel,
// std::function<void(tflite::FlatBufferModel*)>>)
//
// Example use:
// node {
@ -153,6 +179,20 @@ struct GPUData {
// }
// }
//
// or
//
// node {
// calculator: "TfLiteInferenceCalculator"
// input_stream: "TENSORS:tensor_image"
// input_side_packet: "MODEL:model"
// output_stream: "TENSORS:tensors"
// options: {
// [mediapipe.TfLiteInferenceCalculatorOptions.ext] {
// delegate { gpu {} }
// }
// }
// }
//
// IMPORTANT Notes:
// Tensors are assumed to be ordered correctly (sequentially added to model).
// Input tensors are assumed to be of the correct size and already normalized.
@ -165,6 +205,9 @@ class TfLiteInferenceCalculator : public CalculatorBase {
public:
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
using TfLiteModelPtr =
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
@ -173,12 +216,12 @@ class TfLiteInferenceCalculator : public CalculatorBase {
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::Status LoadOptions(CalculatorContext* cc);
::mediapipe::Status LoadModel(CalculatorContext* cc);
::mediapipe::StatusOr<Packet> GetModelAsPacket(const CalculatorContext& cc);
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
Packet model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_;
std::unique_ptr<tflite::FlatBufferModel> model_;
TfLiteDelegatePtr delegate_;
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
@ -198,7 +241,6 @@ class TfLiteInferenceCalculator : public CalculatorBase {
edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
#endif
std::string model_path_ = "";
bool gpu_inference_ = false;
bool gpu_input_ = false;
bool gpu_output_ = false;
@ -217,6 +259,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
const auto& options =
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
RET_CHECK(!options.model_path().empty() ^
cc->InputSidePackets().HasTag("MODEL"))
<< "Either model as side packet or model path in options is required.";
bool use_gpu =
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
@ -249,6 +295,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
.Tag("CUSTOM_OP_RESOLVER")
.Set<tflite::ops::builtin::BuiltinOpResolver>();
}
if (cc->InputSidePackets().HasTag("MODEL")) {
cc->InputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
}
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
@ -267,7 +316,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
MP_RETURN_IF_ERROR(LoadOptions(cc));
const auto& options =
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
gpu_inference_ = options.use_gpu();
if (cc->Inputs().HasTag("TENSORS_GPU")) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__EMSCRIPTEN__)
@ -492,34 +543,10 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Calculator Auxiliary Section
::mediapipe::Status TfLiteInferenceCalculator::LoadOptions(
CalculatorContext* cc) {
// Get calculator options specified in the graph.
const auto& options =
cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
// Get model name.
if (!options.model_path().empty()) {
std::string model_path = options.model_path();
ASSIGN_OR_RETURN(model_path_, mediapipe::PathToResourceAsFile(model_path));
} else {
LOG(ERROR) << "Must specify path to TFLite model.";
return ::mediapipe::Status(::mediapipe::StatusCode::kNotFound,
"Must specify path to TFLite model.");
}
// Get execution modes.
gpu_inference_ =
options.has_delegate() ? options.delegate().has_gpu() : options.use_gpu();
return ::mediapipe::OkStatus();
}
::mediapipe::Status TfLiteInferenceCalculator::LoadModel(
CalculatorContext* cc) {
model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
RET_CHECK(model_);
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc));
const auto& model = *model_packet_.Get<TfLiteModelPtr>();
tflite::ops::builtin::BuiltinOpResolver op_resolver;
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
@ -529,9 +556,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
#if defined(MEDIAPIPE_EDGE_TPU)
interpreter_ =
BuildEdgeTpuInterpreter(*model_, &op_resolver, edgetpu_context_.get());
BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get());
#else
tflite::InterpreterBuilder(*model_, op_resolver)(&interpreter_);
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
#endif // MEDIAPIPE_EDGE_TPU
RET_CHECK(interpreter_);
@ -557,6 +584,28 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
return ::mediapipe::OkStatus();
}
::mediapipe::StatusOr<Packet> TfLiteInferenceCalculator::GetModelAsPacket(
const CalculatorContext& cc) {
const auto& options =
cc.Options<mediapipe::TfLiteInferenceCalculatorOptions>();
if (!options.model_path().empty()) {
std::string model_path = options.model_path();
ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path));
auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
RET_CHECK(model) << "Failed to load model from path.";
return MakePacket<TfLiteModelPtr>(TfLiteModelPtr(
model.release(), [](tflite::FlatBufferModel* model) { delete model; }));
}
if (cc.InputSidePackets().HasTag("MODEL")) {
return cc.InputSidePackets().Tag("MODEL");
}
return ::mediapipe::Status(
::mediapipe::StatusCode::kNotFound,
"Must specify TFLite model as path or loaded model.");
}
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
CalculatorContext* cc) {
const auto& calculator_opts =
@ -587,6 +636,22 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
}
#endif // MEDIAPIPE_ANDROID
#if defined(__EMSCRIPTEN__)
const bool xnnpack_requested = true;
#else
const bool xnnpack_requested = calculator_opts.has_delegate() &&
calculator_opts.delegate().has_xnnpack();
#endif // __EMSCRIPTEN__
if (xnnpack_requested) {
TfLiteXNNPackDelegateOptions xnnpack_opts{};
xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts);
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
&TfLiteXNNPackDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
}
// Return, no need for GPU delegate below.
return ::mediapipe::OkStatus();
}

View File

@ -45,11 +45,17 @@ message TfLiteInferenceCalculatorOptions {
message Gpu {}
// Android only.
message Nnapi {}
message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries
// to choose optimal number of threads depending on the device.)
optional int32 num_threads = 1 [default = -1];
}
oneof delegate {
TfLite tflite = 1;
Gpu gpu = 2;
Nnapi nnapi = 3;
Xnnpack xnnpack = 4;
}
}

View File

@ -41,7 +41,7 @@ namespace mediapipe {
using ::tflite::Interpreter;
void DoSmokeTest(absl::string_view delegate) {
void DoSmokeTest(const std::string& graph_proto) {
const int width = 8;
const int height = 8;
const int channels = 3;
@ -69,24 +69,9 @@ void DoSmokeTest(absl::string_view delegate) {
auto input_vec = absl::make_unique<std::vector<TfLiteTensor>>();
input_vec->emplace_back(*tensor);
std::string graph_proto = R"(
input_stream: "tensor_in"
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:tensor_in"
output_stream: "TENSORS:tensor_out"
options {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
$delegate
}
}
}
)";
ASSERT_EQ(absl::StrReplaceAll({{"$delegate", delegate}}, &graph_proto), 1);
// Prepare single calculator graph to and wait for packets.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
CalculatorGraph graph(graph_config);
@ -119,8 +104,70 @@ void DoSmokeTest(absl::string_view delegate) {
// Tests a simple add model that adds an input tensor to itself.
TEST(TfLiteInferenceCalculatorTest, SmokeTest) {
DoSmokeTest(/*delegate=*/"");
DoSmokeTest(/*delegate=*/"delegate { tflite {} }");
std::string graph_proto = R"(
input_stream: "tensor_in"
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:tensor_in"
output_stream: "TENSORS:tensor_out"
options {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
$delegate
}
}
}
)";
DoSmokeTest(
/*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto, {{"$delegate", "delegate { xnnpack {} }"}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto,
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
}
TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
std::string graph_proto = R"(
input_stream: "tensor_in"
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "mediapipe/calculators/tflite/testdata/add.bin" }
}
}
}
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
}
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
output_side_packet: "MODEL:model"
}
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:tensor_in"
output_stream: "TENSORS:tensor_out"
input_side_packet: "MODEL:model"
options {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
use_gpu: false
}
}
}
)";
DoSmokeTest(graph_proto);
}
} // namespace mediapipe

View File

@ -0,0 +1,86 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include <string>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/model.h"
namespace mediapipe {
// Loads TfLite model from model blob specified as input side packet and outputs
// corresponding side packet.
//
// Input side packets:
// MODEL_BLOB - TfLite model blob/file-contents (std::string). You can read
// model blob from file (using whatever APIs you have) and pass
// it to the graph as input side packet or you can use some of
// calculators like LocalFileContentsCalculator to get model
// blob and use it as input here.
//
// Output side packets:
// MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
// std::function<void(tflite::FlatBufferModel*)>>)
//
// Example use:
//
// node {
// calculator: "TfLiteModelCalculator"
// input_side_packet: "MODEL_BLOB:model_blob"
// output_side_packet: "MODEL:model"
// }
//
class TfLiteModelCalculator : public CalculatorBase {
public:
using TfLiteModelPtr =
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>;
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>();
cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB");
const std::string& model_blob = model_packet.Get<std::string>();
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(),
model_blob.size());
RET_CHECK(model) << "Failed to load TfLite model from blob.";
cc->OutputSidePackets().Tag("MODEL").Set(
MakePacket<TfLiteModelPtr>(TfLiteModelPtr(
model.release(), [model_packet](tflite::FlatBufferModel* model) {
// Keeping model_packet in order to keep underlying model blob
// which can be released only after TfLite model is not needed
// anymore (deleted).
delete model;
})));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(TfLiteModelCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,88 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "tensorflow/lite/model.h"
namespace mediapipe {
TEST(TfLiteModelCalculatorTest, SmokeTest) {
// Prepare single calculator graph to and wait for packets.
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
CalculatorGraphConfig>(
R"(
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet {
string_value: "mediapipe/calculators/tflite/testdata/add.bin"
}
}
}
}
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
}
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
output_side_packet: "MODEL:model"
}
)");
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
auto status_or_packet = graph.GetOutputSidePacket("model");
MP_ASSERT_OK(status_or_packet);
auto model_packet = status_or_packet.ValueOrDie();
const auto& model = model_packet.Get<
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>>();
auto expected_model = tflite::FlatBufferModel::BuildFromFile(
"mediapipe/calculators/tflite/testdata/add.bin");
EXPECT_EQ(model->GetModel()->version(),
expected_model->GetModel()->version());
EXPECT_EQ(model->GetModel()->buffers()->size(),
expected_model->GetModel()->buffers()->size());
const int num_subgraphs = expected_model->GetModel()->subgraphs()->size();
EXPECT_EQ(model->GetModel()->subgraphs()->size(), num_subgraphs);
for (int i = 0; i < num_subgraphs; ++i) {
const auto* expected_subgraph =
expected_model->GetModel()->subgraphs()->Get(i);
const auto* subgraph = model->GetModel()->subgraphs()->Get(i);
const int num_tensors = expected_subgraph->tensors()->size();
EXPECT_EQ(subgraph->tensors()->size(), num_tensors);
for (int j = 0; j < num_tensors; ++j) {
EXPECT_EQ(subgraph->tensors()->Get(j)->name()->str(),
expected_subgraph->tensors()->Get(j)->name()->str());
}
}
}
} // namespace mediapipe

View File

@ -129,22 +129,43 @@ REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator);
num_classes *= raw_score_tensor->dims->data[i];
}
if (options_.binary_classification()) {
RET_CHECK_EQ(num_classes, 1);
// Number of classes for binary classification.
num_classes = 2;
}
if (label_map_loaded_) {
RET_CHECK_EQ(num_classes, label_map_.size());
}
const float* raw_scores = raw_score_tensor->data.f;
auto classification_list = absl::make_unique<ClassificationList>();
for (int i = 0; i < num_classes; ++i) {
if (options_.has_min_score_threshold() &&
raw_scores[i] < options_.min_score_threshold()) {
continue;
}
Classification* classification = classification_list->add_classification();
classification->set_index(i);
classification->set_score(raw_scores[i]);
if (options_.binary_classification()) {
Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification();
class_first->set_index(0);
class_second->set_index(1);
class_first->set_score(raw_scores[0]);
class_second->set_score(1. - raw_scores[0]);
if (label_map_loaded_) {
classification->set_label(label_map_[i]);
class_first->set_label(label_map_[0]);
class_second->set_label(label_map_[1]);
}
} else {
for (int i = 0; i < num_classes; ++i) {
if (options_.has_min_score_threshold() &&
raw_scores[i] < options_.min_score_threshold()) {
continue;
}
Classification* classification =
classification_list->add_classification();
classification->set_index(i);
classification->set_score(raw_scores[i]);
if (label_map_loaded_) {
classification->set_label(label_map_[i]);
}
}
}

View File

@ -32,4 +32,10 @@ message TfLiteTensorsToClassificationCalculatorOptions {
optional int32 top_k = 2;
// Path to a label map file for getting the actual name of class ids.
optional string label_map_path = 3;
// Whether the input is a single float for binary classification.
// When true, only a single float is expected in the input tensor and the
// label map, if provided, is expected to have exactly two labels.
// The single score(float) represent the probability of first label, and
// 1 - score is the probabilility of the second label.
optional bool binary_classification = 4;
}

View File

@ -998,6 +998,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
@ -1015,6 +1016,7 @@ cc_library(
deps = [
":collection_has_min_size_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
@ -1022,6 +1024,18 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "collection_has_min_size_calculator_test",
srcs = ["collection_has_min_size_calculator_test.cc"],
deps = [
":collection_has_min_size_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library(
name = "association_calculator",
hdrs = ["association_calculator.h"],

View File

@ -15,6 +15,9 @@
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
#include <vector>
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
@ -23,4 +26,9 @@ typedef CollectionHasMinSizeCalculator<std::vector<::mediapipe::NormalizedRect>>
NormalizedRectVectorHasMinSizeCalculator;
REGISTER_CALCULATOR(NormalizedRectVectorHasMinSizeCalculator);
typedef CollectionHasMinSizeCalculator<
std::vector<::mediapipe::NormalizedLandmarkList>>
NormalizedLandmarkListVectorHasMinSizeCalculator;
REGISTER_CALCULATOR(NormalizedLandmarkListVectorHasMinSizeCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,156 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/util/collection_has_min_size_calculator.h"
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
namespace mediapipe {
typedef CollectionHasMinSizeCalculator<std::vector<int>>
TestIntCollectionHasMinSizeCalculator;
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
void AddInputVector(const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()
->Tag("ITERABLE")
.packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}
TEST(TestIntCollectionHasMinSizeCalculator, DoesHaveMinSize) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestIntCollectionHasMinSizeCalculator"
input_stream: "ITERABLE:input_vector"
output_stream: "output_vector"
options {
[mediapipe.CollectionHasMinSizeCalculatorOptions.ext] { min_size: 2 }
}
)");
CalculatorRunner runner(node_config);
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
AddInputVector({1, 2}, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_TRUE(outputs[0].Get<bool>());
AddInputVector({1, 2, 3}, /*timestamp=*/2, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(2, outputs.size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
EXPECT_TRUE(outputs[1].Get<bool>());
}
TEST(TestIntCollectionHasMinSizeCalculator,
DoesHaveMinSize_MinSizeAsSidePacket) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestIntCollectionHasMinSizeCalculator"
input_stream: "ITERABLE:input_vector"
input_side_packet: "min_size"
output_stream: "output_vector"
)");
CalculatorRunner runner(node_config);
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
runner.MutableSidePackets()->Index(0) = MakePacket<int>(2);
AddInputVector({1, 2}, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_TRUE(outputs[0].Get<bool>());
AddInputVector({1, 2, 3}, /*timestamp=*/2, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(2, outputs.size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
EXPECT_TRUE(outputs[1].Get<bool>());
}
TEST(TestIntCollectionHasMinSizeCalculator, DoesNotHaveMinSize) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestIntCollectionHasMinSizeCalculator"
input_stream: "ITERABLE:input_vector"
output_stream: "output_vector"
options {
[mediapipe.CollectionHasMinSizeCalculatorOptions.ext] { min_size: 3 }
}
)");
CalculatorRunner runner(node_config);
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
AddInputVector({1}, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_FALSE(outputs[0].Get<bool>());
AddInputVector({1, 2}, /*timestamp=*/2, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(2, outputs.size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
EXPECT_FALSE(outputs[1].Get<bool>());
}
TEST(TestIntCollectionHasMinSizeCalculator,
DoesNotHaveMinSize_MinSizeAsSidePacket) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TestIntCollectionHasMinSizeCalculator"
input_stream: "ITERABLE:input_vector"
input_side_packet: "min_size"
output_stream: "output_vector"
)");
CalculatorRunner runner(node_config);
const std::vector<Packet>& outputs = runner.Outputs().Index(0).packets;
runner.MutableSidePackets()->Index(0) = MakePacket<int>(3);
AddInputVector({1}, /*timestamp=*/1, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
EXPECT_FALSE(outputs[0].Get<bool>());
AddInputVector({1, 2}, /*timestamp=*/2, &runner);
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(2, outputs.size());
EXPECT_EQ(Timestamp(2), outputs[1].Timestamp());
EXPECT_FALSE(outputs[1].Get<bool>());
}
} // namespace mediapipe

View File

@ -17,6 +17,7 @@
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -31,4 +32,8 @@ typedef FilterCollectionCalculator<
FilterLandmarkListCollectionCalculator;
REGISTER_CALCULATOR(FilterLandmarkListCollectionCalculator);
typedef FilterCollectionCalculator<std::vector<::mediapipe::ClassificationList>>
FilterClassificationListCollectionCalculator;
REGISTER_CALCULATOR(FilterClassificationListCollectionCalculator);
} // namespace mediapipe

View File

@ -29,6 +29,7 @@ namespace {
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kRenderScaleTag[] = "RENDER_SCALE";
constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kLandmarkLabel[] = "KEYPOINT";
constexpr int kMaxLandmarkThickness = 18;
@ -71,6 +72,83 @@ void SetColorSizeValueFromZ(float z, float z_min, float z_max,
render_annotation->set_thickness(thickness);
}
template <class LandmarkType>
void AddConnectionToRenderData(const LandmarkType& start,
const LandmarkType& end, int gray_val1,
int gray_val2, float thickness, bool normalized,
RenderData* render_data) {
auto* connection_annotation = render_data->add_render_annotations();
RenderAnnotation::GradientLine* line =
connection_annotation->mutable_gradient_line();
line->set_x_start(start.x());
line->set_y_start(start.y());
line->set_x_end(end.x());
line->set_y_end(end.y());
line->set_normalized(normalized);
line->mutable_color1()->set_r(gray_val1);
line->mutable_color1()->set_g(gray_val1);
line->mutable_color1()->set_b(gray_val1);
line->mutable_color2()->set_r(gray_val2);
line->mutable_color2()->set_g(gray_val2);
line->mutable_color2()->set_b(gray_val2);
connection_annotation->set_thickness(thickness);
}
template <class LandmarkListType, class LandmarkType>
void AddConnectionsWithDepth(const LandmarkListType& landmarks,
const std::vector<int>& landmark_connections,
float thickness, bool normalized, float min_z,
float max_z, RenderData* render_data) {
for (int i = 0; i < landmark_connections.size(); i += 2) {
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
const int gray_val1 =
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
const int gray_val2 =
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
AddConnectionToRenderData<LandmarkType>(ld0, ld1, gray_val1, gray_val2,
thickness, normalized, render_data);
}
}
template <class LandmarkType>
void AddConnectionToRenderData(const LandmarkType& start,
const LandmarkType& end,
const Color& connection_color, float thickness,
bool normalized, RenderData* render_data) {
auto* connection_annotation = render_data->add_render_annotations();
RenderAnnotation::Line* line = connection_annotation->mutable_line();
line->set_x_start(start.x());
line->set_y_start(start.y());
line->set_x_end(end.x());
line->set_y_end(end.y());
line->set_normalized(normalized);
SetColor(connection_annotation, connection_color);
connection_annotation->set_thickness(thickness);
}
template <class LandmarkListType, class LandmarkType>
void AddConnections(const LandmarkListType& landmarks,
const std::vector<int>& landmark_connections,
const Color& connection_color, float thickness,
bool normalized, RenderData* render_data) {
for (int i = 0; i < landmark_connections.size(); i += 2) {
const auto& ld0 = landmarks.landmark(landmark_connections[i]);
const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]);
AddConnectionToRenderData<LandmarkType>(ld0, ld1, connection_color,
thickness, normalized, render_data);
}
}
RenderAnnotation* AddPointRenderData(const Color& landmark_color,
float thickness, RenderData* render_data) {
auto* landmark_data_annotation = render_data->add_render_annotations();
landmark_data_annotation->set_scene_tag(kLandmarkLabel);
SetColor(landmark_data_annotation, landmark_color);
landmark_data_annotation->set_thickness(thickness);
return landmark_data_annotation;
}
} // namespace
// A calculator that converts Landmark proto to RenderData proto for
@ -107,29 +185,6 @@ class LandmarksToRenderDataCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
static void AddConnectionToRenderData(
float start_x, float start_y, float end_x, float end_y,
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
RenderData* render_data);
static void SetRenderAnnotationColorThickness(
const LandmarksToRenderDataCalculatorOptions& options,
RenderAnnotation* render_annotation);
static RenderAnnotation* AddPointRenderData(
const LandmarksToRenderDataCalculatorOptions& options,
RenderData* render_data);
static void AddConnectionToRenderData(
float start_x, float start_y, float end_x, float end_y,
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
int gray_val1, int gray_val2, RenderData* render_data);
template <class LandmarkListType>
void AddConnections(const LandmarkListType& landmarks, bool normalized,
RenderData* render_data);
template <class LandmarkListType>
void AddConnectionsWithDepth(const LandmarkListType& landmarks,
bool normalized, float min_z, float max_z,
RenderData* render_data);
LandmarksToRenderDataCalculatorOptions options_;
};
REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
@ -150,6 +205,9 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
cc->Inputs().Tag(kNormLandmarksTag).Set<NormalizedLandmarkList>();
}
if (cc->Inputs().HasTag(kRenderScaleTag)) {
cc->Inputs().Tag(kRenderScaleTag).Set<float>();
}
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
return ::mediapipe::OkStatus();
}
@ -169,11 +227,26 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
float z_min = 0.f;
float z_max = 0.f;
// Apply scale to `thickness` of rendered landmarks and connections to make
// them bigger when object (e.g. pose, hand or face) is closer/bigger and
// snaller when object is further/smaller.
float thickness = options_.thickness();
if (cc->Inputs().HasTag(kRenderScaleTag)) {
const float render_scale = cc->Inputs().Tag(kRenderScaleTag).Get<float>();
thickness *= render_scale;
}
// Parse landmarks connections to a vector.
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
<< "Number of entries in landmark connections must be a multiple of 2";
std::vector<int> landmark_connections;
for (int i = 0; i < options_.landmark_connections_size(); i += 1) {
landmark_connections.push_back(options_.landmark_connections(i));
}
if (cc->Inputs().HasTag(kLandmarksTag)) {
const LandmarkList& landmarks =
cc->Inputs().Tag(kLandmarksTag).Get<LandmarkList>();
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
<< "Number of entries in landmark connections must be a multiple of 2";
if (visualize_depth) {
GetMinMaxZ<LandmarkList, Landmark>(landmarks, &z_min, &z_max);
}
@ -181,8 +254,8 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
visualize_depth &= ((z_max - z_min) > 1e-3);
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const Landmark& landmark = landmarks.landmark(i);
auto* landmark_data_render =
AddPointRenderData(options_, render_data.get());
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render);
@ -193,19 +266,19 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
landmark_data->set_y(landmark.y());
}
if (visualize_depth) {
AddConnectionsWithDepth<LandmarkList>(landmarks, /*normalized=*/false,
z_min, z_max, render_data.get());
AddConnectionsWithDepth<LandmarkList, Landmark>(
landmarks, landmark_connections, thickness, /*normalized=*/false,
z_min, z_max, render_data.get());
} else {
AddConnections<LandmarkList>(landmarks, /*normalized=*/false,
render_data.get());
AddConnections<LandmarkList, Landmark>(
landmarks, landmark_connections, options_.connection_color(),
thickness, /*normalized=*/false, render_data.get());
}
}
if (cc->Inputs().HasTag(kNormLandmarksTag)) {
const NormalizedLandmarkList& landmarks =
cc->Inputs().Tag(kNormLandmarksTag).Get<NormalizedLandmarkList>();
RET_CHECK_EQ(options_.landmark_connections_size() % 2, 0)
<< "Number of entries in landmark connections must be a multiple of 2";
if (visualize_depth) {
GetMinMaxZ<NormalizedLandmarkList, NormalizedLandmark>(landmarks, &z_min,
&z_max);
@ -214,8 +287,8 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
visualize_depth &= ((z_max - z_min) > 1e-3);
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const NormalizedLandmark& landmark = landmarks.landmark(i);
auto* landmark_data_render =
AddPointRenderData(options_, render_data.get());
auto* landmark_data_render = AddPointRenderData(
options_.landmark_color(), thickness, render_data.get());
if (visualize_depth) {
SetColorSizeValueFromZ(landmark.z(), z_min, z_max,
landmark_data_render);
@ -226,11 +299,13 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
landmark_data->set_y(landmark.y());
}
if (visualize_depth) {
AddConnectionsWithDepth<NormalizedLandmarkList>(
landmarks, /*normalized=*/true, z_min, z_max, render_data.get());
AddConnectionsWithDepth<NormalizedLandmarkList, NormalizedLandmark>(
landmarks, landmark_connections, thickness, /*normalized=*/true,
z_min, z_max, render_data.get());
} else {
AddConnections<NormalizedLandmarkList>(landmarks, /*normalized=*/true,
render_data.get());
AddConnections<NormalizedLandmarkList, NormalizedLandmark>(
landmarks, landmark_connections, options_.connection_color(),
thickness, /*normalized=*/true, render_data.get());
}
}
@ -240,84 +315,4 @@ REGISTER_CALCULATOR(LandmarksToRenderDataCalculator);
return ::mediapipe::OkStatus();
}
template <class LandmarkListType>
void LandmarksToRenderDataCalculator::AddConnectionsWithDepth(
const LandmarkListType& landmarks, bool normalized, float min_z,
float max_z, RenderData* render_data) {
for (int i = 0; i < options_.landmark_connections_size(); i += 2) {
const auto& ld0 = landmarks.landmark(options_.landmark_connections(i));
const auto& ld1 = landmarks.landmark(options_.landmark_connections(i + 1));
const int gray_val1 =
255 - static_cast<int>(Remap(ld0.z(), min_z, max_z, 255));
const int gray_val2 =
255 - static_cast<int>(Remap(ld1.z(), min_z, max_z, 255));
AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_,
normalized, gray_val1, gray_val2, render_data);
}
}
void LandmarksToRenderDataCalculator::AddConnectionToRenderData(
float start_x, float start_y, float end_x, float end_y,
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
int gray_val1, int gray_val2, RenderData* render_data) {
auto* connection_annotation = render_data->add_render_annotations();
RenderAnnotation::GradientLine* line =
connection_annotation->mutable_gradient_line();
line->set_x_start(start_x);
line->set_y_start(start_y);
line->set_x_end(end_x);
line->set_y_end(end_y);
line->set_normalized(normalized);
line->mutable_color1()->set_r(gray_val1);
line->mutable_color1()->set_g(gray_val1);
line->mutable_color1()->set_b(gray_val1);
line->mutable_color2()->set_r(gray_val2);
line->mutable_color2()->set_g(gray_val2);
line->mutable_color2()->set_b(gray_val2);
connection_annotation->set_thickness(options.thickness());
}
template <class LandmarkListType>
void LandmarksToRenderDataCalculator::AddConnections(
const LandmarkListType& landmarks, bool normalized,
RenderData* render_data) {
for (int i = 0; i < options_.landmark_connections_size(); i += 2) {
const auto& ld0 = landmarks.landmark(options_.landmark_connections(i));
const auto& ld1 = landmarks.landmark(options_.landmark_connections(i + 1));
AddConnectionToRenderData(ld0.x(), ld0.y(), ld1.x(), ld1.y(), options_,
normalized, render_data);
}
}
void LandmarksToRenderDataCalculator::AddConnectionToRenderData(
float start_x, float start_y, float end_x, float end_y,
const LandmarksToRenderDataCalculatorOptions& options, bool normalized,
RenderData* render_data) {
auto* connection_annotation = render_data->add_render_annotations();
RenderAnnotation::Line* line = connection_annotation->mutable_line();
line->set_x_start(start_x);
line->set_y_start(start_y);
line->set_x_end(end_x);
line->set_y_end(end_y);
line->set_normalized(normalized);
SetColor(connection_annotation, options.connection_color());
connection_annotation->set_thickness(options.thickness());
}
RenderAnnotation* LandmarksToRenderDataCalculator::AddPointRenderData(
const LandmarksToRenderDataCalculatorOptions& options,
RenderData* render_data) {
auto* landmark_data_annotation = render_data->add_render_annotations();
landmark_data_annotation->set_scene_tag(kLandmarkLabel);
SetRenderAnnotationColorThickness(options, landmark_data_annotation);
return landmark_data_annotation;
}
void LandmarksToRenderDataCalculator::SetRenderAnnotationColorThickness(
const LandmarksToRenderDataCalculatorOptions& options,
RenderAnnotation* render_annotation) {
SetColor(render_annotation, options.landmark_color());
render_annotation->set_thickness(options.thickness());
}
} // namespace mediapipe

View File

@ -276,6 +276,7 @@ TEST_F(PacketLatencyCalculatorTest, DoesNotOutputUntilReferencePacketReceived) {
"delayed_packet_0", Adopt(new double()).At(Timestamp(2))));
// Send a reference packet with timestamp 10 usec.
simulation_clock_->Sleep(absl::Microseconds(1));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"camera_frames", Adopt(new double()).At(Timestamp(10))));
simulation_clock_->Sleep(absl::Microseconds(1));

View File

@ -138,7 +138,7 @@ cc_library(
srcs = ["flow_to_image_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/video:flow_to_image_calculator_cc_proto",
":flow_to_image_calculator_cc_proto",
"//mediapipe/calculators/video/tool:flow_quantizer_model",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_format_cc_proto",
@ -384,20 +384,18 @@ cc_test(
],
)
MEDIAPIPE_DEPS = [
"//mediapipe/calculators/video:box_tracker_calculator",
"//mediapipe/calculators/video:flow_packager_calculator",
"//mediapipe/calculators/video:motion_analysis_calculator",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
]
mediapipe_binary_graph(
name = "parallel_tracker_binarypb",
graph = "testdata/parallel_tracker_graph.pbtxt",
output_name = "testdata/parallel_tracker.binarypb",
visibility = ["//visibility:public"],
deps = MEDIAPIPE_DEPS,
deps = [
":box_tracker_calculator",
":flow_packager_calculator",
":motion_analysis_calculator",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
],
)
mediapipe_binary_graph(
@ -405,7 +403,13 @@ mediapipe_binary_graph(
graph = "testdata/tracker_graph.pbtxt",
output_name = "testdata/tracker.binarypb",
visibility = ["//visibility:public"],
deps = MEDIAPIPE_DEPS,
deps = [
":box_tracker_calculator",
":flow_packager_calculator",
":motion_analysis_calculator",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
],
)
cc_test(

View File

@ -0,0 +1,7 @@
tricorder: {
options: {
builder: {
config: "android_arm64"
}
}
}

View File

@ -95,7 +95,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
const std::vector<FocusPointFrame>& focus_point_frames,
const std::vector<FocusPointFrame>& prior_focus_point_frames,
const int original_width, const int original_height, const int output_width,
const int output_height, std::vector<cv::Mat>* all_xforms) {
const int output_height, std::vector<cv::Mat>* all_transforms) {
RET_CHECK_GE(original_width, output_width);
RET_CHECK_GE(original_height, output_height);
const bool should_solve_x_problem = original_width != output_width;
@ -138,9 +138,10 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
Solver::Summary summary;
Solve(options, &problem_x, &summary);
all_xforms->clear();
Solver::Summary summary_x, summary_y;
Solve(options, &problem_x, &summary_x);
Solve(options, &problem_y, &summary_y);
all_transforms->clear();
for (int i = 0;
i < focus_point_frames.size() + prior_focus_point_frames.size(); i++) {
// Code below assigns values into an affine model, defined as:
@ -160,7 +161,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem(
yb_, yc_, yd_, yk_);
transform.at<float>(1, 2) = delta;
}
all_xforms->push_back(transform);
all_transforms->push_back(transform);
}
return mediapipe::OkStatus();
}

View File

@ -40,14 +40,14 @@ class PolynomialRegressionPathSolver {
// Given a series of focus points on frames, uses polynomial regression to
// compute a best guess of a 1D camera movement trajectory along x-axis and
// y-axis, such that focus points can be preserved as much as possible. The
// returned |all_xforms| hold the camera location at each timestamp
// returned |all_transforms| hold the camera location at each timestamp
// corresponding to each input frame.
::mediapipe::Status ComputeCameraPath(
const std::vector<FocusPointFrame>& focus_point_frames,
const std::vector<FocusPointFrame>& prior_focus_point_frames,
const int original_width, const int original_height,
const int output_width, const int output_height,
std::vector<cv::Mat>* all_xforms);
std::vector<cv::Mat>* all_transforms);
private:
// Adds a new cost function, constructed using |in| and |out|, into |problem|.

View File

@ -24,3 +24,17 @@ cc_binary(
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
],
)
cc_binary(
name = "hair_segmentation_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
] + select({
"//mediapipe/gpu:disable_gpu": [
"//mediapipe/graphs/hair_segmentation:desktop_calculators",
],
"//conditions:default": [
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
],
}),
)

View File

@ -361,6 +361,7 @@ cc_library(
"//mediapipe/framework:mediapipe_options_cc_proto",
"//mediapipe/framework:packet_generator_cc_proto",
"//mediapipe/framework:status_handler_cc_proto",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_util",

View File

@ -84,7 +84,7 @@ class CalculatorContract {
return *output_side_packets_;
}
// Set this Node's default InputStreamHandler.
// Specifies the preferred InputStreamHandler for this Node.
// If there is an InputStreamHandler specified in the graph (.pbtxt) for this
// Node, then the graph's InputStreamHandler will take priority.
void SetInputStreamHandler(const std::string& name) {
@ -104,6 +104,29 @@ class CalculatorContract {
return input_stream_handler_options_;
}
// The next few methods are concerned with timestamp bound propagation
// (see scheduling_sync.md#input-policies). Every calculator that processes
// live inputs should specify either ProcessTimestampBounds or
// TimestampOffset. Calculators that produce output at the same timestamp as
// the input, or with a fixed offset, should declare this fact using
// SetTimestampOffset. Calculators that require custom timestamp bound
// calculations should use SetProcessTimestampBounds.
// When true, Process is called for every new timestamp bound, with or without
// new packets. A call to Process with only an input timestamp bound is
// normally used to compute a new output timestamp bound.
void SetProcessTimestampBounds(bool process_timestamps) {
process_timestamps_ = process_timestamps;
}
bool GetProcessTimestampBounds() const { return process_timestamps_; }
// Specifies the maximum difference between input and output timestamps.
// When specified, the mediapipe framework automatically computes output
// timestamp bounds based on input timestamps. The special value
// TimestampDiff::Unset disables the timestamp offset.
void SetTimestampOffset(TimestampDiff offset) { timestamp_offset_ = offset; }
TimestampDiff GetTimestampOffset() const { return timestamp_offset_; }
class GraphServiceRequest {
public:
// APIs that should be used by calculators.
@ -147,6 +170,8 @@ class CalculatorContract {
MediaPipeOptions input_stream_handler_options_;
std::string node_name_;
std::map<std::string, GraphServiceRequest> service_requests_;
bool process_timestamps_ = false;
TimestampDiff timestamp_offset_ = TimestampDiff::Unset();
};
} // namespace mediapipe

View File

@ -143,7 +143,7 @@ class CalculatorGraph {
const std::string& graph_type = "",
const Subgraph::SubgraphOptions* options = nullptr);
// Resturns the canonicalized CalculatorGraphConfig for this graph.
// Returns the canonicalized CalculatorGraphConfig for this graph.
const CalculatorGraphConfig& Config() const {
return validated_graph_->Config();
}

View File

@ -31,6 +31,17 @@ namespace {
typedef std::function<::mediapipe::Status(CalculatorContext* cc)>
CalculatorContextFunction;
// Returns the contents of a set of Packets.
// The contents must be copyable.
template <typename T>
std::vector<T> GetContents(const std::vector<Packet>& packets) {
std::vector<T> result;
for (Packet p : packets) {
result.push_back(p.Get<T>());
}
return result;
}
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
@ -671,9 +682,9 @@ REGISTER_CALCULATOR(BoundToPacketCalculator);
// A Calculator that produces packets at timestamps beyond the input timestamp.
class FuturePacketCalculator : public CalculatorBase {
public:
static constexpr int64 kOutputFutureMicros = 3;
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
@ -742,9 +753,8 @@ TEST(CalculatorGraphBoundsTest, OffsetBoundPropagation) {
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that bounds changes alone do not invoke Process.
// Note: Bounds changes alone will invoke Process eventually
// when SetOffset is cleared, see: go/mediapipe-realtime-graph.
// Shows that timestamp bounds changes alone do not invoke Process,
// without SetProcessTimestampBounds(true).
TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
// OffsetBoundCalculator produces only timestamp bounds.
// The BoundToPacketCalculator delivers an output packet whenever the
@ -753,8 +763,13 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
node {
calculator: 'OffsetBoundCalculator'
calculator: 'FuturePacketCalculator'
input_stream: 'input'
output_stream: 'input_2'
}
node {
calculator: 'OffsetBoundCalculator'
input_stream: 'input_2'
output_stream: 'bounds'
}
node {
@ -778,6 +793,7 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
for (int i = 0; i < kNumInputs; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
// No packets arrive, because updated timestamp bounds do not invoke
@ -1104,5 +1120,254 @@ TEST(CalculatorGraphBoundsTest, BoundsForEmptyInputs_SyncSets) {
)");
}
// A Calculator that produces a packet for each timestamp bounds update.
class ProcessBoundToPacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
}
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
cc->Outputs().Index(i).Set<Timestamp>();
}
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
cc->SetProcessTimestampBounds(true);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
Timestamp t = cc->Inputs().Index(i).Value().Timestamp();
if (t == cc->InputTimestamp() &&
t >= cc->Outputs().Index(i).NextTimestampBound()) {
cc->Outputs().Index(i).Add(new auto(t), t);
}
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(ProcessBoundToPacketCalculator);
// A Calculator that passes through each packet and timestamp immediately.
class ImmediatePassthroughCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).SetAny();
}
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
}
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
cc->SetProcessTimestampBounds(true);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
for (int i = 0; i < cc->Outputs().NumEntries(); ++i) {
if (!cc->Inputs().Index(i).IsEmpty()) {
cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value());
} else {
Timestamp input_bound =
cc->Inputs().Index(i).Value().Timestamp().NextAllowedInStream();
if (cc->Outputs().Index(i).NextTimestampBound() < input_bound) {
cc->Outputs().Index(i).SetNextTimestampBound(input_bound);
}
}
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(ImmediatePassthroughCalculator);
// Shows that Process is called for input-sets without input packets.
void TestProcessForEmptyInputs(const std::string& input_stream_handler) {
// FuturePacketCalculator and OffsetBoundCalculator produce only ts bounds,
// The ProcessBoundToPacketCalculator has SetProcessTimestampBounds(true),
// and produces an output packet for every timestamp bound update.
std::string config_str = R"(
input_stream: 'input'
node {
calculator: 'FuturePacketCalculator'
input_stream: 'input'
output_stream: 'futures'
}
node {
calculator: 'OffsetBoundCalculator'
input_stream: 'futures'
output_stream: 'bounds'
}
node {
calculator: 'ProcessBoundToPacketCalculator'
input_stream: 'bounds'
output_stream: 'bounds_ts'
input_stream_handler { $input_stream_handler }
}
)";
absl::StrReplaceAll({{"$input_stream_handler", input_stream_handler}},
&config_str);
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
std::vector<Packet> input_ts_packets;
std::vector<Packet> bounds_ts_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("bounds_ts", [&](const Packet& p) {
bounds_ts_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets into the graph.
constexpr int kFutureMicros = FuturePacketCalculator::kOutputFutureMicros;
Packet p;
p = MakePacket<int>(33).At(Timestamp(0));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
p = MakePacket<int>(33).At(Timestamp(10));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
p = MakePacket<int>(33).At(Timestamp(20));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
p = MakePacket<int>(33).At(Timestamp(30));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets arrive.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(bounds_ts_packets.size(), 4);
std::vector<Timestamp> expected = {
Timestamp(0 + kFutureMicros), Timestamp(10 + kFutureMicros),
Timestamp(20 + kFutureMicros), Timestamp(30 + kFutureMicros)};
EXPECT_EQ(GetContents<Timestamp>(bounds_ts_packets), expected);
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that Process is called for input-sets without input packets
// using an DefaultInputStreamHandler.
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Default) {
TestProcessForEmptyInputs(R"(
input_stream_handler: "DefaultInputStreamHandler")");
}
// Shows that Process is called for input-sets without input packets
// using an ImmediateInputStreamHandler.
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Immediate) {
TestProcessForEmptyInputs(R"(
input_stream_handler: "ImmediateInputStreamHandler")");
}
// Shows that Process is called for input-sets without input packets
// using a SyncSetInputStreamHandler with a single sync-set.
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_SyncSet) {
TestProcessForEmptyInputs(R"(
input_stream_handler: "SyncSetInputStreamHandler")");
}
// Shows that Process is called for input-sets without input packets
// using a SyncSetInputStreamHandler with multiple sync-sets.
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_SyncSets) {
TestProcessForEmptyInputs(R"(
input_stream_handler: "SyncSetInputStreamHandler"
options {
[mediapipe.SyncSetInputStreamHandlerOptions.ext] {
sync_set { tag_index: ":0" }
}
}
)");
}
// Demonstrates the functionality of an "ImmediatePassthroughCalculator".
// The ImmediatePassthroughCalculator simply relays each input packet to
// the corresponding output stream. ProcessTimestampBounds is needed to
// relay timestamp bounds as well as packets.
TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) {
// OffsetBoundCalculator produces timestamp bounds.
// ImmediatePassthroughCalculator relays packets and bounds.
// ProcessBoundToPacketCalculator reports packets and bounds as packets.
std::string config_str = R"(
input_stream: "input_0"
input_stream: "input_1"
node {
calculator: "OffsetBoundCalculator"
input_stream: "input_1"
output_stream: "bound_1"
}
node {
calculator: "ImmediatePassthroughCalculator"
input_stream: "input_0"
input_stream: "bound_1"
output_stream: "same_0"
output_stream: "same_1"
}
node {
calculator: "ProcessBoundToPacketCalculator"
input_stream: "same_0"
input_stream: "same_1"
output_stream: "output_0"
output_stream: "output_1"
}
)";
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
std::vector<Packet> output_0_packets;
std::vector<Packet> output_1_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) {
output_0_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.ObserveOutputStream("output_1", [&](const Packet& p) {
output_1_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets to input_0.
for (int i = 0; i < 4; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i * 10));
MP_ASSERT_OK(graph.AddPacketToInputStream("input_0", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
// Packets arrive.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(output_0_packets.size(), 4);
EXPECT_EQ(output_1_packets.size(), 0);
std::vector<Timestamp> expected = //
{Timestamp(0), Timestamp(10), Timestamp(20), Timestamp(30)};
EXPECT_EQ(GetContents<Timestamp>(output_0_packets), expected);
// Add two timestamp bounds to bound_1.
for (int i = 0; i < 2; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(10 + i * 10));
MP_ASSERT_OK(graph.AddPacketToInputStream("input_1", p));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
// Bounds arrive.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(output_0_packets.size(), 4);
EXPECT_EQ(output_1_packets.size(), 2);
expected = //
{Timestamp(10), Timestamp(20)};
EXPECT_EQ(GetContents<Timestamp>(output_1_packets), expected);
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -97,6 +97,7 @@ Timestamp CalculatorNode::SourceProcessOrder(
const NodeTypeInfo& node_type_info =
validated_graph_->CalculatorInfos()[node_id_];
const CalculatorContract& contract = node_type_info.Contract();
uses_gpu_ =
node_type_info.InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
@ -147,6 +148,14 @@ Timestamp CalculatorNode::SourceProcessOrder(
use_calc_specified ? handler_config : node_config.input_stream_handler(),
node_type_info.InputStreamTypes()));
for (auto& stream : output_stream_handler_->OutputStreams()) {
stream->Spec()->offset_enabled =
(contract.GetTimestampOffset() != TimestampDiff::Unset());
stream->Spec()->offset = contract.GetTimestampOffset();
}
input_stream_handler_->SetProcessTimestampBounds(
contract.GetProcessTimestampBounds());
return InitializeInputStreams(input_stream_managers, output_stream_managers);
}

View File

@ -18,6 +18,10 @@ namespace mediapipe {
namespace {
// List of namespaces that can register calculators inside the namespace
// and still refer to them using an unqualified name. This whitelist
// is meant to facilitate migration from unqualified to fully qualified
// calculator names.
constexpr char const* kTopNamespaces[] = {
"mediapipe",
};

View File

@ -49,3 +49,10 @@ mediapipe_cc_proto_library(
visibility = ["//visibility:public"],
deps = [":rasterization_proto"],
)
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -16,6 +16,9 @@ syntax = "proto2";
package mediapipe;
option java_package = "com.google.mediapipe.formats.annotation.proto";
option java_outer_classname = "RasterizationProto";
// A Region can be represented in each frame as a set of scanlines
// (compressed RLE, similar to rasterization of polygons).
// For each scanline with y-coordinate y, we save (possibly multiple) intervals

View File

@ -23,6 +23,9 @@ package mediapipe;
import "mediapipe/framework/formats/annotation/rasterization.proto";
option java_package = "com.google.mediapipe.formats.proto";
option java_outer_classname = "LocationDataProto";
message LocationData {
// The supported formats for representing location data. A single location
// must store its data in exactly one way.

View File

@ -22,6 +22,8 @@
namespace mediapipe {
using SyncSet = InputStreamHandler::SyncSet;
::mediapipe::Status InputStreamHandler::InitializeInputStreamManagers(
InputStreamManager* flat_input_stream_managers) {
for (CollectionItemId id = input_stream_managers_.BeginId();
@ -300,4 +302,92 @@ void InputStreamHandler::SetLatePreparation(bool late_preparation) {
late_preparation_ = late_preparation;
}
SyncSet::SyncSet(InputStreamHandler* input_stream_handler,
std::vector<CollectionItemId> stream_ids)
: input_stream_handler_(input_stream_handler),
stream_ids_(std::move(stream_ids)) {}
NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) {
Timestamp min_bound = Timestamp::Done();
Timestamp min_packet = Timestamp::Done();
for (CollectionItemId id : stream_ids_) {
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
bool empty;
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
if (empty) {
min_bound = std::min(min_bound, stream_timestamp);
} else {
min_packet = std::min(min_packet, stream_timestamp);
}
}
*min_stream_timestamp = std::min(min_packet, min_bound);
if (*min_stream_timestamp == Timestamp::Done()) {
last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream();
return NodeReadiness::kReadyForClose;
}
if (!input_stream_handler_->process_timestamps_) {
// Only an input_ts with packets can be processed.
// Note that (min_bound - 1) is the highest fully settled timestamp.
if (min_bound > min_packet) {
last_processed_ts_ = *min_stream_timestamp;
return NodeReadiness::kReadyForProcess;
}
} else {
// Any unprocessed input_ts can be processed.
// Note that (min_bound - 1) is the highest fully settled timestamp.
Timestamp input_timestamp =
std::min(min_packet, min_bound.PreviousAllowedInStream());
if (input_timestamp >
std::max(last_processed_ts_, Timestamp::Unstarted())) {
*min_stream_timestamp = input_timestamp;
last_processed_ts_ = input_timestamp;
return NodeReadiness::kReadyForProcess;
}
}
return NodeReadiness::kNotReady;
}
Timestamp SyncSet::LastProcessed() const { return last_processed_ts_; }
Timestamp SyncSet::MinPacketTimestamp() const {
Timestamp result = Timestamp::Done();
for (CollectionItemId id : stream_ids_) {
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
bool empty;
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
if (!empty) {
result = std::min(result, stream_timestamp);
}
}
return result;
}
void SyncSet::FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) {
CHECK(input_timestamp.IsAllowedInStream());
CHECK(input_set);
for (CollectionItemId id : stream_ids_) {
const auto& stream = input_stream_handler_->input_stream_managers_.Get(id);
int num_packets_dropped = 0;
bool stream_is_done = false;
Packet current_packet = stream->PopPacketAtTimestamp(
input_timestamp, &num_packets_dropped, &stream_is_done);
CHECK_EQ(num_packets_dropped, 0)
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, stream->Name());
input_stream_handler_->AddPacketToShard(
&input_set->Get(id), std::move(current_packet), stream_is_done);
}
}
void SyncSet::FillInputBounds(InputStreamShardSet* input_set) {
for (CollectionItemId id : stream_ids_) {
const auto* stream = input_stream_handler_->input_stream_managers_.Get(id);
Timestamp bound = stream->MinTimestampOrBound(nullptr);
input_stream_handler_->AddPacketToShard(
&input_set->Get(id), Packet().At(bound.PreviousAllowedInStream()),
bound == Timestamp::Done());
}
}
} // namespace mediapipe

View File

@ -74,9 +74,7 @@ class InputStreamHandler {
: input_stream_managers_(std::move(tag_map)),
calculator_context_manager_(calculator_context_manager),
options_(options),
calculator_run_in_parallel_(calculator_run_in_parallel),
late_preparation_(false),
batch_size_(1) {}
calculator_run_in_parallel_(calculator_run_in_parallel) {}
virtual ~InputStreamHandler() = default;
@ -174,6 +172,57 @@ class InputStreamHandler {
return unset_header_count_.load(std::memory_order_relaxed);
}
// When true, Calculator::Process is called for any increase in the
// timestamp bound, whether or not any packets are available.
// Calculator::Process is called when the minimum timestamp bound
// increases for any synchronized set of input streams.
// DefaultInputStreamHandler groups all input streams into a single set.
// ImmediateInputStreamHandler treats each input stream as a separate set.
void SetProcessTimestampBounds(bool process_ts) {
process_timestamps_ = process_ts;
}
// When true, Calculator::Process is called for every input timestamp bound.
bool ProcessTimestampBounds() { return process_timestamps_; }
// A helper class to build input packet sets for a certain set of streams.
//
// ReadyForProcess requires all of the streams to be fully determined
// at the same input-timestamp.
// This is the readiness policy for all streams in DefaultInputStreamHandler.
// It is also the policy for each sync-set in SyncSetInputStreamHandler.
// It is also the policy for each input-stream in ImmediateInputStreamHandler.
//
// If ProcessTimestampBounds() is set, then a fully determined input timestamp
// with only empty input packets will qualify as ReadyForProcess.
class SyncSet {
public:
// Creates a SyncSet for a certain set of streams, |stream_ids|.
SyncSet(InputStreamHandler* input_stream_handler,
std::vector<CollectionItemId> stream_ids);
// Answers whether this stream is ready for Process or Close.
NodeReadiness GetReadiness(Timestamp* min_stream_timestamp);
// Returns the latest timestamp returned for processing.
Timestamp LastProcessed() const;
// The earliest available packet timestamp, or Timestamp::Done.
Timestamp MinPacketTimestamp() const;
// Moves packets from all input streams to the input_set.
void FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set);
// Copies timestamp bounds from all input streams to the input_set.
void FillInputBounds(InputStreamShardSet* input_set);
private:
InputStreamHandler* input_stream_handler_;
std::vector<CollectionItemId> stream_ids_;
Timestamp last_processed_ts_ = Timestamp::Unset();
};
protected:
typedef internal::Collection<InputStreamManager*> InputStreamManagerSet;
@ -240,11 +289,14 @@ class InputStreamHandler {
// The variable is set to false by default. A subclass should set it to true
// with SetLatePreparation(true) in the constructor if the input sets need to
// be filled in ProcessNode().
bool late_preparation_;
bool late_preparation_ = false;
// Determines how many sets of input packets are collected before a
// CalculatorNode is scheduled.
int batch_size_;
int batch_size_ = 1;
// When true, any increase in timestamp bound invokes Calculator::Process.
bool process_timestamps_ = false;
// A callback to notify the observer when all the input stream headers
// (excluding headers of back edges) become available.

View File

@ -107,6 +107,9 @@ CalculatorContext* LegacyCalculatorSupport::Scoped<CalculatorContext>::current_;
template <>
CalculatorContract*
LegacyCalculatorSupport::Scoped<CalculatorContract>::current_;
#elif _MSC_VER
// MSVC interprets these declarations as definitions and during linking it
// generates an error about multiple definitions of current_.
#else
template <>
thread_local CalculatorContext*

View File

@ -46,6 +46,7 @@ class OutputStreamHandler {
// ids of upstream sources that affect it.
typedef std::unordered_map<std::string, std::unordered_set<int>>
OutputStreamToSourcesMap;
typedef internal::Collection<OutputStreamManager*> OutputStreamManagerSet;
// The constructor of the OutputStreamHandler takes four arguments.
// The tag_map argument holds the information needed for tag/index retrieval
@ -119,9 +120,11 @@ class OutputStreamHandler {
// collection for debugging purpose.
std::string FirstStreamName() const;
protected:
typedef internal::Collection<OutputStreamManager*> OutputStreamManagerSet;
const OutputStreamManagerSet& OutputStreams() {
return output_stream_managers_;
}
protected:
// Checks if the given input bound should be propagated or not. If any output
// streams with OffsetEnabled() need to have the timestamp bounds updated,
// then propagates the timestamp bounds of all output streams with

View File

@ -27,6 +27,9 @@ class OutputStreamPoller {
OutputStreamPoller(const OutputStreamPoller&) = delete;
OutputStreamPoller& operator=(const OutputStreamPoller&) = delete;
OutputStreamPoller(OutputStreamPoller&&) = default;
// Move assignment needs to be explicitly defaulted to allow ASSIGN_OR_RETURN
// on `StatusOr<OutputStreamPoller>`.
OutputStreamPoller& operator=(OutputStreamPoller&&) = default;
// Resets OutputStramPollerImpl and cleans the internal packet queue.
void Reset() {

View File

@ -0,0 +1,97 @@
graph_trace: {
calculator_name : ["ACalculator", "BCalculator"]
stream_name : [ "", "input1", "a_b"]
base_time : 0
base_timestamp : 100
# Fire off three input packets and have them spend time in Calculator A.
# Drop the middle packet.
calculator_trace: {
node_id: -1
input_timestamp: 100
event_type : PROCESS
finish_time : 1000
output_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: -1
input_timestamp: 101
event_type : PROCESS
finish_time : 2000
output_trace: {
packet_timestamp: 101
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: 0
input_timestamp: 100
event_type : PROCESS
start_time : 1200 # 200 after initial input (emits at 1000)
finish_time : 1500 # Speed to delivery is 500 (1500 - 1000)
input_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: 0
input_timestamp: 101
event_type : PROCESS
start_time : 2100 # 100 after initial input (emits at 2000)
finish_time : 2500 # Speed to delivery is 500 (2500 - 2000)
input_trace: {
packet_timestamp: 101
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: 1
input_timestamp: 100
event_type : PROCESS
start_time : 1600 # 600 after the initial input (emits at 1000)
finish_time : 2000 # Speed to delivery is 1000 (2000 - 1000)
input_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: 1
input_timestamp: 101
event_type : PROCESS
start_time : 2900 # 700 after the initial input (emits at 2000)
finish_time : 3100 # Speed to delivery is 1000 (3000 - 2000)
input_trace: {
packet_timestamp: 101
stream_id : 1
}
thread_id : 1
}
}
config: {
node: {
name: "ACalculator"
calculator: "ACalculator"
input_stream: "input1"
output_stream: "a_b"
}
node: {
name: "BCalculator"
calculator: "BCalculator"
input_stream: "a_b"
}
}

View File

@ -0,0 +1,122 @@
graph_trace: {
calculator_name : ["ACalculator", "BCalculator"]
stream_name : [ "", "input1"]
base_time : 0
base_timestamp : 100
# Fire off three input packets and have them spend time in Calculator A.
# Drop the middle packet.
calculator_trace: {
node_id: -1
input_timestamp: 100
event_type : PROCESS
finish_time : 1000
output_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: -1
input_timestamp: 101
event_type : PROCESS
finish_time : 2000
output_trace: {
packet_timestamp: 101
stream_id : 1
}
thread_id : 1
}
calculator_trace: {
node_id: -1
input_timestamp: 102
event_type : PROCESS
finish_time : 3000
output_trace: {
packet_timestamp: 102
stream_id : 1
}
thread_id : 1
}
# First event is disconnected. We'll see the output_trace later.
calculator_trace: {
node_id: 0
input_timestamp: 100
event_type : PROCESS
start_time : 1100
input_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
# # We're going to drop this packet.
calculator_trace: {
node_id: 0
input_timestamp: 101
event_type : PROCESS
start_time : 2100
input_trace: {
packet_timestamp: 101
stream_id : 1
}
thread_id : 1
}
# # Here's that matching output trace.
calculator_trace: {
node_id: 0
input_timestamp: 100
event_type : PROCESS
finish_time : 1500
input_trace: {
packet_timestamp: 100
stream_id : 1
}
thread_id : 1
}
# Third packet is processed all at the same time.
calculator_trace: {
node_id: 0
input_timestamp: 102
event_type : PROCESS
start_time : 3100
finish_time : 3600
input_trace: {
packet_timestamp: 102
stream_id : 1
}
thread_id : 1
}
# A second calculator will process an input in order to affect the
# time_percent.
calculator_trace: {
node_id: 1
input_timestamp: 102
event_type : PROCESS
start_time : 3200
finish_time : 3500
input_trace: {
packet_timestamp: 102
stream_id : 1
}
thread_id : 1
}
}
config: {
node: {
name: "ACalculator"
calculator: "ACalculator"
input_stream: "input1"
}
node: {
name: "BCalculator"
calculator: "BCalculator"
input_stream: "input1"
}
}

View File

@ -25,7 +25,11 @@
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status.h"
#ifdef __APPLE__
#define AUTORELEASEPOOL @autoreleasepool
#else
#define AUTORELEASEPOOL
#endif // __APPLE__
namespace mediapipe {
namespace internal {

View File

@ -17,16 +17,28 @@
#include <algorithm>
#include "absl/strings/substitute.h"
#include "mediapipe/framework/input_stream_handler.h"
namespace mediapipe {
REGISTER_INPUT_STREAM_HANDLER(DefaultInputStreamHandler);
// Returns all CollectionItemId's for a Collection TagMap.
std::vector<CollectionItemId> GetIds(
const std::shared_ptr<tool::TagMap>& tag_map) {
std::vector<CollectionItemId> result;
for (auto id = tag_map->BeginId(); id < tag_map->EndId(); ++id) {
result.push_back(id);
}
return result;
}
DefaultInputStreamHandler::DefaultInputStreamHandler(
std::shared_ptr<tool::TagMap> tag_map, CalculatorContextManager* cc_manager,
const MediaPipeOptions& options, bool calculator_run_in_parallel)
: InputStreamHandler(std::move(tag_map), cc_manager, options,
calculator_run_in_parallel) {
calculator_run_in_parallel),
sync_set_(this, GetIds(input_stream_managers_.TagMap())) {
if (options.HasExtension(DefaultInputStreamHandlerOptions::ext)) {
SetBatchSize(options.GetExtension(DefaultInputStreamHandlerOptions::ext)
.batch_size());
@ -35,47 +47,12 @@ DefaultInputStreamHandler::DefaultInputStreamHandler(
NodeReadiness DefaultInputStreamHandler::GetNodeReadiness(
Timestamp* min_stream_timestamp) {
DCHECK(min_stream_timestamp);
*min_stream_timestamp = Timestamp::Done();
Timestamp min_bound = Timestamp::Done();
for (const auto& stream : input_stream_managers_) {
bool empty;
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
if (empty) {
min_bound = std::min(min_bound, stream_timestamp);
}
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
}
if (*min_stream_timestamp == Timestamp::Done()) {
return NodeReadiness::kReadyForClose;
}
if (min_bound > *min_stream_timestamp) {
return NodeReadiness::kReadyForProcess;
}
CHECK_EQ(min_bound, *min_stream_timestamp);
return NodeReadiness::kNotReady;
return sync_set_.GetReadiness(min_stream_timestamp);
}
void DefaultInputStreamHandler::FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) {
CHECK(input_timestamp.IsAllowedInStream());
CHECK(input_set);
for (CollectionItemId id = input_stream_managers_.BeginId();
id < input_stream_managers_.EndId(); ++id) {
auto& stream = input_stream_managers_.Get(id);
int num_packets_dropped = 0;
bool stream_is_done = false;
Packet current_packet = stream->PopPacketAtTimestamp(
input_timestamp, &num_packets_dropped, &stream_is_done);
CHECK_EQ(num_packets_dropped, 0)
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, stream->Name());
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
stream_is_done);
}
sync_set_.FillInputSet(input_timestamp, input_set);
}
} // namespace mediapipe

View File

@ -45,6 +45,9 @@ class DefaultInputStreamHandler : public InputStreamHandler {
// Only invoked when associated GetNodeReadiness() returned kReadyForProcess.
void FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) override;
// The packet-set builder.
SyncSet sync_set_;
};
} // namespace mediapipe

View File

@ -19,6 +19,8 @@
namespace mediapipe {
using SyncSet = InputStreamHandler::SyncSet;
// An input stream handler that delivers input packets to the Calculator
// immediately, with no dependency between input streams. It also invokes
// Calculator::Process when any input stream becomes done.
@ -47,8 +49,11 @@ class ImmediateInputStreamHandler : public InputStreamHandler {
void FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) override;
// Record of the last reported timestamp bound for each input stream.
mediapipe::internal::Collection<Timestamp> timestamp_bounds_;
absl::Mutex mutex_;
// The packet-set builder for each input stream.
std::vector<SyncSet> sync_sets_ ABSL_GUARDED_BY(mutex_);
// The input timestamp for each kReadyForProcess input stream.
std::vector<Timestamp> ready_timestamps_ ABSL_GUARDED_BY(mutex_);
};
REGISTER_INPUT_STREAM_HANDLER(ImmediateInputStreamHandler);
@ -57,31 +62,47 @@ ImmediateInputStreamHandler::ImmediateInputStreamHandler(
CalculatorContextManager* calculator_context_manager,
const MediaPipeOptions& options, bool calculator_run_in_parallel)
: InputStreamHandler(tag_map, calculator_context_manager, options,
calculator_run_in_parallel),
timestamp_bounds_(std::move(tag_map)) {}
calculator_run_in_parallel) {
for (auto id = tag_map->BeginId(); id < tag_map->EndId(); ++id) {
sync_sets_.emplace_back(this, std::vector<CollectionItemId>{id});
ready_timestamps_.push_back(Timestamp::Unset());
}
}
NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
Timestamp* min_stream_timestamp) {
*min_stream_timestamp = Timestamp::Done();
absl::MutexLock lock(&mutex_);
Timestamp input_timestamp = Timestamp::Done();
Timestamp min_bound = Timestamp::Done();
bool stream_became_done = false;
for (CollectionItemId i = input_stream_managers_.BeginId();
i < input_stream_managers_.EndId(); ++i) {
const auto& stream = input_stream_managers_.Get(i);
bool empty;
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
if (!empty) {
input_timestamp = std::min(input_timestamp, stream_timestamp);
for (int i = 0; i < sync_sets_.size(); ++i) {
if (ready_timestamps_[i] > Timestamp::Unset()) {
min_bound = std::min(min_bound, ready_timestamps_[i]);
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
continue;
}
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
if (stream_timestamp != timestamp_bounds_.Get(i)) {
if (stream_timestamp == Timestamp::Done()) {
Timestamp prev_ts = sync_sets_[i].LastProcessed();
Timestamp stream_ts;
NodeReadiness readiness = sync_sets_[i].GetReadiness(&stream_ts);
min_bound = std::min(min_bound, stream_ts);
if (readiness == NodeReadiness::kReadyForProcess) {
ready_timestamps_[i] = stream_ts;
input_timestamp = std::min(input_timestamp, stream_ts);
} else if (readiness == NodeReadiness::kReadyForClose) {
CHECK_EQ(stream_ts, Timestamp::Done());
if (ProcessTimestampBounds()) {
// With kReadyForClose, the timestamp-bound Done is returned.
// This bound is processed using the preceding input-timestamp.
// TODO: Make all InputStreamHandlers process Done() like this.
ready_timestamps_[i] = stream_ts.PreviousAllowedInStream();
input_timestamp = std::min(input_timestamp, ready_timestamps_[i]);
} else if (prev_ts < Timestamp::Done()) {
stream_became_done = true;
ready_timestamps_[i] = Timestamp::Done();
}
timestamp_bounds_.Get(i) = stream_timestamp;
}
}
*min_stream_timestamp = min_bound;
if (*min_stream_timestamp == Timestamp::Done()) {
return NodeReadiness::kReadyForClose;
@ -94,6 +115,8 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
}
if (stream_became_done) {
// The stream_became_done logic is kept for backward compatibility.
// Note that the minimum bound is returned in min_stream_timestamp.
return NodeReadiness::kReadyForProcess;
}
@ -102,23 +125,13 @@ NodeReadiness ImmediateInputStreamHandler::GetNodeReadiness(
void ImmediateInputStreamHandler::FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) {
CHECK(input_timestamp.IsAllowedInStream());
CHECK(input_set);
for (CollectionItemId id = input_stream_managers_.BeginId();
id < input_stream_managers_.EndId(); ++id) {
auto& stream = input_stream_managers_.Get(id);
if (stream->QueueHead().Timestamp() == input_timestamp) {
int num_packets_dropped = 0;
bool stream_is_done = false;
Packet current_packet = stream->PopPacketAtTimestamp(
input_timestamp, &num_packets_dropped, &stream_is_done);
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
stream_is_done);
absl::MutexLock lock(&mutex_);
for (int i = 0; i < sync_sets_.size(); ++i) {
if (ready_timestamps_[i] == input_timestamp) {
sync_sets_[i].FillInputSet(input_timestamp, input_set);
ready_timestamps_[i] = Timestamp::Unset();
} else {
Timestamp bound = stream->MinTimestampOrBound(nullptr);
AddPacketToShard(&input_set->Get(id),
Packet().At(bound.PreviousAllowedInStream()),
bound == Timestamp::Done());
sync_sets_[i].FillInputBounds(input_set);
}
}
}

View File

@ -17,6 +17,7 @@
// TODO: Move protos in another CL after the C++ code migration.
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/input_stream_handler.h"
#include "mediapipe/framework/mediapipe_options.pb.h"
#include "mediapipe/framework/packet_set.h"
@ -69,7 +70,7 @@ class SyncSetInputStreamHandler : public InputStreamHandler {
private:
absl::Mutex mutex_;
// The ids of each set of inputs.
std::vector<std::vector<CollectionItemId>> sync_sets_ ABSL_GUARDED_BY(mutex_);
std::vector<SyncSet> sync_sets_ ABSL_GUARDED_BY(mutex_);
// The index of the ready sync set. A value of -1 indicates that no
// sync sets are ready.
int ready_sync_set_index_ ABSL_GUARDED_BY(mutex_) = -1;
@ -98,7 +99,7 @@ void SyncSetInputStreamHandler::PrepareForRun(
sync_sets_.clear();
std::set<CollectionItemId> used_ids;
for (const auto& sync_set : handler_options.sync_set()) {
sync_sets_.emplace_back();
std::vector<CollectionItemId> stream_ids;
CHECK_LT(0, sync_set.tag_index_size());
for (const auto& tag_index : sync_set.tag_index()) {
std::string tag;
@ -109,8 +110,9 @@ void SyncSetInputStreamHandler::PrepareForRun(
CHECK(!::mediapipe::ContainsKey(used_ids, id))
<< "stream \"" << tag_index << "\" is in more than one sync set.";
used_ids.insert(id);
sync_sets_.back().push_back(id);
stream_ids.push_back(id);
}
sync_sets_.emplace_back(this, std::move(stream_ids));
}
std::vector<CollectionItemId> remaining_ids;
for (CollectionItemId id = input_stream_managers_.BeginId();
@ -120,7 +122,7 @@ void SyncSetInputStreamHandler::PrepareForRun(
}
}
if (!remaining_ids.empty()) {
sync_sets_.push_back(std::move(remaining_ids));
sync_sets_.emplace_back(this, std::move(remaining_ids));
}
ready_sync_set_index_ = -1;
ready_timestamp_ = Timestamp::Done();
@ -137,24 +139,14 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
absl::MutexLock lock(&mutex_);
if (ready_sync_set_index_ >= 0) {
*min_stream_timestamp = ready_timestamp_;
// TODO: Return kNotReady unless a new ready syncset is found.
return NodeReadiness::kReadyForProcess;
}
for (int sync_set_index = 0; sync_set_index < sync_sets_.size();
++sync_set_index) {
const std::vector<CollectionItemId>& sync_set = sync_sets_[sync_set_index];
*min_stream_timestamp = Timestamp::Done();
Timestamp min_bound = Timestamp::Done();
for (CollectionItemId id : sync_set) {
const auto& stream = input_stream_managers_.Get(id);
bool empty;
Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
if (empty) {
min_bound = std::min(min_bound, stream_timestamp);
}
*min_stream_timestamp = std::min(*min_stream_timestamp, stream_timestamp);
}
if (*min_stream_timestamp == Timestamp::Done()) {
NodeReadiness readiness =
sync_sets_[sync_set_index].GetReadiness(min_stream_timestamp);
if (readiness == NodeReadiness::kReadyForClose) {
// This sync set is done, remove it. Note that this invalidates
// sync set indexes higher than sync_set_index. However, we are
// guaranteed that we were not ready before entering the outer
@ -165,15 +157,14 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
continue;
}
if (min_bound > *min_stream_timestamp) {
if (readiness == NodeReadiness::kReadyForProcess) {
// TODO: Prioritize sync-sets to avoid starvation.
if (*min_stream_timestamp < ready_timestamp_) {
// Store the timestamp and corresponding sync set index for the
// sync set with the earliest arrival timestamp.
ready_timestamp_ = *min_stream_timestamp;
ready_sync_set_index_ = sync_set_index;
}
} else {
CHECK_EQ(min_bound, *min_stream_timestamp);
}
}
if (ready_sync_set_index_ >= 0) {
@ -188,44 +179,17 @@ NodeReadiness SyncSetInputStreamHandler::GetNodeReadiness(
return NodeReadiness::kNotReady;
}
void SyncSetInputStreamHandler::FillInputBounds(
Timestamp input_timestamp, InputStreamShardSet* input_set) {
for (int i = 0; i < sync_sets_.size(); ++i) {
if (i != ready_sync_set_index_) {
// Set the input streams for the not-ready sync sets.
for (CollectionItemId id : sync_sets_[i]) {
const auto stream = input_stream_managers_.Get(id);
Timestamp bound = stream->MinTimestampOrBound(nullptr);
AddPacketToShard(&input_set->Get(id),
Packet().At(bound.PreviousAllowedInStream()),
bound == Timestamp::Done());
}
}
}
}
void SyncSetInputStreamHandler::FillInputSet(Timestamp input_timestamp,
InputStreamShardSet* input_set) {
// Assume that all current packets are already cleared.
CHECK(input_timestamp.IsAllowedInStream());
CHECK(input_set);
absl::MutexLock lock(&mutex_);
CHECK_LE(0, ready_sync_set_index_);
CHECK_EQ(input_timestamp, ready_timestamp_);
// Set the input streams for the ready sync set.
for (CollectionItemId id : sync_sets_[ready_sync_set_index_]) {
const auto& stream = input_stream_managers_.Get(id);
int num_packets_dropped = 0;
bool stream_is_done = false;
Packet current_packet = stream->PopPacketAtTimestamp(
input_timestamp, &num_packets_dropped, &stream_is_done);
CHECK_EQ(num_packets_dropped, 0)
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, stream->Name());
AddPacketToShard(&input_set->Get(id), std::move(current_packet),
stream_is_done);
sync_sets_[ready_sync_set_index_].FillInputSet(input_timestamp, input_set);
for (int i = 0; i < sync_sets_.size(); ++i) {
if (i != ready_sync_set_index_) {
sync_sets_[i].FillInputBounds(input_set);
}
}
FillInputBounds(input_timestamp, input_set);
ready_sync_set_index_ = -1;
ready_timestamp_ = Timestamp::Done();
}

View File

@ -122,7 +122,6 @@ std::string TimestampDiff::DebugString() const {
}
Timestamp Timestamp::NextAllowedInStream() const {
CHECK(IsAllowedInStream()) << "Timestamp is: " << DebugString();
if (*this >= Max() || *this == PreStream()) {
// Indicates that no further timestamps may occur.
return OneOverPostStream();

View File

@ -247,6 +247,12 @@ class TimestampDiff {
TimestampDiff operator-(const TimestampDiff other) const;
Timestamp operator+(const Timestamp other) const;
// Special values.
static TimestampDiff Unset() {
return TimestampDiff(Timestamp::Unset().Value());
}
private:
TimestampBaseType timestamp_;
};

View File

@ -815,16 +815,25 @@ NodeTypeInfo::NodeRef ValidatedGraphConfig::NodeForSorterIndex(
sorted_nodes_.push_back(&tmp_calculators.back());
}
}
if (cyclic) {
// This reads from partilly altered config_ (by node Swap()) but we assume
// the nodes in the cycle are not altered, as TopologicalSorter reports
// cyclicity before processing any node in cycle.
auto node_name_formatter = [this](std::string* out, int i) {
const auto& n = NodeForSorterIndex(i);
absl::StrAppend(out, n.type == NodeTypeInfo::NodeType::CALCULATOR
? tool::CanonicalNodeName(Config(), n.index)
: DebugName(Config(), n.type, n.index));
};
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Generator side packet cycle or calculator stream cycle detected "
"in graph: ["
<< absl::StrJoin(cycle_indexes, ", ", node_name_formatter) << "]";
}
generator_configs.Swap(config_.mutable_packet_generator());
tmp_generators.swap(generators_);
node_configs.Swap(config_.mutable_node());
tmp_calculators.swap(calculators_);
if (cyclic) {
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Generator side packet cycle or calculator stream cycle detected "
"in graph. Cycle indexes: "
<< absl::StrJoin(cycle_indexes, ", ");
}
#if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE))
VLOG(2) << "AFTER TOPOLOGICAL SORT:\n" << config_.DebugString();
#endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE)

View File

@ -57,15 +57,14 @@
#include <EGL/egl.h>
#include <GLES2/gl2.h>
#include <GLES2/gl2ext.h>
#ifdef __ANDROID__
#if defined(__ANDROID__)
// Weak-link all GL APIs included from this point on.
// TODO: Annotate these with availability attributes for the
// appropriate versions of Android, by including gl{3,31,31}.h and resetting
// GL_APICALL for each.
#undef GL_APICALL
#define GL_APICALL __attribute__((weak_import)) KHRONOS_APICALL
#endif // __ANDROID__
#endif // defined(__ANDROID__)
#include <GLES3/gl32.h>

View File

@ -83,6 +83,10 @@ class GlCalculatorHelperImpl {
GLuint framebuffer_ = 0;
GpuResources& gpu_resources_;
// Necessary to compute for a given GlContext in order to properly enforce the
// SetStandardTextureParams.
bool can_linear_filter_float_textures_;
};
} // namespace mediapipe

View File

@ -22,6 +22,17 @@ GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc,
GpuResources* gpu_resources)
: gpu_resources_(*gpu_resources) {
gl_context_ = gpu_resources_.gl_context(cc);
// GL_ES_VERSION_2_0 and up (at least through ES 3.2) may contain the extension.
// Checking against one also checks against higher ES versions. So this checks
// against GLES >= 2.0.
#if GL_ES_VERSION_2_0
// No linear float filtering by default, check extensions.
can_linear_filter_float_textures_ =
gl_context_->HasGlExtension("OES_texture_float_linear");
#else
// Any float32 texture we create should automatically have linear filtering.
can_linear_filter_float_textures_ = true;
#endif // GL_ES_VERSION_2_0
}
GlCalculatorHelperImpl::~GlCalculatorHelperImpl() {
@ -89,13 +100,15 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target,
GLint internal_format) {
// Default to using linear filter everywhere. For float32 textures, fall back
// to GL_NEAREST if linear filtering unsupported.
GLint filter;
switch (internal_format) {
case GL_R32F:
case GL_RGBA32F:
// 32F (unlike 16f) textures do not support texture filtering
// 32F (unlike 16f) textures do not always support texture filtering
// (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION])
filter = GL_NEAREST;
filter = can_linear_filter_float_textures_ ? GL_LINEAR : GL_NEAREST;
break;
default:
filter = GL_LINEAR;

View File

@ -203,6 +203,69 @@ bool GlContext::ParseGlVersion(absl::string_view version_string, GLint* major,
return true;
}
bool GlContext::HasGlExtension(absl::string_view extension) const {
return gl_extensions_.find(extension) != gl_extensions_.end();
}
// Function for GL3.0+ to query for and store all of our available GL extensions
// in an easily-accessible set. The glGetString call is actually *not* required
// to work with GL_EXTENSIONS for newer GL versions, so we must maintain both
// variations of this function.
::mediapipe::Status GlContext::GetGlExtensions() {
gl_extensions_.clear();
// glGetStringi only introduced in GL 3.0+; so we exit out this function if
// we don't have that function defined, regardless of version number reported.
// The function itself is also fully stubbed out if we're linking against an
// API version without a glGetStringi declaration. Although Emscripten
// sometimes provides this function, its default library implementation
// appears to only provide glGetString, so we skip this for Emscripten
// platforms to avoid possible undefined symbol or runtime errors.
#if (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__)
if (!SymbolAvailable(&glGetStringi)) {
LOG(ERROR) << "GL major version > 3.0 indicated, but glGetStringi not "
<< "defined. Falling back to deprecated GL extensions querying "
<< "method.";
return ::mediapipe::InternalError("glGetStringi not defined, but queried");
}
int num_extensions = 0;
glGetIntegerv(GL_NUM_EXTENSIONS, &num_extensions);
if (glGetError() != 0) {
return ::mediapipe::InternalError(
"Error querying for number of extensions");
}
for (int i = 0; i < num_extensions; ++i) {
const GLubyte* res = glGetStringi(GL_EXTENSIONS, i);
if (glGetError() != 0 || res == nullptr) {
return ::mediapipe::InternalError(
"Error querying for an extension by index");
}
const char* signed_res = reinterpret_cast<const char*>(res);
gl_extensions_.insert(signed_res);
}
return ::mediapipe::OkStatus();
#else
return ::mediapipe::InternalError("GL version mismatch in GlGetExtensions");
#endif // (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__)
}
// Same as GetGlExtensions() above, but for pre-GL3.0, where glGetStringi did
// not exist.
::mediapipe::Status GlContext::GetGlExtensionsCompat() {
gl_extensions_.clear();
const GLubyte* res = glGetString(GL_EXTENSIONS);
if (glGetError() != 0 || res == nullptr) {
LOG(ERROR) << "Error querying for GL extensions";
return ::mediapipe::InternalError("Error querying for GL extensions");
}
const char* signed_res = reinterpret_cast<const char*>(res);
gl_extensions_ = absl::StrSplit(signed_res, ' ');
return ::mediapipe::OkStatus();
}
::mediapipe::Status GlContext::FinishInitialization(bool create_thread) {
if (create_thread) {
thread_ = absl::make_unique<GlContext::DedicatedThread>();
@ -232,8 +295,13 @@ bool GlContext::ParseGlVersion(absl::string_view version_string, GLint* major,
LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_
<< " (" << glGetString(GL_VERSION) << ")";
return ::mediapipe::OkStatus();
if (gl_major_version_ >= 3) {
auto status = GetGlExtensions();
if (status.ok()) {
return ::mediapipe::OkStatus();
}
}
return GetGlExtensionsCompat();
});
}

View File

@ -237,6 +237,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
static bool ParseGlVersion(absl::string_view version_string, GLint* major,
GLint* minor);
// Simple query for GL extension support; only valid after GlContext has
// finished its initialization successfully.
bool HasGlExtension(absl::string_view extension) const;
int64_t gl_finish_count() { return gl_finish_count_; }
// Used by GlFinishSyncPoint. The count_to_pass cannot exceed the current
@ -346,6 +350,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
bool HasContext() const;
bool CheckForGlErrors();
void LogUncheckedGlErrors(bool had_gl_errors);
::mediapipe::Status GetGlExtensions();
::mediapipe::Status GetGlExtensionsCompat();
// The following ContextBinding functions have platform-specific
// implementations.
@ -366,6 +372,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
GLint gl_major_version_ = 0;
GLint gl_minor_version_ = 0;
// glGetString and glGetStringi both return pointers to static strings,
// so we should be fine storing the extension pieces as string_view's.
std::set<absl::string_view> gl_extensions_;
// Number of glFinish calls completed on the GL thread.
// Changes should be guarded by mutex_. However, we use simple atomic
// loads for efficiency on the fast path.

View File

@ -24,6 +24,19 @@ namespace mediapipe {
#define _STRINGIFY(_x) __STRINGIFY(_x)
#endif
// Our fragment shaders use DEFAULT_PRECISION to define the default precision
// for a type. The macro strips out the precision declaration on desktop GL,
// where it's not supported.
//
// Note: this does not use a raw std::string because some compilers don't handle
// raw strings inside macros correctly. It uses a macro because we want to be
// able to concatenate strings by juxtaposition. We want to concatenate strings
// by juxtaposition so we can export const char* static data containing the
// pre-expanded strings.
//
// TODO: this was written before we could rely on C++11 support.
// Consider replacing it with constexpr std::string concatenation, or replacing
// the static variables with functions.
#define PRECISION_COMPAT \
GLES_VERSION_COMPAT \
"#ifdef GL_ES \n" \
@ -42,10 +55,15 @@ namespace mediapipe {
"#define out varying\n" \
"#endif // __VERSION__ < 130\n"
#define FRAGMENT_PREAMBLE \
PRECISION_COMPAT \
"#if __VERSION__ < 130\n" \
"#define in varying\n" \
// Note: on systems where highp precision for floats is not supported (look up
// GL_FRAGMENT_PRECISION_HIGH), we replace it with mediump.
#define FRAGMENT_PREAMBLE \
PRECISION_COMPAT \
"#if __VERSION__ < 130\n" \
"#define in varying\n" \
"#if GL_ES && !GL_FRAGMENT_PRECISION_HIGH\n" \
"#define highp mediump\n" \
"#endif // GL_ES && !GL_FRAGMENT_PRECISION_HIGH\n" \
"#endif // __VERSION__ < 130\n"
const GLchar* const kMediaPipeVertexShaderPreamble = VERTEX_PREAMBLE;

View File

@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_binary_graph",
)
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
@ -33,9 +38,19 @@ cc_library(
],
)
load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_binary_graph",
cc_library(
name = "desktop_calculators",
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:recolor_calculator",
"//mediapipe/calculators/image:set_alpha_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator",
],
)
mediapipe_binary_graph(

View File

@ -0,0 +1,152 @@
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on CPU.
# Used in the example in
# mediapipie/examples/desktop/hair_segmentation:hair_segmentation_cpu
# Images on CPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# TfLiteTensorsToSegmentationCalculator downstream in the graph to finish
# generating the corresponding hair mask before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# TfLiteTensorsToSegmentationCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:hair_mask"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transforms the input image on CPU to a 512x512 image. To scale the image, by
# default it uses the STRETCH scale mode that maps the entire input image to the
# entire transformed image. As a result, image aspect ratio may be changed and
# objects in the image may be deformed (stretched or squeezed), but the hair
# segmentation model used in this graph is agnostic to that deformation.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:throttled_input_video"
output_stream: "IMAGE:transformed_input_video"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 512
output_height: 512
}
}
}
# Caches a mask fed back from the previous round of hair segmentation, and upon
# the arrival of the next input image sends out the cached mask with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous mask. Note that upon the arrival of the very first
# input image, an empty packet is sent out to jump start the feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:throttled_input_video"
input_stream: "LOOP:hair_mask"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:previous_hair_mask"
}
# Embeds the hair mask generated from the previous round of hair segmentation
# as the alpha channel of the current input image.
node {
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:transformed_input_video"
input_stream: "ALPHA:previous_hair_mask"
output_stream: "IMAGE:mask_embedded_input_video"
}
# Converts the transformed input image on CPU into an image tensor stored in
# TfLiteTensor. The zero_center option is set to false to normalize the
# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. With the
# max_num_channels option set to 4, all 4 RGBA channels are contained in the
# image tensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:mask_embedded_input_video"
output_stream: "TENSORS:image_tensor"
node_options: {
[type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
zero_center: false
max_num_channels: 4
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "op_resolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
use_gpu: false
}
}
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# tensor representing the hair segmentation, which has the same width and height
# as the input image tensor.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:segmentation_tensor"
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/hair_segmentation.tflite"
use_gpu: false
}
}
}
# Decodes the segmentation tensor generated by the TensorFlow Lite model into a
# mask of values in [0, 255], stored in a CPU buffer. It also
# takes the mask generated previously as another input to improve the temporal
# consistency.
node {
calculator: "TfLiteTensorsToSegmentationCalculator"
input_stream: "TENSORS:segmentation_tensor"
input_stream: "PREV_MASK:previous_hair_mask"
output_stream: "MASK:hair_mask"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] {
tensor_width: 512
tensor_height: 512
tensor_channels: 2
combine_with_previous_ratio: 0.9
output_layer_index: 1
}
}
}
# Colors the hair segmentation with the color specified in the option.
node {
calculator: "RecolorCalculator"
input_stream: "IMAGE:throttled_input_video"
input_stream: "MASK:hair_mask"
output_stream: "IMAGE:output_video"
node_options: {
[type.googleapis.com/mediapipe.RecolorCalculatorOptions] {
color { r: 0 g: 0 b: 255 }
mask_channel: RED
}
}
}

View File

@ -78,6 +78,33 @@ cat > $(OUTS) <<EOF
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
android_library(
name = name + "_android_lib",
srcs = [
@ -86,6 +113,9 @@ cat > $(OUTS) <<EOF
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/proto/CalculatorProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
],
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],

View File

@ -43,31 +43,6 @@ MEDIAPIPE_IOS_HDRS = [
"NSError+util_status.h",
]
MEDIAPIPE_IOS_CC_DEPS = [
":CFHolder",
":util",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:mediapipe_profiling",
"//mediapipe/gpu:MPPGraphGPUData",
"//mediapipe/gpu:pixel_buffer_pool_util",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gl_base",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/gpu:graph_support",
# Other deps
"//mediapipe/util:cpu_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
]
objc_library(
name = "mediapipe_framework_ios",
srcs = MEDIAPIPE_IOS_SRCS,
@ -80,8 +55,28 @@ objc_library(
"Accelerate",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = MEDIAPIPE_IOS_CC_DEPS + [
# These are objc_library deps.
deps = [
":CFHolder",
":util",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/port:threadpool",
"//mediapipe/gpu:MPPGraphGPUData",
"//mediapipe/gpu:gl_base",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_shared_data_internal",
"//mediapipe/gpu:graph_support",
"//mediapipe/gpu:pixel_buffer_pool_util",
"//mediapipe/util:cpu_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@google_toolbox_for_mac//:GTM_Defines",
],
)

View File

@ -426,7 +426,7 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
cv::Point point_to_draw(x, y);
const cv::Scalar color = MediapipeColorToOpenCVColor(annotation.color());
const int thickness = annotation.thickness();
cv::circle(mat_image_, point_to_draw, thickness, color, thickness);
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
}
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {

View File

@ -22,7 +22,6 @@ cc_library(
hdrs = ["media_sequence_util.h"],
visibility = [
"//mediapipe:__subpackages__",
"//research/action_recognition/sequence:__subpackages__",
],
deps = [
"//mediapipe/framework/port:core_proto",
@ -38,7 +37,6 @@ cc_library(
hdrs = ["media_sequence.h"],
visibility = [
"//mediapipe:__subpackages__",
"//research/action_recognition/sequence:__subpackages__",
],
deps = [
":media_sequence_util",

View File

@ -36,7 +36,11 @@ cd /tmp/build_opencv
git clone https://github.com/opencv/opencv_contrib.git
git clone https://github.com/opencv/opencv.git
mkdir opencv/release
cd opencv/release
cd opencv_contrib
git checkout 3.4
cd ../opencv
git checkout 3.4
cd release
cmake .. -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr/local \
-DBUILD_TESTS=OFF -DBUILD_PERF_TESTS=OFF -DBUILD_opencv_ts=OFF \
-DOPENCV_EXTRA_MODULES_PATH=/tmp/build_opencv/opencv_contrib/modules \

View File

@ -0,0 +1,112 @@
diff --git a/third_party/cpuinfo/BUILD.bazel b/third_party/cpuinfo/BUILD.bazel
index 8d89521612..6ea60acdda 100644
--- a/third_party/cpuinfo/BUILD.bazel
+++ b/third_party/cpuinfo/BUILD.bazel
@@ -116,6 +111,8 @@ cc_library(
":watchos_x86": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
":watchos_armv7k": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
":watchos_arm64_32": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
+ ":tvos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS,
+ ":tvos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS,
":emscripten_wasm": COMMON_SRCS + EMSCRIPTEN_SRCS,
}),
copts = C99OPTS + [
@@ -212,7 +209,7 @@ config_setting(
config_setting(
name = "ios_armv7",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "ios",
"cpu": "ios_armv7",
},
)
@@ -220,7 +217,7 @@ config_setting(
config_setting(
name = "ios_arm64",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "ios",
"cpu": "ios_arm64",
},
)
@@ -228,7 +225,7 @@ config_setting(
config_setting(
name = "ios_arm64e",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "ios",
"cpu": "ios_arm64e",
},
)
@@ -236,7 +233,7 @@ config_setting(
config_setting(
name = "ios_x86",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "ios",
"cpu": "ios_i386",
},
)
@@ -244,7 +241,7 @@ config_setting(
config_setting(
name = "ios_x86_64",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "ios",
"cpu": "ios_x86_64",
},
)
@@ -252,7 +249,7 @@ config_setting(
config_setting(
name = "watchos_armv7k",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "watchos",
"cpu": "watchos_armv7k",
},
)
@@ -260,7 +257,7 @@ config_setting(
config_setting(
name = "watchos_arm64_32",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "watchos",
"cpu": "watchos_arm64_32",
},
)
@@ -268,7 +265,7 @@ config_setting(
config_setting(
name = "watchos_x86",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "watchos",
"cpu": "watchos_i386",
},
)
@@ -276,7 +273,7 @@ config_setting(
config_setting(
name = "watchos_x86_64",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "watchos",
"cpu": "watchos_x86_64",
},
)
@@ -284,7 +281,7 @@ config_setting(
config_setting(
name = "tvos_arm64",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "tvos",
"cpu": "tvos_arm64",
},
)
@@ -292,7 +289,7 @@ config_setting(
config_setting(
name = "tvos_x86_64",
values = {
- "crosstool_top": "//tools/osx/crosstool:crosstool",
+ "apple_platform_type": "tvos",
"cpu": "tvos_x86_64",
},
)

File diff suppressed because it is too large Load Diff