Project import generated by Copybara.
GitOrigin-RevId: 8e1da4611d93ccb7d9674713157d43be0348d98f
This commit is contained in:
parent
50c92c6623
commit
b899d17f18
|
@ -220,6 +220,7 @@ import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
mp_drawing = mp.solutions.drawing_utils
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
mp_hands = mp.solutions.hands
|
mp_hands = mp.solutions.hands
|
||||||
|
drawing_styles = mp.solutions.drawing_styles
|
||||||
|
|
||||||
# For static images:
|
# For static images:
|
||||||
IMAGE_FILES = []
|
IMAGE_FILES = []
|
||||||
|
@ -248,7 +249,9 @@ with mp_hands.Hands(
|
||||||
f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})'
|
f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})'
|
||||||
)
|
)
|
||||||
mp_drawing.draw_landmarks(
|
mp_drawing.draw_landmarks(
|
||||||
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
|
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
|
||||||
|
drawing_styles.get_default_hand_landmark_style(),
|
||||||
|
drawing_styles.get_default_hand_connection_style())
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
'/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1))
|
'/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1))
|
||||||
|
|
||||||
|
@ -278,7 +281,9 @@ with mp_hands.Hands(
|
||||||
if results.multi_hand_landmarks:
|
if results.multi_hand_landmarks:
|
||||||
for hand_landmarks in results.multi_hand_landmarks:
|
for hand_landmarks in results.multi_hand_landmarks:
|
||||||
mp_drawing.draw_landmarks(
|
mp_drawing.draw_landmarks(
|
||||||
image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
|
image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
|
||||||
|
drawing_styles.get_default_hand_landmark_style(),
|
||||||
|
drawing_styles.get_default_hand_connection_style())
|
||||||
cv2.imshow('MediaPipe Hands', image)
|
cv2.imshow('MediaPipe Hands', image)
|
||||||
if cv2.waitKey(5) & 0xFF == 27:
|
if cv2.waitKey(5) & 0xFF == 27:
|
||||||
break
|
break
|
||||||
|
|
|
@ -24,6 +24,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kDataTag[] = "DATA";
|
||||||
|
constexpr char kHeaderTag[] = "HEADER";
|
||||||
|
|
||||||
class AddHeaderCalculatorTest : public ::testing::Test {};
|
class AddHeaderCalculatorTest : public ::testing::Test {};
|
||||||
|
|
||||||
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
||||||
|
@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
||||||
CalculatorRunner runner(node);
|
CalculatorRunner runner(node);
|
||||||
|
|
||||||
// Set header and add 5 packets.
|
// Set header and add 5 packets.
|
||||||
runner.MutableInputs()->Tag("HEADER").header =
|
runner.MutableInputs()->Tag(kHeaderTag).header =
|
||||||
Adopt(new std::string("my_header"));
|
Adopt(new std::string("my_header"));
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run calculator.
|
// Run calculator.
|
||||||
|
@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
|
||||||
CalculatorRunner runner(node);
|
CalculatorRunner runner(node);
|
||||||
|
|
||||||
// Set header and add 5 packets.
|
// Set header and add 5 packets.
|
||||||
runner.MutableInputs()->Tag("HEADER").header =
|
runner.MutableInputs()->Tag(kHeaderTag).header =
|
||||||
Adopt(new std::string("my_header"));
|
Adopt(new std::string("my_header"));
|
||||||
runner.MutableInputs()->Tag("HEADER").packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(new std::string("not allowed")));
|
->Tag(kHeaderTag)
|
||||||
|
.packets.push_back(Adopt(new std::string("not allowed")));
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run calculator.
|
// Run calculator.
|
||||||
|
@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
|
||||||
CalculatorRunner runner(node);
|
CalculatorRunner runner(node);
|
||||||
|
|
||||||
// Set header and add 5 packets.
|
// Set header and add 5 packets.
|
||||||
runner.MutableSidePackets()->Tag("HEADER") =
|
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||||
Adopt(new std::string("my_header"));
|
Adopt(new std::string("my_header"));
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run calculator.
|
// Run calculator.
|
||||||
|
@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
|
||||||
CalculatorRunner runner(node);
|
CalculatorRunner runner(node);
|
||||||
|
|
||||||
// Set both headers and add 5 packets.
|
// Set both headers and add 5 packets.
|
||||||
runner.MutableSidePackets()->Tag("HEADER") =
|
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||||
Adopt(new std::string("my_header"));
|
Adopt(new std::string("my_header"));
|
||||||
runner.MutableSidePackets()->Tag("HEADER") =
|
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||||
Adopt(new std::string("my_header"));
|
Adopt(new std::string("my_header"));
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run should fail because header can only be provided one way.
|
// Run should fail because header can only be provided one way.
|
||||||
|
|
|
@ -19,6 +19,13 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kIncrementTag[] = "INCREMENT";
|
||||||
|
constexpr char kInitialValueTag[] = "INITIAL_VALUE";
|
||||||
|
constexpr char kBatchSizeTag[] = "BATCH_SIZE";
|
||||||
|
constexpr char kErrorCountTag[] = "ERROR_COUNT";
|
||||||
|
constexpr char kMaxCountTag[] = "MAX_COUNT";
|
||||||
|
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
|
||||||
|
|
||||||
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
|
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
|
||||||
// sequential numbers from INITIAL_VALUE (default 0) with a common
|
// sequential numbers from INITIAL_VALUE (default 0) with a common
|
||||||
// difference of INCREMENT (default 1) between successive numbers (with
|
// difference of INCREMENT (default 1) between successive numbers (with
|
||||||
|
@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase {
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Outputs().Index(0).Set<int>();
|
cc->Outputs().Index(0).Set<int>();
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
|
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) {
|
||||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
|
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
|
RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) ||
|
||||||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
|
cc->InputSidePackets().HasTag(kErrorCountTag));
|
||||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
|
||||||
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
|
cc->InputSidePackets().Tag(kMaxCountTag).Set<int>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
|
||||||
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
|
cc->InputSidePackets().Tag(kErrorCountTag).Set<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
|
||||||
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
|
cc->InputSidePackets().Tag(kBatchSizeTag).Set<int>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
|
||||||
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
|
cc->InputSidePackets().Tag(kInitialValueTag).Set<int>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
|
||||||
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
|
cc->InputSidePackets().Tag(kIncrementTag).Set<int>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
|
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) &&
|
||||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
|
cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
|
||||||
return absl::NotFoundError("expected error");
|
return absl::NotFoundError("expected error");
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
|
||||||
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
|
error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get<int>();
|
||||||
RET_CHECK_LE(0, error_count_);
|
RET_CHECK_LE(0, error_count_);
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
|
||||||
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
|
max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get<int>();
|
||||||
RET_CHECK_LE(0, max_count_);
|
RET_CHECK_LE(0, max_count_);
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
|
||||||
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
|
batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get<int>();
|
||||||
RET_CHECK_LT(0, batch_size_);
|
RET_CHECK_LT(0, batch_size_);
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
|
||||||
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
|
counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get<int>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
|
||||||
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
|
increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get<int>();
|
||||||
RET_CHECK_LT(0, increment_);
|
RET_CHECK_LT(0, increment_);
|
||||||
}
|
}
|
||||||
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
|
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
|
||||||
|
|
|
@ -35,11 +35,14 @@
|
||||||
// }
|
// }
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||||
|
constexpr char kEncodedTag[] = "ENCODED";
|
||||||
|
|
||||||
class DequantizeByteArrayCalculator : public CalculatorBase {
|
class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("ENCODED").Set<std::string>();
|
cc->Inputs().Tag(kEncodedTag).Set<std::string>();
|
||||||
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
cc->Outputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
const std::string& encoded =
|
const std::string& encoded =
|
||||||
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
|
cc->Inputs().Tag(kEncodedTag).Value().Get<std::string>();
|
||||||
std::vector<float> float_vector;
|
std::vector<float> float_vector;
|
||||||
float_vector.reserve(encoded.length());
|
float_vector.reserve(encoded.length());
|
||||||
for (int i = 0; i < encoded.length(); ++i) {
|
for (int i = 0; i < encoded.length(); ++i) {
|
||||||
|
@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||||
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
||||||
}
|
}
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("FLOAT_VECTOR")
|
.Tag(kFloatVectorTag)
|
||||||
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -25,6 +25,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||||
|
constexpr char kEncodedTag[] = "ENCODED";
|
||||||
|
|
||||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||||
CalculatorGraphConfig::Node node_config =
|
CalculatorGraphConfig::Node node_config =
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
@ -39,7 +42,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||||
)pb");
|
)pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::string empty_string;
|
std::string empty_string;
|
||||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kEncodedTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
|
@ -64,7 +69,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||||
)pb");
|
)pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::string empty_string;
|
std::string empty_string;
|
||||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kEncodedTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
|
@ -89,7 +96,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||||
)pb");
|
)pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::string empty_string;
|
std::string empty_string;
|
||||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kEncodedTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
|
@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
|
||||||
)pb");
|
)pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
||||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kEncodedTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::string>(
|
MakePacket<std::string>(
|
||||||
std::string(reinterpret_cast<char const*>(input), 4))
|
std::string(reinterpret_cast<char const*>(input), 4))
|
||||||
.At(Timestamp(0)));
|
.At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& outputs =
|
const std::vector<Packet>& outputs =
|
||||||
runner.Outputs().Tag("FLOAT_VECTOR").packets;
|
runner.Outputs().Tag(kFloatVectorTag).packets;
|
||||||
EXPECT_EQ(1, outputs.size());
|
EXPECT_EQ(1, outputs.size());
|
||||||
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
||||||
ASSERT_FALSE(result.empty());
|
ASSERT_FALSE(result.empty());
|
||||||
|
|
|
@ -24,6 +24,11 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kFinishedTag[] = "FINISHED";
|
||||||
|
constexpr char kAllowTag[] = "ALLOW";
|
||||||
|
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
|
||||||
|
constexpr char kOptionsTag[] = "OPTIONS";
|
||||||
|
|
||||||
// FlowLimiterCalculator is used to limit the number of frames in flight
|
// FlowLimiterCalculator is used to limit the number of frames in flight
|
||||||
// by dropping input frames when necessary.
|
// by dropping input frames when necessary.
|
||||||
//
|
//
|
||||||
|
@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
auto& side_inputs = cc->InputSidePackets();
|
auto& side_inputs = cc->InputSidePackets();
|
||||||
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
side_inputs.Tag(kOptionsTag).Set<FlowLimiterCalculatorOptions>().Optional();
|
||||||
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
cc->Inputs()
|
||||||
|
.Tag(kOptionsTag)
|
||||||
|
.Set<FlowLimiterCalculatorOptions>()
|
||||||
|
.Optional();
|
||||||
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
|
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
|
||||||
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
|
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
|
||||||
cc->Inputs().Get("", i).SetAny();
|
cc->Inputs().Get("", i).SetAny();
|
||||||
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
||||||
}
|
}
|
||||||
cc->Inputs().Get("FINISHED", 0).SetAny();
|
cc->Inputs().Get("FINISHED", 0).SetAny();
|
||||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional();
|
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>().Optional();
|
||||||
cc->Outputs().Tag("ALLOW").Set<bool>().Optional();
|
cc->Outputs().Tag(kAllowTag).Set<bool>().Optional();
|
||||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
cc->SetProcessTimestampBounds(true);
|
cc->SetProcessTimestampBounds(true);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
options_ = cc->Options<FlowLimiterCalculatorOptions>();
|
options_ = cc->Options<FlowLimiterCalculatorOptions>();
|
||||||
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
|
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
|
||||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||||
options_.set_max_in_flight(
|
options_.set_max_in_flight(
|
||||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>());
|
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
||||||
}
|
}
|
||||||
input_queues_.resize(cc->Inputs().NumEntries(""));
|
input_queues_.resize(cc->Inputs().NumEntries(""));
|
||||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
||||||
|
@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
|
|
||||||
// Outputs a packet indicating whether a frame was sent or dropped.
|
// Outputs a packet indicating whether a frame was sent or dropped.
|
||||||
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
||||||
if (cc->Outputs().HasTag("ALLOW")) {
|
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||||
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts));
|
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
||||||
|
|
||||||
// Process the FINISHED input stream.
|
// Process the FINISHED input stream.
|
||||||
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value();
|
Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value();
|
||||||
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
|
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
|
||||||
while (!frames_in_flight_.empty() &&
|
while (!frames_in_flight_.empty() &&
|
||||||
frames_in_flight_.front() <= finished_packet.Timestamp()) {
|
frames_in_flight_.front() <= finished_packet.Timestamp()) {
|
||||||
|
@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
Timestamp bound =
|
Timestamp bound =
|
||||||
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
|
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
|
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
|
||||||
if (cc->Outputs().HasTag("ALLOW")) {
|
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW"));
|
SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,13 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS";
|
||||||
|
constexpr char kClockTag[] = "CLOCK";
|
||||||
|
constexpr char kWarmupTimeTag[] = "WARMUP_TIME";
|
||||||
|
constexpr char kSleepTimeTag[] = "SLEEP_TIME";
|
||||||
|
constexpr char kPacketTag[] = "PACKET";
|
||||||
|
|
||||||
// A simple Semaphore for synchronizing test threads.
|
// A simple Semaphore for synchronizing test threads.
|
||||||
class AtomicSemaphore {
|
class AtomicSemaphore {
|
||||||
public:
|
public:
|
||||||
|
@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
|
||||||
class SleepCalculator : public CalculatorBase {
|
class SleepCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("PACKET").SetAny();
|
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
|
||||||
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>();
|
cc->InputSidePackets().Tag(kSleepTimeTag).Set<int64>();
|
||||||
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>();
|
cc->InputSidePackets().Tag(kWarmupTimeTag).Set<int64>();
|
||||||
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>();
|
cc->InputSidePackets().Tag(kClockTag).Set<mediapipe::Clock*>();
|
||||||
cc->SetTimestampOffset(0);
|
cc->SetTimestampOffset(0);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>();
|
clock_ = cc->InputSidePackets().Tag(kClockTag).Get<mediapipe::Clock*>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase {
|
||||||
++packet_count;
|
++packet_count;
|
||||||
absl::Duration sleep_time = absl::Microseconds(
|
absl::Duration sleep_time = absl::Microseconds(
|
||||||
packet_count == 1
|
packet_count == 1
|
||||||
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>()
|
? cc->InputSidePackets().Tag(kWarmupTimeTag).Get<int64>()
|
||||||
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>());
|
: cc->InputSidePackets().Tag(kSleepTimeTag).Get<int64>());
|
||||||
clock_->Sleep(sleep_time);
|
clock_->Sleep(sleep_time);
|
||||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
cc->Outputs()
|
||||||
|
.Tag(kPacketTag)
|
||||||
|
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator);
|
||||||
class DropCalculator : public CalculatorBase {
|
class DropCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("PACKET").SetAny();
|
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
|
||||||
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>();
|
cc->InputSidePackets().Tag(kDropTimestampsTag).Set<bool>();
|
||||||
cc->SetProcessTimestampBounds(true);
|
cc->SetProcessTimestampBounds(true);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
|
||||||
++packet_count;
|
++packet_count;
|
||||||
}
|
}
|
||||||
bool drop = (packet_count == 3);
|
bool drop = (packet_count == 3);
|
||||||
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
|
||||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
cc->Outputs()
|
||||||
|
.Tag(kPacketTag)
|
||||||
|
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
|
||||||
}
|
}
|
||||||
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) {
|
if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get<bool>()) {
|
||||||
cc->Outputs().Tag("PACKET").SetNextTimestampBound(
|
cc->Outputs()
|
||||||
cc->InputTimestamp().NextAllowedInStream());
|
.Tag(kPacketTag)
|
||||||
|
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,11 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kStateChangeTag[] = "STATE_CHANGE";
|
||||||
|
constexpr char kDisallowTag[] = "DISALLOW";
|
||||||
|
constexpr char kAllowTag[] = "ALLOW";
|
||||||
|
|
||||||
enum GateState {
|
enum GateState {
|
||||||
GATE_UNINITIALIZED,
|
GATE_UNINITIALIZED,
|
||||||
GATE_ALLOW,
|
GATE_ALLOW,
|
||||||
|
@ -83,30 +88,31 @@ class GateCalculator : public CalculatorBase {
|
||||||
GateCalculator() {}
|
GateCalculator() {}
|
||||||
|
|
||||||
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
|
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
|
||||||
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
|
bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) ||
|
||||||
cc->InputSidePackets().HasTag("DISALLOW");
|
cc->InputSidePackets().HasTag(kDisallowTag);
|
||||||
bool input_via_stream =
|
bool input_via_stream =
|
||||||
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
|
cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag);
|
||||||
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
|
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
|
||||||
// input.
|
// input.
|
||||||
RET_CHECK(input_via_side_packet ^ input_via_stream);
|
RET_CHECK(input_via_side_packet ^ input_via_stream);
|
||||||
|
|
||||||
if (input_via_side_packet) {
|
if (input_via_side_packet) {
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
|
RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^
|
||||||
cc->InputSidePackets().HasTag("DISALLOW"));
|
cc->InputSidePackets().HasTag(kDisallowTag));
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
if (cc->InputSidePackets().HasTag(kAllowTag)) {
|
||||||
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
|
cc->InputSidePackets().Tag(kAllowTag).Set<bool>();
|
||||||
} else {
|
} else {
|
||||||
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
|
cc->InputSidePackets().Tag(kDisallowTag).Set<bool>();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
|
RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^
|
||||||
|
cc->Inputs().HasTag(kDisallowTag));
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("ALLOW")) {
|
if (cc->Inputs().HasTag(kAllowTag)) {
|
||||||
cc->Inputs().Tag("ALLOW").Set<bool>();
|
cc->Inputs().Tag(kAllowTag).Set<bool>();
|
||||||
} else {
|
} else {
|
||||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
cc->Inputs().Tag(kDisallowTag).Set<bool>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -125,8 +131,8 @@ class GateCalculator : public CalculatorBase {
|
||||||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||||
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
|
cc->Outputs().Tag(kStateChangeTag).Set<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -134,14 +140,14 @@ class GateCalculator : public CalculatorBase {
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
use_side_packet_for_allow_disallow_ = false;
|
use_side_packet_for_allow_disallow_ = false;
|
||||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
if (cc->InputSidePackets().HasTag(kAllowTag)) {
|
||||||
use_side_packet_for_allow_disallow_ = true;
|
use_side_packet_for_allow_disallow_ = true;
|
||||||
allow_by_side_packet_decision_ =
|
allow_by_side_packet_decision_ =
|
||||||
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
|
cc->InputSidePackets().Tag(kAllowTag).Get<bool>();
|
||||||
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
|
} else if (cc->InputSidePackets().HasTag(kDisallowTag)) {
|
||||||
use_side_packet_for_allow_disallow_ = true;
|
use_side_packet_for_allow_disallow_ = true;
|
||||||
allow_by_side_packet_decision_ =
|
allow_by_side_packet_decision_ =
|
||||||
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
|
!cc->InputSidePackets().Tag(kDisallowTag).Get<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
@ -160,18 +166,18 @@ class GateCalculator : public CalculatorBase {
|
||||||
if (use_side_packet_for_allow_disallow_) {
|
if (use_side_packet_for_allow_disallow_) {
|
||||||
allow = allow_by_side_packet_decision_;
|
allow = allow_by_side_packet_decision_;
|
||||||
} else {
|
} else {
|
||||||
if (cc->Inputs().HasTag("ALLOW") &&
|
if (cc->Inputs().HasTag(kAllowTag) &&
|
||||||
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
|
!cc->Inputs().Tag(kAllowTag).IsEmpty()) {
|
||||||
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
|
allow = cc->Inputs().Tag(kAllowTag).Get<bool>();
|
||||||
}
|
}
|
||||||
if (cc->Inputs().HasTag("DISALLOW") &&
|
if (cc->Inputs().HasTag(kDisallowTag) &&
|
||||||
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
|
!cc->Inputs().Tag(kDisallowTag).IsEmpty()) {
|
||||||
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
|
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
||||||
last_gate_state_ != new_gate_state) {
|
last_gate_state_ != new_gate_state) {
|
||||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||||
|
@ -179,7 +185,7 @@ class GateCalculator : public CalculatorBase {
|
||||||
<< ToString(last_gate_state_) << " to "
|
<< ToString(last_gate_state_) << " to "
|
||||||
<< ToString(new_gate_state);
|
<< ToString(new_gate_state);
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("STATE_CHANGE")
|
.Tag(kStateChangeTag)
|
||||||
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
|
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,9 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kDisallowTag[] = "DISALLOW";
|
||||||
|
constexpr char kAllowTag[] = "ALLOW";
|
||||||
|
|
||||||
class GateCalculatorTest : public ::testing::Test {
|
class GateCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
// Helper to run a graph and return status.
|
// Helper to run a graph and return status.
|
||||||
|
@ -117,7 +120,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
||||||
input_stream: "test_input"
|
input_stream: "test_input"
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64 kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
|
@ -139,7 +142,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
||||||
input_stream: "test_input"
|
input_stream: "test_input"
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64 kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
|
@ -161,7 +164,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
||||||
input_stream: "test_input"
|
input_stream: "test_input"
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64 kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
|
@ -179,7 +182,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
||||||
input_stream: "test_input"
|
input_stream: "test_input"
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64 kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
|
|
|
@ -39,20 +39,24 @@ using testing::ElementsAre;
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kClockTag[] = "CLOCK";
|
||||||
|
|
||||||
using mediapipe::Clock;
|
using mediapipe::Clock;
|
||||||
|
|
||||||
// A Calculator with a fixed Process call latency.
|
// A Calculator with a fixed Process call latency.
|
||||||
class SleepCalculator : public CalculatorBase {
|
class SleepCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
|
||||||
cc->Inputs().Index(0).SetAny();
|
cc->Inputs().Index(0).SetAny();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
cc->SetTimestampOffset(TimestampDiff(0));
|
cc->SetTimestampOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
clock_ =
|
||||||
|
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kMinuendTag[] = "MINUEND";
|
||||||
|
constexpr char kSubtrahendTag[] = "SUBTRAHEND";
|
||||||
|
|
||||||
// A 3x4 Matrix of random integers in [0,1000).
|
// A 3x4 Matrix of random integers in [0,1000).
|
||||||
const char kMatrixText[] =
|
const char kMatrixText[] =
|
||||||
"rows: 3\n"
|
"rows: 3\n"
|
||||||
|
@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
Matrix* side_matrix = new Matrix();
|
Matrix* side_matrix = new Matrix();
|
||||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||||
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
|
runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix);
|
||||||
|
|
||||||
Matrix* input_matrix = new Matrix();
|
Matrix* input_matrix = new Matrix();
|
||||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||||
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_matrix).At(Timestamp(0)));
|
->Tag(kMinuendTag)
|
||||||
|
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||||
|
@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
Matrix* side_matrix = new Matrix();
|
Matrix* side_matrix = new Matrix();
|
||||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||||
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
|
runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix);
|
||||||
|
|
||||||
Matrix* input_matrix = new Matrix();
|
Matrix* input_matrix = new Matrix();
|
||||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("SUBTRAHEND")
|
->Tag(kSubtrahendTag)
|
||||||
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kPresenceTag[] = "PRESENCE";
|
||||||
|
constexpr char kPacketTag[] = "PACKET";
|
||||||
|
|
||||||
// For each non empty input packet, emits a single output packet containing a
|
// For each non empty input packet, emits a single output packet containing a
|
||||||
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
|
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
|
||||||
// bound updates) This can be used to "flag" the presence of an arbitrary packet
|
// bound updates) This can be used to "flag" the presence of an arbitrary packet
|
||||||
|
@ -58,8 +61,8 @@ namespace mediapipe {
|
||||||
class PacketPresenceCalculator : public CalculatorBase {
|
class PacketPresenceCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("PACKET").SetAny();
|
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||||
cc->Outputs().Tag("PRESENCE").Set<bool>();
|
cc->Outputs().Tag(kPresenceTag).Set<bool>();
|
||||||
// Process() function is invoked in response to input stream timestamp
|
// Process() function is invoked in response to input stream timestamp
|
||||||
// bound updates.
|
// bound updates.
|
||||||
cc->SetProcessTimestampBounds(true);
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase {
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("PRESENCE")
|
.Tag(kPresenceTag)
|
||||||
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty())
|
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag(kPacketTag).IsEmpty())
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,11 @@ namespace mediapipe {
|
||||||
|
|
||||||
REGISTER_CALCULATOR(PacketResamplerCalculator);
|
REGISTER_CALCULATOR(PacketResamplerCalculator);
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSeedTag[] = "SEED";
|
||||||
|
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
|
||||||
|
constexpr char kOptionsTag[] = "OPTIONS";
|
||||||
|
|
||||||
// Returns a TimestampDiff (assuming microseconds) corresponding to the
|
// Returns a TimestampDiff (assuming microseconds) corresponding to the
|
||||||
// given time in seconds.
|
// given time in seconds.
|
||||||
TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
||||||
|
@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
||||||
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
||||||
const auto& resampler_options =
|
const auto& resampler_options =
|
||||||
cc->Options<PacketResamplerCalculatorOptions>();
|
cc->Options<PacketResamplerCalculatorOptions>();
|
||||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||||
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
|
cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
|
||||||
}
|
}
|
||||||
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
|
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
|
||||||
if (!input_data_id.IsValid()) {
|
if (!input_data_id.IsValid()) {
|
||||||
input_data_id = cc->Inputs().GetId("", 0);
|
input_data_id = cc->Inputs().GetId("", 0);
|
||||||
}
|
}
|
||||||
cc->Inputs().Get(input_data_id).SetAny();
|
cc->Inputs().Get(input_data_id).SetAny();
|
||||||
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
|
if (cc->Inputs().HasTag(kVideoHeaderTag)) {
|
||||||
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
cc->Inputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
|
|
||||||
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
|
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
|
||||||
|
@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
||||||
output_data_id = cc->Outputs().GetId("", 0);
|
output_data_id = cc->Outputs().GetId("", 0);
|
||||||
}
|
}
|
||||||
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
|
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
|
||||||
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
|
if (cc->Outputs().HasTag(kVideoHeaderTag)) {
|
||||||
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
cc->Outputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resampler_options.jitter() != 0.0) {
|
if (resampler_options.jitter() != 0.0) {
|
||||||
RET_CHECK_GT(resampler_options.jitter(), 0.0);
|
RET_CHECK_GT(resampler_options.jitter(), 0.0);
|
||||||
RET_CHECK_LE(resampler_options.jitter(), 1.0);
|
RET_CHECK_LE(resampler_options.jitter(), 1.0);
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
|
RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag));
|
||||||
cc->InputSidePackets().Tag("SEED").Set<std::string>();
|
cc->InputSidePackets().Tag(kSeedTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
||||||
if (cc->InputTimestamp() == Timestamp::PreStream() &&
|
if (cc->InputTimestamp() == Timestamp::PreStream() &&
|
||||||
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
|
cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) &&
|
||||||
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
|
!cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) {
|
||||||
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
|
video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get<VideoHeader>();
|
||||||
video_header_.frame_rate = frame_rate_;
|
video_header_.frame_rate = frame_rate_;
|
||||||
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
|
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
|
||||||
"ignored, because we are adding jitter.";
|
"ignored, because we are adding jitter.";
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||||
random_ = CreateSecureRandom(seed);
|
random_ = CreateSecureRandom(seed);
|
||||||
if (random_ == nullptr) {
|
if (random_ == nullptr) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
|
@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open(
|
||||||
"ignored, because we are adding jitter.";
|
"ignored, because we are adding jitter.";
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||||
random_ = CreateSecureRandom(seed);
|
random_ = CreateSecureRandom(seed);
|
||||||
if (random_ == nullptr) {
|
if (random_ == nullptr) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
|
@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
|
||||||
"ignored, because we are adding jitter.";
|
"ignored, because we are adding jitter.";
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||||
random_ = CreateSecureRandom(seed);
|
random_ = CreateSecureRandom(seed);
|
||||||
if (random_ == nullptr) {
|
if (random_ == nullptr) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
|
@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
|
||||||
base_timestamp_ +
|
base_timestamp_ +
|
||||||
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
|
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
|
||||||
}
|
}
|
||||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
|
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("VIDEO_HEADER")
|
.Tag(kVideoHeaderTag)
|
||||||
.Add(new VideoHeader(calculator_->video_header_),
|
.Add(new VideoHeader(calculator_->video_header_),
|
||||||
Timestamp::PreStream());
|
Timestamp::PreStream());
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,12 @@ namespace mediapipe {
|
||||||
|
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kOptionsTag[] = "OPTIONS";
|
||||||
|
constexpr char kSeedTag[] = "SEED";
|
||||||
|
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
|
||||||
|
constexpr char kDataTag[] = "DATA";
|
||||||
|
|
||||||
// A simple version of CalculatorRunner with built-in convenience
|
// A simple version of CalculatorRunner with built-in convenience
|
||||||
// methods for setting inputs from a vector and checking outputs
|
// methods for setting inputs from a vector and checking outputs
|
||||||
// against expected outputs (both timestamps and contents).
|
// against expected outputs (both timestamps and contents).
|
||||||
|
@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
||||||
)pb"));
|
)pb"));
|
||||||
|
|
||||||
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
||||||
runner.MutableInputs()->Tag("DATA").packets.push_back(
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
|
||||||
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
||||||
}
|
}
|
||||||
VideoHeader video_header_in;
|
VideoHeader video_header_in;
|
||||||
|
@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
||||||
video_header_in.duration = 1.0;
|
video_header_in.duration = 1.0;
|
||||||
video_header_in.format = ImageFormat::SRGB;
|
video_header_in.format = ImageFormat::SRGB;
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("VIDEO_HEADER")
|
->Tag(kVideoHeaderTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
|
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
|
ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size());
|
||||||
EXPECT_EQ(Timestamp::PreStream(),
|
EXPECT_EQ(Timestamp::PreStream(),
|
||||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
|
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp());
|
||||||
const VideoHeader& video_header_out =
|
const VideoHeader& video_header_out =
|
||||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
|
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get<VideoHeader>();
|
||||||
EXPECT_EQ(video_header_in.width, video_header_out.width);
|
EXPECT_EQ(video_header_in.width, video_header_out.width);
|
||||||
EXPECT_EQ(video_header_in.height, video_header_out.height);
|
EXPECT_EQ(video_header_in.height, video_header_out.height);
|
||||||
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
|
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
|
||||||
|
@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
|
||||||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||||
frame_rate: 30
|
frame_rate: 30
|
||||||
})pb"));
|
})pb"));
|
||||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
|
||||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
|
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
|
||||||
|
@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
|
||||||
frame_rate: 30
|
frame_rate: 30
|
||||||
base_timestamp: 0
|
base_timestamp: 0
|
||||||
})pb"));
|
})pb"));
|
||||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
|
||||||
|
|
||||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
|
@ -29,6 +29,8 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPeriodTag[] = "PERIOD";
|
||||||
|
|
||||||
// A simple version of CalculatorRunner with built-in convenience methods for
|
// A simple version of CalculatorRunner with built-in convenience methods for
|
||||||
// setting inputs from a vector and checking outputs against a vector of
|
// setting inputs from a vector and checking outputs against a vector of
|
||||||
// expected outputs.
|
// expected outputs.
|
||||||
|
@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
|
||||||
|
|
||||||
SimpleRunner runner(node);
|
SimpleRunner runner(node);
|
||||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
||||||
|
@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
|
||||||
|
|
||||||
SimpleRunner runner(node);
|
SimpleRunner runner(node);
|
||||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
||||||
|
|
|
@ -39,6 +39,8 @@ using ::testing::Pair;
|
||||||
using ::testing::Value;
|
using ::testing::Value;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kDisallowTag[] = "DISALLOW";
|
||||||
|
|
||||||
// Returns the timestamp values for a vector of Packets.
|
// Returns the timestamp values for a vector of Packets.
|
||||||
// TODO: puth this kind of test util in a common place.
|
// TODO: puth this kind of test util in a common place.
|
||||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||||
|
@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).SetAny();
|
cc->Inputs().Index(0).SetAny();
|
||||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
cc->Inputs().Tag(kDisallowTag).Set<bool>();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
if (!cc->Inputs().Index(0).IsEmpty() &&
|
if (!cc->Inputs().Index(0).IsEmpty() &&
|
||||||
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
|
!cc->Inputs().Tag(kDisallowTag).Get<bool>()) {
|
||||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -41,11 +41,14 @@
|
||||||
// }
|
// }
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kEncodedTag[] = "ENCODED";
|
||||||
|
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||||
|
|
||||||
class QuantizeFloatVectorCalculator : public CalculatorBase {
|
class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
cc->Inputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
|
||||||
cc->Outputs().Tag("ENCODED").Set<std::string>();
|
cc->Outputs().Tag(kEncodedTag).Set<std::string>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
const std::vector<float>& float_vector =
|
const std::vector<float>& float_vector =
|
||||||
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
|
cc->Inputs().Tag(kFloatVectorTag).Value().Get<std::vector<float>>();
|
||||||
int feature_size = float_vector.size();
|
int feature_size = float_vector.size();
|
||||||
std::string encoded_features;
|
std::string encoded_features;
|
||||||
encoded_features.reserve(feature_size);
|
encoded_features.reserve(feature_size);
|
||||||
|
@ -86,7 +89,9 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||||
(old_value - min_quantized_value_) * (255.0 / range_));
|
(old_value - min_quantized_value_) * (255.0 / range_));
|
||||||
encoded_features += encoded;
|
encoded_features += encoded;
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("ENCODED").AddPacket(
|
cc->Outputs()
|
||||||
|
.Tag(kEncodedTag)
|
||||||
|
.AddPacket(
|
||||||
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kEncodedTag[] = "ENCODED";
|
||||||
|
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||||
|
|
||||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||||
CalculatorGraphConfig::Node node_config =
|
CalculatorGraphConfig::Node node_config =
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> empty_vector;
|
std::vector<float> empty_vector;
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
|
@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> empty_vector;
|
std::vector<float> empty_vector;
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
|
@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> empty_vector;
|
std::vector<float> empty_vector;
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||||
auto status = runner.Run();
|
auto status = runner.Run();
|
||||||
|
@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> empty_vector;
|
std::vector<float> empty_vector;
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
const std::vector<Packet>& outputs =
|
||||||
|
runner.Outputs().Tag(kEncodedTag).packets;
|
||||||
EXPECT_EQ(1, outputs.size());
|
EXPECT_EQ(1, outputs.size());
|
||||||
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
|
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
|
||||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||||
|
@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
|
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
const std::vector<Packet>& outputs =
|
||||||
|
runner.Outputs().Tag(kEncodedTag).packets;
|
||||||
EXPECT_EQ(1, outputs.size());
|
EXPECT_EQ(1, outputs.size());
|
||||||
const std::string& result = outputs[0].Get<std::string>();
|
const std::string& result = outputs[0].Get<std::string>();
|
||||||
ASSERT_FALSE(result.empty());
|
ASSERT_FALSE(result.empty());
|
||||||
|
@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
std::vector<float> vector = {-65.0f, 65.0f};
|
std::vector<float> vector = {-65.0f, 65.0f};
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FLOAT_VECTOR")
|
->Tag(kFloatVectorTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
const std::vector<Packet>& outputs =
|
||||||
|
runner.Outputs().Tag(kEncodedTag).packets;
|
||||||
EXPECT_EQ(1, outputs.size());
|
EXPECT_EQ(1, outputs.size());
|
||||||
const std::string& result = outputs[0].Get<std::string>();
|
const std::string& result = outputs[0].Get<std::string>();
|
||||||
ASSERT_FALSE(result.empty());
|
ASSERT_FALSE(result.empty());
|
||||||
|
|
|
@ -23,6 +23,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kAllowTag[] = "ALLOW";
|
||||||
|
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
|
||||||
|
|
||||||
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
|
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
|
||||||
// processing operations in a section of the graph.
|
// processing operations in a section of the graph.
|
||||||
//
|
//
|
||||||
|
@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
||||||
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
||||||
}
|
}
|
||||||
cc->Inputs().Get("FINISHED", 0).SetAny();
|
cc->Inputs().Get("FINISHED", 0).SetAny();
|
||||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>();
|
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("ALLOW")) {
|
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||||
cc->Outputs().Tag("ALLOW").Set<bool>();
|
cc->Outputs().Tag(kAllowTag).Set<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
finished_id_ = cc->Inputs().GetId("FINISHED", 0);
|
finished_id_ = cc->Inputs().GetId("FINISHED", 0);
|
||||||
max_in_flight_ = 1;
|
max_in_flight_ = 1;
|
||||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||||
max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>();
|
max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>();
|
||||||
}
|
}
|
||||||
RET_CHECK_GE(max_in_flight_, 1);
|
RET_CHECK_GE(max_in_flight_, 1);
|
||||||
num_in_flight_ = 0;
|
num_in_flight_ = 0;
|
||||||
|
|
|
@ -33,6 +33,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kFinishedTag[] = "FINISHED";
|
||||||
|
|
||||||
// A simple Semaphore for synchronizing test threads.
|
// A simple Semaphore for synchronizing test threads.
|
||||||
class AtomicSemaphore {
|
class AtomicSemaphore {
|
||||||
public:
|
public:
|
||||||
|
@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) {
|
||||||
Timestamp timestamp =
|
Timestamp timestamp =
|
||||||
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
|
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("FINISHED")
|
->Tag(kFinishedTag)
|
||||||
.packets.push_back(MakePacket<bool>(true).At(timestamp));
|
.packets.push_back(MakePacket<bool>(true).At(timestamp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPacketOffsetTag[] = "PACKET_OFFSET";
|
||||||
|
|
||||||
// Adds packets containing integers equal to their original timestamp.
|
// Adds packets containing integers equal to their original timestamp.
|
||||||
void AddPackets(CalculatorRunner* runner) {
|
void AddPackets(CalculatorRunner* runner) {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) {
|
||||||
|
|
||||||
CalculatorRunner runner(node);
|
CalculatorRunner runner(node);
|
||||||
AddPackets(&runner);
|
AddPackets(&runner);
|
||||||
runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2));
|
runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& input_packets =
|
const std::vector<Packet>& input_packets =
|
||||||
runner.MutableInputs()->Index(0).packets;
|
runner.MutableInputs()->Index(0).packets;
|
||||||
|
|
|
@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode(
|
||||||
// IMAGE: ImageFrame representing the input image.
|
// IMAGE: ImageFrame representing the input image.
|
||||||
// IMAGE_GPU: GpuBuffer representing the input image.
|
// IMAGE_GPU: GpuBuffer representing the input image.
|
||||||
//
|
//
|
||||||
|
// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as
|
||||||
|
// pair<int, int>. If set, it will override corresponding field in calculator
|
||||||
|
// options and input side packet.
|
||||||
|
//
|
||||||
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
|
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
|
||||||
// degrees. This allows different rotation angles for different frames. It has
|
// degrees. This allows different rotation angles for different frames. It has
|
||||||
// to be a multiple of 90 degrees. If provided, it overrides the
|
// to be a multiple of 90 degrees. If provided, it overrides the
|
||||||
|
@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract(
|
||||||
}
|
}
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||||
|
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<std::pair<int, int>>();
|
||||||
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
|
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
|
||||||
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
|
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
|
||||||
}
|
}
|
||||||
|
@ -329,6 +337,13 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) {
|
||||||
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
|
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
|
||||||
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
|
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
|
||||||
}
|
}
|
||||||
|
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") &&
|
||||||
|
!cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
|
||||||
|
const auto& image_size =
|
||||||
|
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<std::pair<int, int>>();
|
||||||
|
output_width_ = image_size.first;
|
||||||
|
output_height_ = image_size.second;
|
||||||
|
}
|
||||||
|
|
||||||
if (use_gpu_) {
|
if (use_gpu_) {
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
|
@ -88,6 +88,13 @@ proto_library(
|
||||||
deps = ["//mediapipe/framework:calculator_proto"],
|
deps = ["//mediapipe/framework:calculator_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
proto_library(
|
||||||
|
name = "tensor_to_vector_string_calculator_options_proto",
|
||||||
|
srcs = ["tensor_to_vector_string_calculator_options.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = ["//mediapipe/framework:calculator_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
proto_library(
|
proto_library(
|
||||||
name = "unpack_media_sequence_calculator_proto",
|
name = "unpack_media_sequence_calculator_proto",
|
||||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||||
|
@ -257,6 +264,14 @@ mediapipe_cc_proto_library(
|
||||||
deps = [":tensor_to_vector_float_calculator_options_proto"],
|
deps = [":tensor_to_vector_float_calculator_options_proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_cc_proto_library(
|
||||||
|
name = "tensor_to_vector_string_calculator_options_cc_proto",
|
||||||
|
srcs = ["tensor_to_vector_string_calculator_options.proto"],
|
||||||
|
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":tensor_to_vector_string_calculator_options_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_cc_proto_library(
|
mediapipe_cc_proto_library(
|
||||||
name = "unpack_media_sequence_calculator_cc_proto",
|
name = "unpack_media_sequence_calculator_cc_proto",
|
||||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||||
|
@ -694,6 +709,26 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_to_vector_string_calculator",
|
||||||
|
srcs = ["tensor_to_vector_string_calculator.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
":tensor_to_vector_string_calculator_options_cc_proto",
|
||||||
|
] + select({
|
||||||
|
"//conditions:default": [
|
||||||
|
"@org_tensorflow//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
"//mediapipe:android": [
|
||||||
|
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "unpack_media_sequence_calculator",
|
name = "unpack_media_sequence_calculator",
|
||||||
srcs = ["unpack_media_sequence_calculator.cc"],
|
srcs = ["unpack_media_sequence_calculator.cc"],
|
||||||
|
@ -1059,6 +1094,20 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tensor_to_vector_string_calculator_test",
|
||||||
|
srcs = ["tensor_to_vector_string_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":tensor_to_vector_string_calculator",
|
||||||
|
":tensor_to_vector_string_calculator_options_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@org_tensorflow//tensorflow/core:framework",
|
||||||
|
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "unpack_media_sequence_calculator_test",
|
name = "unpack_media_sequence_calculator_test",
|
||||||
srcs = ["unpack_media_sequence_calculator_test.cc"],
|
srcs = ["unpack_media_sequence_calculator_test.cc"],
|
||||||
|
|
|
@ -40,6 +40,24 @@ namespace {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
namespace mpms = mediapipe::mediasequence;
|
namespace mpms = mediapipe::mediasequence;
|
||||||
|
|
||||||
|
constexpr char kBboxTag[] = "BBOX";
|
||||||
|
constexpr char kEncodedMediaStartTimestampTag[] =
|
||||||
|
"ENCODED_MEDIA_START_TIMESTAMP";
|
||||||
|
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
|
||||||
|
constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION";
|
||||||
|
constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST";
|
||||||
|
constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED";
|
||||||
|
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
|
||||||
|
constexpr char kAudioTestTag[] = "AUDIO_TEST";
|
||||||
|
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||||
|
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
|
||||||
|
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
|
||||||
|
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||||
|
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||||
|
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||||
|
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
|
||||||
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpCalculator(const std::vector<std::string>& input_streams,
|
void SetUpCalculator(const std::vector<std::string>& input_streams,
|
||||||
|
@ -83,17 +101,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -127,17 +145,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("IMAGE_PREFIX")
|
->Tag(kImagePrefixTag)
|
||||||
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -161,21 +179,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
|
||||||
for (int i = 0; i < num_timesteps; ++i) {
|
for (int i = 0; i < num_timesteps; ++i) {
|
||||||
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FLOAT_FEATURE_TEST")
|
->Tag(kFloatFeatureTestTag)
|
||||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||||
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FLOAT_FEATURE_OTHER")
|
->Tag(kFloatFeatureOtherTag)
|
||||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -228,20 +246,20 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
|
||||||
|
|
||||||
auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3);
|
auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FLOAT_CONTEXT_FEATURE_TEST")
|
->Tag(kFloatContextFeatureTestTag)
|
||||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
||||||
vf_ptr = absl::make_unique<std::vector<float>>(2, 4);
|
vf_ptr = absl::make_unique<std::vector<float>>(2, 4);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FLOAT_CONTEXT_FEATURE_OTHER")
|
->Tag(kFloatContextFeatureOtherTag)
|
||||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -259,7 +277,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
||||||
SetUpCalculator({"IMAGE:images"}, context, false, true);
|
SetUpCalculator({"IMAGE:images"}, context, false, true);
|
||||||
|
|
||||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
std::vector<uchar> bytes;
|
std::vector<uchar> bytes;
|
||||||
|
@ -268,13 +286,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
||||||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
Adopt(image_ptr.release()).At(Timestamp(0)));
|
Adopt(image_ptr.release()).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -307,17 +325,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
|
||||||
auto flow_ptr =
|
auto flow_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FORWARD_FLOW_ENCODED")
|
->Tag(kForwardFlowEncodedTag)
|
||||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -371,17 +389,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
|
||||||
detections->push_back(detection);
|
detections->push_back(detection);
|
||||||
|
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("BBOX_PREDICTED")
|
->Tag(kBboxPredictedTag)
|
||||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -450,11 +468,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) {
|
||||||
detections->push_back(detection);
|
detections->push_back(detection);
|
||||||
|
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("BBOX_PREDICTED")
|
->Tag(kBboxPredictedTag)
|
||||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
auto status = runner_->Run();
|
auto status = runner_->Run();
|
||||||
|
@ -498,7 +516,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
||||||
detections->push_back(detection);
|
detections->push_back(detection);
|
||||||
|
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("BBOX_PREDICTED")
|
->Tag(kBboxPredictedTag)
|
||||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||||
|
@ -513,16 +531,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -564,18 +582,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
|
||||||
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
|
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
|
||||||
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("KEYPOINTS_TEST")
|
->Tag(kKeypointsTestTag)
|
||||||
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
|
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("KEYPOINTS_TEST")
|
->Tag(kKeypointsTestTag)
|
||||||
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
|
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -615,17 +633,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
||||||
detections->push_back(detection);
|
detections->push_back(detection);
|
||||||
|
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("CLASS_SEGMENTATION")
|
->Tag(kClassSegmentationTag)
|
||||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -664,17 +682,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
|
||||||
auto flow_ptr =
|
auto flow_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FORWARD_FLOW_ENCODED")
|
->Tag(kForwardFlowEncodedTag)
|
||||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -710,11 +728,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
|
||||||
auto flow_ptr =
|
auto flow_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FORWARD_FLOW_ENCODED")
|
->Tag(kForwardFlowEncodedTag)
|
||||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
absl::Status status = runner_->Run();
|
absl::Status status = runner_->Run();
|
||||||
|
@ -731,13 +749,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) {
|
||||||
mpms::AddImageTimestamp(1, input_sequence.get());
|
mpms::AddImageTimestamp(1, input_sequence.get());
|
||||||
mpms::AddImageTimestamp(2, input_sequence.get());
|
mpms::AddImageTimestamp(2, input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -757,13 +775,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) {
|
||||||
mpms::AddForwardFlowTimestamp(1, input_sequence.get());
|
mpms::AddForwardFlowTimestamp(1, input_sequence.get());
|
||||||
mpms::AddForwardFlowTimestamp(2, input_sequence.get());
|
mpms::AddForwardFlowTimestamp(2, input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -794,13 +812,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) {
|
||||||
mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
|
mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
|
||||||
ASSERT_EQ(num_timesteps,
|
ASSERT_EQ(num_timesteps,
|
||||||
mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
|
mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -826,7 +844,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
|
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -838,11 +856,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
||||||
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
|
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
|
||||||
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
|
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
ASSERT_EQ(1, output_packets.size());
|
ASSERT_EQ(1, output_packets.size());
|
||||||
const tf::SequenceExample& output_sequence =
|
const tf::SequenceExample& output_sequence =
|
||||||
output_packets[0].Get<tf::SequenceExample>();
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
@ -879,7 +897,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
auto image_ptr =
|
auto image_ptr =
|
||||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -893,7 +911,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
||||||
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
|
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
|
||||||
.ConvertToProto(detection.mutable_location_data());
|
.ConvertToProto(detection.mutable_location_data());
|
||||||
detections->push_back(detection);
|
detections->push_back(detection);
|
||||||
runner_->MutableInputs()->Tag("BBOX").packets.push_back(
|
runner_->MutableInputs()->Tag(kBboxTag).packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp(i)));
|
Adopt(detections.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -909,7 +927,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
||||||
mpms::AddBBoxTrackIndex({-1}, input_sequence.get());
|
mpms::AddBBoxTrackIndex({-1}, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
// If the all the previous values aren't cleared, this assert will fail.
|
// If the all the previous values aren't cleared, this assert will fail.
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
@ -925,11 +943,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) {
|
||||||
for (int i = 0; i < num_timesteps; ++i) {
|
for (int i = 0; i < num_timesteps; ++i) {
|
||||||
auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i);
|
auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("FLOAT_FEATURE_TEST")
|
->Tag(kFloatFeatureTestTag)
|
||||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
ASSERT_FALSE(runner_->Run().ok());
|
ASSERT_FALSE(runner_->Run().ok());
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ namespace mediapipe {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kReferenceTag[] = "REFERENCE";
|
||||||
|
|
||||||
constexpr char kMatrix[] = "MATRIX";
|
constexpr char kMatrix[] = "MATRIX";
|
||||||
constexpr char kTensor[] = "TENSOR";
|
constexpr char kTensor[] = "TENSOR";
|
||||||
|
|
||||||
|
@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test {
|
||||||
if (include_rate) {
|
if (include_rate) {
|
||||||
header->set_packet_rate(1.0);
|
header->set_packet_rate(1.0);
|
||||||
}
|
}
|
||||||
runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release());
|
runner_->MutableInputs()->Tag(kReferenceTag).header =
|
||||||
|
Adopt(header.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<CalculatorRunner> runner_;
|
std::unique_ptr<CalculatorRunner> runner_;
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2019 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Calculator converts from one-dimensional Tensor of DT_STRING to
|
||||||
|
// vector<std::string> OR from (batched) two-dimensional Tensor of DT_STRING to
|
||||||
|
// vector<vector<std::string>.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
class TensorToVectorStringCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorToVectorStringCalculatorOptions options_;
|
||||||
|
};
|
||||||
|
REGISTER_CALCULATOR(TensorToVectorStringCalculator);
|
||||||
|
|
||||||
|
absl::Status TensorToVectorStringCalculator::GetContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
// Start with only one input packet.
|
||||||
|
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||||
|
<< "Only one input stream is supported.";
|
||||||
|
cc->Inputs().Index(0).Set<tf::Tensor>(
|
||||||
|
// Input Tensor
|
||||||
|
);
|
||||||
|
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||||
|
<< "Only one output stream is supported.";
|
||||||
|
const auto& options = cc->Options<TensorToVectorStringCalculatorOptions>();
|
||||||
|
if (options.tensor_is_2d()) {
|
||||||
|
RET_CHECK(!options.flatten_nd());
|
||||||
|
cc->Outputs().Index(0).Set<std::vector<std::vector<std::string>>>(
|
||||||
|
/* "Output vector<vector<std::string>>." */);
|
||||||
|
} else {
|
||||||
|
cc->Outputs().Index(0).Set<std::vector<std::string>>(
|
||||||
|
// Output vector<std::string>.
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) {
|
||||||
|
options_ = cc->Options<TensorToVectorStringCalculatorOptions>();
|
||||||
|
|
||||||
|
// Inform mediapipe that this calculator produces an output at time t for
|
||||||
|
// each input received at time t (i.e. this calculator does not buffer
|
||||||
|
// inputs). This enables mediapipe to propagate time of arrival estimates in
|
||||||
|
// mediapipe graphs through this calculator.
|
||||||
|
cc->SetOffset(/*offset=*/0);
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) {
|
||||||
|
const tf::Tensor& input_tensor =
|
||||||
|
cc->Inputs().Index(0).Value().Get<tf::Tensor>();
|
||||||
|
RET_CHECK(tf::DT_STRING == input_tensor.dtype())
|
||||||
|
<< "expected DT_STRING input but got "
|
||||||
|
<< tensorflow::DataTypeString(input_tensor.dtype());
|
||||||
|
|
||||||
|
if (options_.tensor_is_2d()) {
|
||||||
|
RET_CHECK(2 == input_tensor.dims())
|
||||||
|
<< "Expected 2-dimensional Tensor, but the tensor shape is: "
|
||||||
|
<< input_tensor.shape().DebugString();
|
||||||
|
auto output = absl::make_unique<std::vector<std::vector<std::string>>>(
|
||||||
|
input_tensor.dim_size(0),
|
||||||
|
std::vector<std::string>(input_tensor.dim_size(1)));
|
||||||
|
for (int i = 0; i < input_tensor.dim_size(0); ++i) {
|
||||||
|
auto& instance_output = output->at(i);
|
||||||
|
const auto& slice =
|
||||||
|
input_tensor.Slice(i, i + 1).unaligned_flat<tensorflow::tstring>();
|
||||||
|
for (int j = 0; j < input_tensor.dim_size(1); ++j) {
|
||||||
|
instance_output.at(j) = slice(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||||
|
} else {
|
||||||
|
if (!options_.flatten_nd()) {
|
||||||
|
RET_CHECK(1 == input_tensor.dims())
|
||||||
|
<< "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the "
|
||||||
|
<< "tensor shape is: " << input_tensor.shape().DebugString();
|
||||||
|
}
|
||||||
|
auto output =
|
||||||
|
absl::make_unique<std::vector<std::string>>(input_tensor.NumElements());
|
||||||
|
const auto& tensor_values = input_tensor.flat<tensorflow::tstring>();
|
||||||
|
for (int i = 0; i < input_tensor.NumElements(); ++i) {
|
||||||
|
output->at(i) = tensor_values(i);
|
||||||
|
}
|
||||||
|
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright 2019 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message TensorToVectorStringCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional TensorToVectorStringCalculatorOptions ext = 386534187;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If true, unpack a 2d tensor (matrix) into a vector<vector<string>>. If
|
||||||
|
// false, convert a 1d tensor (vector) into a vector<string>.
|
||||||
|
optional bool tensor_is_2d = 1 [default = false];
|
||||||
|
|
||||||
|
// If true, an N-D tensor will be flattened to a vector<string>. This is
|
||||||
|
// exclusive with tensor_is_2d.
|
||||||
|
optional bool flatten_nd = 2 [default = false];
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
// Copyright 2018 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
class TensorToVectorStringCalculatorTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) {
|
||||||
|
CalculatorGraphConfig::Node config;
|
||||||
|
config.set_calculator("TensorToVectorStringCalculator");
|
||||||
|
config.add_input_stream("input_tensor");
|
||||||
|
config.add_output_stream("output_tensor");
|
||||||
|
auto options = config.mutable_options()->MutableExtension(
|
||||||
|
TensorToVectorStringCalculatorOptions::ext);
|
||||||
|
options->set_tensor_is_2d(tensor_is_2d);
|
||||||
|
options->set_flatten_nd(flatten_nd);
|
||||||
|
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<CalculatorRunner> runner_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) {
|
||||||
|
SetUpRunner(false, false);
|
||||||
|
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||||
|
auto tensor_vec = tensor->vec<tensorflow::tstring>();
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
tensor_vec(i) = absl::StrCat("foo", i);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 time = 1234;
|
||||||
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const std::vector<std::string>& output_vector =
|
||||||
|
output_packets[0].Get<std::vector<std::string>>();
|
||||||
|
|
||||||
|
EXPECT_EQ(5, output_vector.size());
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
const std::string expected = absl::StrCat("foo", i);
|
||||||
|
EXPECT_EQ(expected, output_vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) {
|
||||||
|
SetUpRunner(true, false);
|
||||||
|
const tf::TensorShape tensor_shape(std::vector<tf::int64>{1, 5});
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||||
|
auto slice = tensor->Slice(0, 1).flat<tensorflow::tstring>();
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
slice(i) = absl::StrCat("foo", i);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 time = 1234;
|
||||||
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const std::vector<std::vector<std::string>>& output_vectors =
|
||||||
|
output_packets[0].Get<std::vector<std::vector<std::string>>>();
|
||||||
|
ASSERT_EQ(1, output_vectors.size());
|
||||||
|
const std::vector<std::string>& output_vector = output_vectors[0];
|
||||||
|
EXPECT_EQ(5, output_vector.size());
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
const std::string expected = absl::StrCat("foo", i);
|
||||||
|
EXPECT_EQ(expected, output_vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) {
|
||||||
|
SetUpRunner(false, true);
|
||||||
|
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 2, 2});
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||||
|
auto slice = tensor->flat<tensorflow::tstring>();
|
||||||
|
for (int i = 0; i < 2 * 2 * 2; ++i) {
|
||||||
|
slice(i) = absl::StrCat("foo", i);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 time = 1234;
|
||||||
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const std::vector<std::string>& output_vector =
|
||||||
|
output_packets[0].Get<std::vector<std::string>>();
|
||||||
|
EXPECT_EQ(2 * 2 * 2, output_vector.size());
|
||||||
|
for (int i = 0; i < 2 * 2 * 2; ++i) {
|
||||||
|
const std::string expected = absl::StrCat("foo", i);
|
||||||
|
EXPECT_EQ(expected, output_vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -49,6 +49,11 @@ namespace tf = ::tensorflow;
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
|
||||||
|
|
||||||
// This is a simple implementation of a semaphore using standard C++ libraries.
|
// This is a simple implementation of a semaphore using standard C++ libraries.
|
||||||
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
|
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
|
||||||
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
|
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
|
||||||
|
@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
||||||
// For this calculator it must include a tag_to_tensor_map.
|
// For this calculator it must include a tag_to_tensor_map.
|
||||||
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
cc->InputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
|
||||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) {
|
||||||
cc->InputSidePackets()
|
cc->InputSidePackets()
|
||||||
.Tag("RECURRENT_INIT_TENSORS")
|
.Tag(kRecurrentInitTensorsTag)
|
||||||
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||||
std::unique_ptr<InferenceState> inference_state =
|
std::unique_ptr<InferenceState> inference_state =
|
||||||
absl::make_unique<InferenceState>();
|
absl::make_unique<InferenceState>();
|
||||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) &&
|
||||||
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
!cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) {
|
||||||
std::map<std::string, tf::Tensor>* init_tensor_map;
|
std::map<std::string, tf::Tensor>* init_tensor_map;
|
||||||
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
||||||
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
cc->InputSidePackets().Tag(kRecurrentInitTensorsTag));
|
||||||
for (const auto& p : *init_tensor_map) {
|
for (const auto& p : *init_tensor_map) {
|
||||||
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
||||||
}
|
}
|
||||||
|
@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||||
|
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag));
|
||||||
session_ = cc->InputSidePackets()
|
session_ = cc->InputSidePackets()
|
||||||
.Tag("SESSION")
|
.Tag(kSessionTag)
|
||||||
.Get<TensorFlowSession>()
|
.Get<TensorFlowSession>()
|
||||||
.session.get();
|
.session.get();
|
||||||
tag_to_tensor_map_ = cc->InputSidePackets()
|
tag_to_tensor_map_ = cc->InputSidePackets()
|
||||||
.Tag("SESSION")
|
.Tag(kSessionTag)
|
||||||
.Get<TensorFlowSession>()
|
.Get<TensorFlowSession>()
|
||||||
.tag_to_tensor_map;
|
.tag_to_tensor_map;
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,11 @@ namespace mediapipe {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kMultipliedTag[] = "MULTIPLIED";
|
||||||
|
constexpr char kBTag[] = "B";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetGraphDefPath() {
|
std::string GetGraphDefPath() {
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
char path[1024];
|
char path[1024];
|
||||||
|
@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
|
||||||
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
|
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
|
||||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
|
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
|
||||||
input_side_packets, &output_side_packets));
|
input_side_packets, &output_side_packets));
|
||||||
runner_->MutableSidePackets()->Tag("SESSION") =
|
runner_->MutableSidePackets()->Tag(kSessionTag) =
|
||||||
output_side_packets.Tag("SESSION");
|
output_side_packets.Tag(kSessionTag);
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
|
Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
|
||||||
|
@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_b =
|
const std::vector<Packet>& output_packets_b =
|
||||||
runner_->Outputs().Tag("B").packets;
|
runner_->Outputs().Tag(kBTag).packets;
|
||||||
ASSERT_EQ(output_packets_b.size(), 1);
|
ASSERT_EQ(output_packets_b.size(), 1);
|
||||||
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
|
||||||
tf::TensorShape expected_shape({1, 3});
|
tf::TensorShape expected_shape({1, 3});
|
||||||
|
@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
|
||||||
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
|
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(1, output_packets_mult.size());
|
ASSERT_EQ(1, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
|
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
|
||||||
|
@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(1, output_packets_mult.size());
|
ASSERT_EQ(1, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
tf::TensorShape expected_shape({3});
|
tf::TensorShape expected_shape({3});
|
||||||
|
@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(1, output_packets_mult.size());
|
ASSERT_EQ(1, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
tf::TensorShape expected_shape({3});
|
tf::TensorShape expected_shape({3});
|
||||||
|
@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(3, output_packets_mult.size());
|
ASSERT_EQ(3, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(5, output_packets_mult.size());
|
ASSERT_EQ(5, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||||
|
@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
LOG(INFO) << "timestamp: " << 0;
|
LOG(INFO) << "timestamp: " << 0;
|
||||||
|
@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(2, output_packets_mult.size());
|
ASSERT_EQ(2, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
LOG(INFO) << "timestamp: " << 0;
|
LOG(INFO) << "timestamp: " << 0;
|
||||||
|
@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) {
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(0, output_packets_mult.size());
|
ASSERT_EQ(0, output_packets_mult.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_mult =
|
const std::vector<Packet>& output_packets_mult =
|
||||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||||
ASSERT_EQ(1, output_packets_mult.size());
|
ASSERT_EQ(1, output_packets_mult.size());
|
||||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||||
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});
|
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});
|
||||||
|
|
|
@ -47,6 +47,11 @@ namespace mediapipe {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||||
|
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||||
|
|
||||||
// Updates the graph nodes to use the device as specified by device_id.
|
// Updates the graph nodes to use the device as specified by device_id.
|
||||||
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
||||||
for (auto& node : *graph_def->mutable_node()) {
|
for (auto& node : *graph_def->mutable_node()) {
|
||||||
|
@ -64,27 +69,29 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
||||||
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
|
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
|
||||||
bool has_exactly_one_model =
|
bool has_exactly_one_model =
|
||||||
!options.graph_proto_path().empty()
|
!options.graph_proto_path().empty()
|
||||||
? !(cc->InputSidePackets().HasTag("STRING_MODEL") |
|
? !(cc->InputSidePackets().HasTag(kStringModelTag) |
|
||||||
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"))
|
cc->InputSidePackets().HasTag(kStringModelFilePathTag))
|
||||||
: (cc->InputSidePackets().HasTag("STRING_MODEL") ^
|
: (cc->InputSidePackets().HasTag(kStringModelTag) ^
|
||||||
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"));
|
cc->InputSidePackets().HasTag(kStringModelFilePathTag));
|
||||||
RET_CHECK(has_exactly_one_model)
|
RET_CHECK(has_exactly_one_model)
|
||||||
<< "Must have exactly one of graph_proto_path in options or "
|
<< "Must have exactly one of graph_proto_path in options or "
|
||||||
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
||||||
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
|
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
|
||||||
cc->InputSidePackets()
|
cc->InputSidePackets()
|
||||||
.Tag("STRING_MODEL")
|
.Tag(kStringModelTag)
|
||||||
.Set<std::string>(
|
.Set<std::string>(
|
||||||
// String model from embedded path
|
// String model from embedded path
|
||||||
);
|
);
|
||||||
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
|
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
|
||||||
cc->InputSidePackets()
|
cc->InputSidePackets()
|
||||||
.Tag("STRING_MODEL_FILE_PATH")
|
.Tag(kStringModelFilePathTag)
|
||||||
.Set<std::string>(
|
.Set<std::string>(
|
||||||
// Filename of std::string model.
|
// Filename of std::string model.
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>(
|
cc->OutputSidePackets()
|
||||||
|
.Tag(kSessionTag)
|
||||||
|
.Set<TensorFlowSession>(
|
||||||
// A TensorFlow model loaded and ready for use along with
|
// A TensorFlow model loaded and ready for use along with
|
||||||
// a map from tags to tensor names.
|
// a map from tags to tensor names.
|
||||||
);
|
);
|
||||||
|
@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
||||||
session->session.reset(tf::NewSession(session_options));
|
session->session.reset(tf::NewSession(session_options));
|
||||||
|
|
||||||
std::string graph_def_serialized;
|
std::string graph_def_serialized;
|
||||||
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
|
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
|
||||||
graph_def_serialized =
|
graph_def_serialized =
|
||||||
cc->InputSidePackets().Tag("STRING_MODEL").Get<std::string>();
|
cc->InputSidePackets().Tag(kStringModelTag).Get<std::string>();
|
||||||
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
|
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
|
||||||
const std::string& frozen_graph = cc->InputSidePackets()
|
const std::string& frozen_graph = cc->InputSidePackets()
|
||||||
.Tag("STRING_MODEL_FILE_PATH")
|
.Tag(kStringModelFilePathTag)
|
||||||
.Get<std::string>();
|
.Get<std::string>();
|
||||||
RET_CHECK_OK(
|
RET_CHECK_OK(
|
||||||
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
||||||
|
@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
||||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
|
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
|
||||||
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
||||||
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
||||||
<< " microseconds.";
|
<< " microseconds.";
|
||||||
|
|
|
@ -37,6 +37,10 @@ namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||||
|
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetGraphDefPath() {
|
std::string GetGraphDefPath() {
|
||||||
return mediapipe::file::JoinPath("./",
|
return mediapipe::file::JoinPath("./",
|
||||||
"mediapipe/calculators/tensorflow/"
|
"mediapipe/calculators/tensorflow/"
|
||||||
|
@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
VerifySignatureMap(session);
|
VerifySignatureMap(session);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
std::string serialized_graph_contents;
|
std::string serialized_graph_contents;
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
VerifySignatureMap(session);
|
VerifySignatureMap(session);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,12 +217,12 @@ TEST_F(
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
calculator_options_->DebugString()));
|
calculator_options_->DebugString()));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
VerifySignatureMap(session);
|
VerifySignatureMap(session);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
calculator_options_->DebugString()));
|
calculator_options_->DebugString()));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
auto run_status = runner.Run();
|
auto run_status = runner.Run();
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
calculator_options_->DebugString()));
|
calculator_options_->DebugString()));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
std::string serialized_graph_contents;
|
std::string serialized_graph_contents;
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
auto run_status = runner.Run();
|
auto run_status = runner.Run();
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
calculator_options_->DebugString()));
|
calculator_options_->DebugString()));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
std::string serialized_graph_contents;
|
std::string serialized_graph_contents;
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
auto run_status = runner.Run();
|
auto run_status = runner.Run();
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
VerifySignatureMap(session);
|
VerifySignatureMap(session);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,11 @@ namespace mediapipe {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||||
|
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||||
|
|
||||||
// Updates the graph nodes to use the device as specified by device_id.
|
// Updates the graph nodes to use the device as specified by device_id.
|
||||||
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
||||||
for (auto& node : *graph_def->mutable_node()) {
|
for (auto& node : *graph_def->mutable_node()) {
|
||||||
|
@ -64,25 +69,26 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
||||||
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
|
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
|
||||||
bool has_exactly_one_model =
|
bool has_exactly_one_model =
|
||||||
!options.graph_proto_path().empty()
|
!options.graph_proto_path().empty()
|
||||||
? !(input_side_packets->HasTag("STRING_MODEL") |
|
? !(input_side_packets->HasTag(kStringModelTag) |
|
||||||
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"))
|
input_side_packets->HasTag(kStringModelFilePathTag))
|
||||||
: (input_side_packets->HasTag("STRING_MODEL") ^
|
: (input_side_packets->HasTag(kStringModelTag) ^
|
||||||
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"));
|
input_side_packets->HasTag(kStringModelFilePathTag));
|
||||||
RET_CHECK(has_exactly_one_model)
|
RET_CHECK(has_exactly_one_model)
|
||||||
<< "Must have exactly one of graph_proto_path in options or "
|
<< "Must have exactly one of graph_proto_path in options or "
|
||||||
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
||||||
if (input_side_packets->HasTag("STRING_MODEL")) {
|
if (input_side_packets->HasTag(kStringModelTag)) {
|
||||||
input_side_packets->Tag("STRING_MODEL")
|
input_side_packets->Tag(kStringModelTag)
|
||||||
.Set<std::string>(
|
.Set<std::string>(
|
||||||
// String model from embedded path
|
// String model from embedded path
|
||||||
);
|
);
|
||||||
} else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) {
|
} else if (input_side_packets->HasTag(kStringModelFilePathTag)) {
|
||||||
input_side_packets->Tag("STRING_MODEL_FILE_PATH")
|
input_side_packets->Tag(kStringModelFilePathTag)
|
||||||
.Set<std::string>(
|
.Set<std::string>(
|
||||||
// Filename of std::string model.
|
// Filename of std::string model.
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
output_side_packets->Tag("SESSION").Set<TensorFlowSession>(
|
output_side_packets->Tag(kSessionTag)
|
||||||
|
.Set<TensorFlowSession>(
|
||||||
// A TensorFlow model loaded and ready for use along with
|
// A TensorFlow model loaded and ready for use along with
|
||||||
// a map from tags to tensor names.
|
// a map from tags to tensor names.
|
||||||
);
|
);
|
||||||
|
@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
||||||
session->session.reset(tf::NewSession(session_options));
|
session->session.reset(tf::NewSession(session_options));
|
||||||
|
|
||||||
std::string graph_def_serialized;
|
std::string graph_def_serialized;
|
||||||
if (input_side_packets.HasTag("STRING_MODEL")) {
|
if (input_side_packets.HasTag(kStringModelTag)) {
|
||||||
graph_def_serialized =
|
graph_def_serialized =
|
||||||
input_side_packets.Tag("STRING_MODEL").Get<std::string>();
|
input_side_packets.Tag(kStringModelTag).Get<std::string>();
|
||||||
} else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) {
|
} else if (input_side_packets.HasTag(kStringModelFilePathTag)) {
|
||||||
const std::string& frozen_graph =
|
const std::string& frozen_graph =
|
||||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get<std::string>();
|
input_side_packets.Tag(kStringModelFilePathTag).Get<std::string>();
|
||||||
RET_CHECK_OK(
|
RET_CHECK_OK(
|
||||||
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
||||||
} else {
|
} else {
|
||||||
|
@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
||||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
output_side_packets->Tag("SESSION") = Adopt(session.release());
|
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
|
||||||
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
||||||
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
||||||
<< " microseconds.";
|
<< " microseconds.";
|
||||||
|
|
|
@ -37,6 +37,10 @@ namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||||
|
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetGraphDefPath() {
|
std::string GetGraphDefPath() {
|
||||||
return mediapipe::file::JoinPath("./",
|
return mediapipe::file::JoinPath("./",
|
||||||
"mediapipe/calculators/tensorflow/"
|
"mediapipe/calculators/tensorflow/"
|
||||||
|
@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
|
||||||
|
|
||||||
void VerifySignatureMap(PacketSet* output_side_packets) {
|
void VerifySignatureMap(PacketSet* output_side_packets) {
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
output_side_packets->Tag("SESSION").Get<TensorFlowSession>();
|
output_side_packets->Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
|
|
||||||
|
@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
generator_options_->clear_graph_proto_path();
|
generator_options_->clear_graph_proto_path();
|
||||||
input_side_packets.Tag("STRING_MODEL") =
|
input_side_packets.Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||||
|
@ -196,7 +200,7 @@ TEST_F(
|
||||||
PacketSet output_side_packets(
|
PacketSet output_side_packets(
|
||||||
tool::CreateTagMap({"SESSION:session"}).value());
|
tool::CreateTagMap({"SESSION:session"}).value());
|
||||||
generator_options_->clear_graph_proto_path();
|
generator_options_->clear_graph_proto_path();
|
||||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||||
|
@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
||||||
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value());
|
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value());
|
||||||
PacketSet output_side_packets(
|
PacketSet output_side_packets(
|
||||||
tool::CreateTagMap({"SESSION:session"}).value());
|
tool::CreateTagMap({"SESSION:session"}).value());
|
||||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||||
|
@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
||||||
std::string serialized_graph_contents;
|
std::string serialized_graph_contents;
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
input_side_packets.Tag("STRING_MODEL") =
|
input_side_packets.Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
|
|
||||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||||
|
@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
||||||
std::string serialized_graph_contents;
|
std::string serialized_graph_contents;
|
||||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||||
&serialized_graph_contents));
|
&serialized_graph_contents));
|
||||||
input_side_packets.Tag("STRING_MODEL") =
|
input_side_packets.Tag(kStringModelTag) =
|
||||||
Adopt(new std::string(serialized_graph_contents));
|
Adopt(new std::string(serialized_graph_contents));
|
||||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||||
Adopt(new std::string(GetGraphDefPath()));
|
Adopt(new std::string(GetGraphDefPath()));
|
||||||
generator_options_->clear_graph_proto_path();
|
generator_options_->clear_graph_proto_path();
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
|
||||||
// Given the path to a directory containing multiple tensorflow saved models
|
// Given the path to a directory containing multiple tensorflow saved models
|
||||||
|
@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
|
||||||
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
|
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
|
||||||
}
|
}
|
||||||
// A TensorFlow model loaded and ready for use along with tensor
|
// A TensorFlow model loaded and ready for use along with tensor
|
||||||
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
cc->OutputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
|
||||||
output_signature.first, options)] = output_signature.second.name();
|
output_signature.first, options)] = output_signature.second.name();
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
|
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,9 @@ namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetSavedModelDir() {
|
std::string GetSavedModelDir() {
|
||||||
std::string out_path =
|
std::string out_path =
|
||||||
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
||||||
|
@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
||||||
options_->DebugString()));
|
options_->DebugString()));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
|
|
||||||
|
@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
||||||
}
|
}
|
||||||
})",
|
})",
|
||||||
options_->DebugString()));
|
options_->DebugString()));
|
||||||
runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") =
|
runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) =
|
||||||
MakePacket<std::string>(GetSavedModelDir());
|
MakePacket<std::string>(GetSavedModelDir());
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
}
|
}
|
||||||
|
@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
||||||
options_->DebugString()));
|
options_->DebugString()));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
}
|
}
|
||||||
|
@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
||||||
options_->DebugString()));
|
options_->DebugString()));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
std::vector<tensorflow::DeviceAttributes> devices;
|
std::vector<tensorflow::DeviceAttributes> devices;
|
||||||
|
|
|
@ -33,6 +33,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
|
||||||
// Given the path to a directory containing multiple tensorflow saved models
|
// Given the path to a directory containing multiple tensorflow saved models
|
||||||
|
@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
||||||
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
||||||
}
|
}
|
||||||
// A TensorFlow model loaded and ready for use along with tensor
|
// A TensorFlow model loaded and ready for use along with tensor
|
||||||
output_side_packets->Tag("SESSION").Set<TensorFlowSession>();
|
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
||||||
output_signature.first, options)] = output_signature.second.name();
|
output_signature.first, options)] = output_signature.second.name();
|
||||||
}
|
}
|
||||||
|
|
||||||
output_side_packets->Tag("SESSION") = Adopt(session.release());
|
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -34,6 +34,9 @@ namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||||
|
constexpr char kSessionTag[] = "SESSION";
|
||||||
|
|
||||||
std::string GetSavedModelDir() {
|
std::string GetSavedModelDir() {
|
||||||
std::string out_path =
|
std::string out_path =
|
||||||
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
||||||
|
@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
input_side_packets, &output_side_packets);
|
input_side_packets, &output_side_packets);
|
||||||
MP_EXPECT_OK(run_status) << run_status.message();
|
MP_EXPECT_OK(run_status) << run_status.message();
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
|
|
||||||
|
@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
generator_options_->clear_saved_model_path();
|
generator_options_->clear_saved_model_path();
|
||||||
PacketSet input_side_packets(
|
PacketSet input_side_packets(
|
||||||
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value());
|
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value());
|
||||||
input_side_packets.Tag("STRING_SAVED_MODEL_PATH") =
|
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||||
Adopt(new std::string(GetSavedModelDir()));
|
Adopt(new std::string(GetSavedModelDir()));
|
||||||
PacketSet output_side_packets(
|
PacketSet output_side_packets(
|
||||||
tool::CreateTagMap({"SESSION:session"}).value());
|
tool::CreateTagMap({"SESSION:session"}).value());
|
||||||
|
@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
input_side_packets, &output_side_packets);
|
input_side_packets, &output_side_packets);
|
||||||
MP_EXPECT_OK(run_status) << run_status.message();
|
MP_EXPECT_OK(run_status) << run_status.message();
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
}
|
}
|
||||||
|
@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
input_side_packets, &output_side_packets);
|
input_side_packets, &output_side_packets);
|
||||||
MP_EXPECT_OK(run_status) << run_status.message();
|
MP_EXPECT_OK(run_status) << run_status.message();
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
}
|
}
|
||||||
|
@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
||||||
input_side_packets, &output_side_packets);
|
input_side_packets, &output_side_packets);
|
||||||
MP_EXPECT_OK(run_status) << run_status.message();
|
MP_EXPECT_OK(run_status) << run_status.message();
|
||||||
const TensorFlowSession& session =
|
const TensorFlowSession& session =
|
||||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||||
// Session must be set.
|
// Session must be set.
|
||||||
ASSERT_NE(session.session, nullptr);
|
ASSERT_NE(session.session, nullptr);
|
||||||
std::vector<tensorflow::DeviceAttributes> devices;
|
std::vector<tensorflow::DeviceAttributes> devices;
|
||||||
|
|
|
@ -33,6 +33,31 @@ namespace {
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
namespace mpms = mediapipe::mediasequence;
|
namespace mpms = mediapipe::mediasequence;
|
||||||
|
|
||||||
|
constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE";
|
||||||
|
constexpr char kEncodedMediaStartTimestampTag[] =
|
||||||
|
"ENCODED_MEDIA_START_TIMESTAMP";
|
||||||
|
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
|
||||||
|
constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS";
|
||||||
|
constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS";
|
||||||
|
constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS";
|
||||||
|
constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS";
|
||||||
|
constexpr char kDataPathTag[] = "DATA_PATH";
|
||||||
|
constexpr char kDatasetRootTag[] = "DATASET_ROOT";
|
||||||
|
constexpr char kMediaIdTag[] = "MEDIA_ID";
|
||||||
|
constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX";
|
||||||
|
constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG";
|
||||||
|
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
|
||||||
|
constexpr char kAudioTestTag[] = "AUDIO_TEST";
|
||||||
|
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||||
|
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||||
|
constexpr char kBboxPrefixTag[] = "BBOX_PREFIX";
|
||||||
|
constexpr char kKeypointsTag[] = "KEYPOINTS";
|
||||||
|
constexpr char kBboxTag[] = "BBOX";
|
||||||
|
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||||
|
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
|
|
||||||
class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
|
class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpCalculator(const std::vector<std::string>& output_streams,
|
void SetUpCalculator(const std::vector<std::string>& output_streams,
|
||||||
|
@ -95,13 +120,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) {
|
||||||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("IMAGE").packets;
|
runner_->Outputs().Tag(kImageTag).packets;
|
||||||
ASSERT_EQ(num_images, output_packets.size());
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
@ -124,13 +149,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) {
|
||||||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("IMAGE").packets;
|
runner_->Outputs().Tag(kImageTag).packets;
|
||||||
ASSERT_EQ(num_images, output_packets.size());
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
@ -154,13 +179,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) {
|
||||||
mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get());
|
mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("IMAGE_PREFIX").packets;
|
runner_->Outputs().Tag(kImagePrefixTag).packets;
|
||||||
ASSERT_EQ(num_images, output_packets.size());
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
@ -182,12 +207,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) {
|
||||||
mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get());
|
mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
|
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
|
||||||
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_forward_flow_images; ++i) {
|
for (int i = 0; i < num_forward_flow_images; ++i) {
|
||||||
|
@ -211,12 +236,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) {
|
||||||
mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get());
|
mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
|
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
|
||||||
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_forward_flow_images; ++i) {
|
for (int i = 0; i < num_forward_flow_images; ++i) {
|
||||||
|
@ -240,13 +265,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) {
|
||||||
mpms::AddBBoxTimestamp(i, input_sequence.get());
|
mpms::AddBBoxTimestamp(i, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("BBOX").packets;
|
runner_->Outputs().Tag(kBboxTag).packets;
|
||||||
ASSERT_EQ(bboxes.size(), output_packets.size());
|
ASSERT_EQ(bboxes.size(), output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < bboxes.size(); ++i) {
|
for (int i = 0; i < bboxes.size(); ++i) {
|
||||||
|
@ -274,13 +299,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) {
|
||||||
mpms::AddBBoxTimestamp(prefix, i, input_sequence.get());
|
mpms::AddBBoxTimestamp(prefix, i, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("BBOX_PREFIX").packets;
|
runner_->Outputs().Tag(kBboxPrefixTag).packets;
|
||||||
ASSERT_EQ(bboxes.size(), output_packets.size());
|
ASSERT_EQ(bboxes.size(), output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < bboxes.size(); ++i) {
|
for (int i = 0; i < bboxes.size(); ++i) {
|
||||||
|
@ -306,13 +331,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
|
||||||
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
|
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets;
|
runner_->Outputs().Tag(kFloatFeatureTestTag).packets;
|
||||||
ASSERT_EQ(num_float_lists, output_packets.size());
|
ASSERT_EQ(num_float_lists, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_float_lists; ++i) {
|
for (int i = 0; i < num_float_lists; ++i) {
|
||||||
|
@ -322,7 +347,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_other =
|
const std::vector<Packet>& output_packets_other =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
|
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
|
||||||
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_float_lists; ++i) {
|
for (int i = 0; i < num_float_lists; ++i) {
|
||||||
|
@ -352,12 +377,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
|
||||||
mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get());
|
mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("IMAGE").packets;
|
runner_->Outputs().Tag(kImageTag).packets;
|
||||||
ASSERT_EQ(num_images, output_packets.size());
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
@ -366,7 +391,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets_other =
|
const std::vector<Packet>& output_packets_other =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
|
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
|
||||||
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_float_lists; ++i) {
|
for (int i = 0; i < num_float_lists; ++i) {
|
||||||
|
@ -389,12 +414,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
||||||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||||
input_sequence.get());
|
input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& fdense_avg_packets =
|
const std::vector<Packet>& fdense_avg_packets =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets;
|
runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets;
|
||||||
ASSERT_EQ(fdense_avg_packets.size(), 1);
|
ASSERT_EQ(fdense_avg_packets.size(), 1);
|
||||||
const auto& fdense_avg_vector =
|
const auto& fdense_avg_vector =
|
||||||
fdense_avg_packets[0].Get<std::vector<float>>();
|
fdense_avg_packets[0].Get<std::vector<float>>();
|
||||||
|
@ -403,7 +428,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
||||||
::testing::Eq(Timestamp::PostStream()));
|
::testing::Eq(Timestamp::PostStream()));
|
||||||
|
|
||||||
const std::vector<Packet>& fdense_max_packets =
|
const std::vector<Packet>& fdense_max_packets =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
|
||||||
ASSERT_EQ(fdense_max_packets.size(), 1);
|
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||||
const auto& fdense_max_vector =
|
const auto& fdense_max_vector =
|
||||||
fdense_max_packets[0].Get<std::vector<float>>();
|
fdense_max_packets[0].Get<std::vector<float>>();
|
||||||
|
@ -430,13 +455,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
|
||||||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||||
input_sequence.get());
|
input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("IMAGE").packets;
|
runner_->Outputs().Tag(kImageTag).packets;
|
||||||
ASSERT_EQ(num_images, output_packets.size());
|
ASSERT_EQ(num_images, output_packets.size());
|
||||||
|
|
||||||
for (int i = 0; i < num_images; ++i) {
|
for (int i = 0; i < num_images; ++i) {
|
||||||
|
@ -463,13 +488,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
||||||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||||
input_sequence.get());
|
input_sequence.get());
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
const std::vector<Packet>& fdense_max_packets =
|
const std::vector<Packet>& fdense_max_packets =
|
||||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
|
||||||
ASSERT_EQ(fdense_max_packets.size(), 1);
|
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||||
const auto& fdense_max_vector =
|
const auto& fdense_max_vector =
|
||||||
fdense_max_packets[0].Get<std::vector<float>>();
|
fdense_max_packets[0].Get<std::vector<float>>();
|
||||||
|
@ -481,17 +506,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
||||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
||||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
|
|
||||||
std::string root = "test_root";
|
std::string root = "test_root";
|
||||||
runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root);
|
runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root);
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("DATA_PATH")
|
.Tag(kDataPathTag)
|
||||||
.ValidateAsType<std::string>());
|
.ValidateAsType<std::string>());
|
||||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||||
root + "/" + data_path_);
|
root + "/" + data_path_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -501,28 +526,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) {
|
||||||
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
|
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
|
||||||
->set_dataset_root_directory(root);
|
->set_dataset_root_directory(root);
|
||||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options);
|
SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options);
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("DATA_PATH")
|
.Tag(kDataPathTag)
|
||||||
.ValidateAsType<std::string>());
|
.ValidateAsType<std::string>());
|
||||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||||
root + "/" + data_path_);
|
root + "/" + data_path_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
|
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
|
||||||
SetUpCalculator({}, {"DATA_PATH:data_path"});
|
SetUpCalculator({}, {"DATA_PATH:data_path"});
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("DATA_PATH")
|
.Tag(kDataPathTag)
|
||||||
.ValidateAsType<std::string>());
|
.ValidateAsType<std::string>());
|
||||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||||
data_path_);
|
data_path_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,20 +559,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) {
|
||||||
->set_padding_after_label(2);
|
->set_padding_after_label(2);
|
||||||
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
||||||
&options);
|
&options);
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.ValidateAsType<AudioDecoderOptions>());
|
.ValidateAsType<AudioDecoderOptions>());
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.Get<AudioDecoderOptions>()
|
.Get<AudioDecoderOptions>()
|
||||||
.start_time(),
|
.start_time(),
|
||||||
2.0, 1e-5);
|
2.0, 1e-5);
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.Get<AudioDecoderOptions>()
|
.Get<AudioDecoderOptions>()
|
||||||
.end_time(),
|
.end_time(),
|
||||||
7.0, 1e-5);
|
7.0, 1e-5);
|
||||||
|
@ -563,20 +588,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) {
|
||||||
->set_force_decoding_from_start_of_media(true);
|
->set_force_decoding_from_start_of_media(true);
|
||||||
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
||||||
&options);
|
&options);
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.ValidateAsType<AudioDecoderOptions>());
|
.ValidateAsType<AudioDecoderOptions>());
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.Get<AudioDecoderOptions>()
|
.Get<AudioDecoderOptions>()
|
||||||
.start_time(),
|
.start_time(),
|
||||||
0.0, 1e-5);
|
0.0, 1e-5);
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("AUDIO_DECODER_OPTIONS")
|
.Tag(kAudioDecoderOptionsTag)
|
||||||
.Get<AudioDecoderOptions>()
|
.Get<AudioDecoderOptions>()
|
||||||
.end_time(),
|
.end_time(),
|
||||||
7.0, 1e-5);
|
7.0, 1e-5);
|
||||||
|
@ -594,27 +619,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
|
||||||
->mutable_base_packet_resampler_options()
|
->mutable_base_packet_resampler_options()
|
||||||
->set_frame_rate(1.0);
|
->set_frame_rate(1.0);
|
||||||
SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options);
|
SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options);
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("RESAMPLER_OPTIONS")
|
.Tag(kResamplerOptionsTag)
|
||||||
.ValidateAsType<CalculatorOptions>());
|
.ValidateAsType<CalculatorOptions>());
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("RESAMPLER_OPTIONS")
|
.Tag(kResamplerOptionsTag)
|
||||||
.Get<CalculatorOptions>()
|
.Get<CalculatorOptions>()
|
||||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||||
.start_time(),
|
.start_time(),
|
||||||
2000000, 1);
|
2000000, 1);
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("RESAMPLER_OPTIONS")
|
.Tag(kResamplerOptionsTag)
|
||||||
.Get<CalculatorOptions>()
|
.Get<CalculatorOptions>()
|
||||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||||
.end_time(),
|
.end_time(),
|
||||||
7000000, 1);
|
7000000, 1);
|
||||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||||
.Tag("RESAMPLER_OPTIONS")
|
.Tag(kResamplerOptionsTag)
|
||||||
.Get<CalculatorOptions>()
|
.Get<CalculatorOptions>()
|
||||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||||
.frame_rate(),
|
.frame_rate(),
|
||||||
|
@ -623,13 +648,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
|
||||||
|
|
||||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) {
|
TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) {
|
||||||
SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"});
|
SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"});
|
||||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(sequence_.release());
|
Adopt(sequence_.release());
|
||||||
MP_ASSERT_OK(runner_->Run());
|
MP_ASSERT_OK(runner_->Run());
|
||||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||||
.Tag("IMAGE_FRAME_RATE")
|
.Tag(kImageFrameRateTag)
|
||||||
.ValidateAsType<double>());
|
.ValidateAsType<double>());
|
||||||
EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get<double>(),
|
EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get<double>(),
|
||||||
image_frame_rate_);
|
image_frame_rate_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,10 @@ namespace {
|
||||||
|
|
||||||
namespace tf = ::tensorflow;
|
namespace tf = ::tensorflow;
|
||||||
|
|
||||||
|
constexpr char kSingleIntTag[] = "SINGLE_INT";
|
||||||
|
constexpr char kTensorOutTag[] = "TENSOR_OUT";
|
||||||
|
constexpr char kVectorIntTag[] = "VECTOR_INT";
|
||||||
|
|
||||||
class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpRunner(
|
void SetUpRunner(
|
||||||
|
@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("VECTOR_INT")
|
->Tag(kVectorIntTag)
|
||||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||||
|
@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
|
||||||
tensorflow::DT_INT32, false, true);
|
tensorflow::DT_INT32, false, true);
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("SINGLE_INT")
|
->Tag(kSingleIntTag)
|
||||||
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
|
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||||
|
@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
|
||||||
}
|
}
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("VECTOR_INT")
|
->Tag(kVectorIntTag)
|
||||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||||
|
@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
|
||||||
tensorflow::DT_INT64, false, true);
|
tensorflow::DT_INT64, false, true);
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("SINGLE_INT")
|
->Tag(kSingleIntTag)
|
||||||
.packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time)));
|
.packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||||
|
@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
|
||||||
}
|
}
|
||||||
const int64 time = 1234;
|
const int64 time = 1234;
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag("VECTOR_INT")
|
->Tag(kVectorIntTag)
|
||||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
EXPECT_TRUE(runner_->Run().ok());
|
EXPECT_TRUE(runner_->Run().ok());
|
||||||
|
|
||||||
const std::vector<Packet>& output_packets =
|
const std::vector<Packet>& output_packets =
|
||||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||||
EXPECT_EQ(1, output_packets.size());
|
EXPECT_EQ(1, output_packets.size());
|
||||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kFloatsTag[] = "FLOATS";
|
||||||
|
constexpr char kFloatTag[] = "FLOAT";
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
// A calculator for converting TFLite tensors to to a float or a float vector.
|
// A calculator for converting TFLite tensors to to a float or a float vector.
|
||||||
//
|
//
|
||||||
// Input:
|
// Input:
|
||||||
|
@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
|
||||||
|
|
||||||
absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
|
absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
|
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||||
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
|
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
|
||||||
|
cc->Outputs().HasTag(kFloatTag));
|
||||||
|
|
||||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||||
if (cc->Outputs().HasTag("FLOATS")) {
|
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||||
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
|
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("FLOAT")) {
|
if (cc->Outputs().HasTag(kFloatTag)) {
|
||||||
cc->Outputs().Tag("FLOAT").Set<float>();
|
cc->Outputs().Tag(kFloatTag).Set<float>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
||||||
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty());
|
||||||
|
|
||||||
const auto& input_tensors =
|
const auto& input_tensors =
|
||||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
|
||||||
// TODO: Add option to specify which tensor to take from.
|
// TODO: Add option to specify which tensor to take from.
|
||||||
const TfLiteTensor* raw_tensor = &input_tensors[0];
|
const TfLiteTensor* raw_tensor = &input_tensors[0];
|
||||||
const float* raw_floats = raw_tensor->data.f;
|
const float* raw_floats = raw_tensor->data.f;
|
||||||
|
@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
||||||
num_values *= raw_tensor->dims->data[i];
|
num_values *= raw_tensor->dims->data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("FLOAT")) {
|
if (cc->Outputs().HasTag(kFloatTag)) {
|
||||||
// TODO: Could add an index in the option to specifiy returning one
|
// TODO: Could add an index in the option to specifiy returning one
|
||||||
// value of a float array.
|
// value of a float array.
|
||||||
RET_CHECK_EQ(num_values, 1);
|
RET_CHECK_EQ(num_values, 1);
|
||||||
cc->Outputs().Tag("FLOAT").AddPacket(
|
cc->Outputs().Tag(kFloatTag).AddPacket(
|
||||||
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
|
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("FLOATS")) {
|
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||||
auto output_floats = absl::make_unique<std::vector<float>>(
|
auto output_floats = absl::make_unique<std::vector<float>>(
|
||||||
raw_floats, raw_floats + num_values);
|
raw_floats, raw_floats + num_values);
|
||||||
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
|
cc->Outputs()
|
||||||
cc->InputTimestamp());
|
.Tag(kFloatsTag)
|
||||||
|
.Add(output_floats.release(), cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) {
|
||||||
// Initialize the clock.
|
// Initialize the clock.
|
||||||
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
||||||
clock_ = cc->InputSidePackets()
|
clock_ = cc->InputSidePackets()
|
||||||
.Tag("CLOCK")
|
.Tag(kClockTag)
|
||||||
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
||||||
} else {
|
} else {
|
||||||
clock_.reset(
|
clock_.reset(
|
||||||
|
|
|
@ -27,6 +27,8 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kIterableTag[] = "ITERABLE";
|
||||||
|
|
||||||
typedef CollectionHasMinSizeCalculator<std::vector<int>>
|
typedef CollectionHasMinSizeCalculator<std::vector<int>>
|
||||||
TestIntCollectionHasMinSizeCalculator;
|
TestIntCollectionHasMinSizeCalculator;
|
||||||
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
||||||
|
@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
||||||
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
||||||
CalculatorRunner* runner) {
|
CalculatorRunner* runner) {
|
||||||
runner->MutableInputs()
|
runner->MutableInputs()
|
||||||
->Tag("ITERABLE")
|
->Tag(kIterableTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("DETECTIONS")
|
.Tag(kDetectionsTag)
|
||||||
.Add(output_detections.release(), cc->InputTimestamp());
|
.Add(output_detections.release(), cc->InputTimestamp());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
|
||||||
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
|
|
||||||
LocationData CreateRelativeLocationData(double xmin, double ymin, double width,
|
LocationData CreateRelativeLocationData(double xmin, double ymin, double width,
|
||||||
double height) {
|
double height) {
|
||||||
LocationData location_data;
|
LocationData location_data;
|
||||||
|
@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) {
|
||||||
detections->push_back(
|
detections->push_back(
|
||||||
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||||
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LETTERBOX_PADDING")
|
->Tag(kLetterboxPaddingTag)
|
||||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("DETECTIONS").packets;
|
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||||
|
|
||||||
|
@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) {
|
||||||
detections->push_back(
|
detections->push_back(
|
||||||
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||||
std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f});
|
std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f});
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LETTERBOX_PADDING")
|
->Tag(kLetterboxPaddingTag)
|
||||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("DETECTIONS").packets;
|
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,9 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||||
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
|
|
||||||
using ::testing::ElementsAre;
|
using ::testing::ElementsAre;
|
||||||
using ::testing::FloatNear;
|
using ::testing::FloatNear;
|
||||||
|
|
||||||
|
@ -74,19 +77,19 @@ absl::StatusOr<Detection> RunProjectionCalculator(
|
||||||
)pb"));
|
)pb"));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(MakePacket<std::vector<Detection>>(
|
.packets.push_back(MakePacket<std::vector<Detection>>(
|
||||||
std::vector<Detection>({std::move(detection)}))
|
std::vector<Detection>({std::move(detection)}))
|
||||||
.At(Timestamp::PostStream()));
|
.At(Timestamp::PostStream()));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("PROJECTION_MATRIX")
|
->Tag(kProjectionMatrixTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<std::array<float, 16>>(std::move(project_mat))
|
MakePacket<std::array<float, 16>>(std::move(project_mat))
|
||||||
.At(Timestamp::PostStream()));
|
.At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(runner.Run());
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("DETECTIONS").packets;
|
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||||
RET_CHECK_EQ(output.size(), 1);
|
RET_CHECK_EQ(output.size(), 1);
|
||||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,14 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||||
|
constexpr char kRectsTag[] = "RECTS";
|
||||||
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||||
|
constexpr char kRectTag[] = "RECT";
|
||||||
|
constexpr char kDetectionTag[] = "DETECTION";
|
||||||
|
|
||||||
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
|
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
|
||||||
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
|
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
|
||||||
testing::Value(arg.y_center(), testing::Eq(y_center)) &&
|
testing::Value(arg.y_center(), testing::Eq(y_center)) &&
|
||||||
|
@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) {
|
||||||
DetectionWithLocationData(100, 200, 300, 400));
|
DetectionWithLocationData(100, 200, 300, 400));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rect = output[0].Get<Rect>();
|
const auto& rect = output[0].Get<Rect>();
|
||||||
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
||||||
|
@ -120,16 +128,16 @@ absl::StatusOr<Rect> RunDetectionKeyPointsToRectCalculation(
|
||||||
)pb"));
|
)pb"));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
||||||
.At(Timestamp::PostStream()));
|
.At(Timestamp::PostStream()));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("IMAGE_SIZE")
|
->Tag(kImageSizeTag)
|
||||||
.packets.push_back(MakePacket<std::pair<int, int>>(image_size)
|
.packets.push_back(MakePacket<std::pair<int, int>>(image_size)
|
||||||
.At(Timestamp::PostStream()));
|
.At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(runner.Run());
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||||
RET_CHECK_EQ(output.size(), 1);
|
RET_CHECK_EQ(output.size(), 1);
|
||||||
return output[0].Get<Rect>();
|
return output[0].Get<Rect>();
|
||||||
}
|
}
|
||||||
|
@ -176,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) {
|
||||||
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
const std::vector<Packet>& output =
|
||||||
|
runner.Outputs().Tag(kNormRectTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rect = output[0].Get<NormalizedRect>();
|
const auto& rect = output[0].Get<NormalizedRect>();
|
||||||
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
||||||
|
@ -201,12 +210,13 @@ absl::StatusOr<NormalizedRect> RunDetectionKeyPointsToNormRectCalculation(
|
||||||
)pb"));
|
)pb"));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
||||||
.At(Timestamp::PostStream()));
|
.At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(runner.Run());
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
const std::vector<Packet>& output =
|
||||||
|
runner.Outputs().Tag(kNormRectTag).packets;
|
||||||
RET_CHECK_EQ(output.size(), 1);
|
RET_CHECK_EQ(output.size(), 1);
|
||||||
return output[0].Get<NormalizedRect>();
|
return output[0].Get<NormalizedRect>();
|
||||||
}
|
}
|
||||||
|
@ -248,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) {
|
||||||
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rect = output[0].Get<Rect>();
|
const auto& rect = output[0].Get<Rect>();
|
||||||
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
||||||
|
@ -271,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) {
|
||||||
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
const std::vector<Packet>& output =
|
||||||
|
runner.Outputs().Tag(kNormRectTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rect = output[0].Get<NormalizedRect>();
|
const auto& rect = output[0].Get<NormalizedRect>();
|
||||||
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
||||||
|
@ -294,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) {
|
||||||
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
|
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rects = output[0].Get<std::vector<Rect>>();
|
const auto& rects = output[0].Get<std::vector<Rect>>();
|
||||||
ASSERT_EQ(rects.size(), 2);
|
ASSERT_EQ(rects.size(), 2);
|
||||||
|
@ -319,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) {
|
||||||
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("NORM_RECTS").packets;
|
runner.Outputs().Tag(kNormRectsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||||
ASSERT_EQ(rects.size(), 2);
|
ASSERT_EQ(rects.size(), 2);
|
||||||
|
@ -344,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) {
|
||||||
DetectionWithLocationData(100, 200, 300, 400));
|
DetectionWithLocationData(100, 200, 300, 400));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
|
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rects = output[0].Get<std::vector<Rect>>();
|
const auto& rects = output[0].Get<std::vector<Rect>>();
|
||||||
EXPECT_EQ(rects.size(), 1);
|
EXPECT_EQ(rects.size(), 1);
|
||||||
|
@ -367,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) {
|
||||||
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION")
|
->Tag(kDetectionTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("NORM_RECTS").packets;
|
runner.Outputs().Tag(kNormRectsTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||||
ASSERT_EQ(rects.size(), 1);
|
ASSERT_EQ(rects.size(), 1);
|
||||||
|
@ -391,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) {
|
||||||
detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
|
@ -411,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) {
|
||||||
detections->push_back(DetectionWithLocationData(100, 200, 300, 400));
|
detections->push_back(DetectionWithLocationData(100, 200, 300, 400));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,10 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
|
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||||
|
constexpr char kDetectionListTag[] = "DETECTION_LIST";
|
||||||
|
|
||||||
using ::testing::DoubleNear;
|
using ::testing::DoubleNear;
|
||||||
|
|
||||||
// Error tolerance for pixels, distances, etc.
|
// Error tolerance for pixels, distances, etc.
|
||||||
|
@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) {
|
||||||
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag");
|
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag");
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION_LIST")
|
->Tag(kDetectionListTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& actual = output[0].Get<RenderData>();
|
const auto& actual = output[0].Get<RenderData>();
|
||||||
EXPECT_EQ(actual.render_annotations_size(), 3);
|
EXPECT_EQ(actual.render_annotations_size(), 3);
|
||||||
|
@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) {
|
||||||
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"));
|
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"));
|
||||||
|
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& actual = output[0].Get<RenderData>();
|
const auto& actual = output[0].Get<RenderData>();
|
||||||
EXPECT_EQ(actual.render_annotations_size(), 3);
|
EXPECT_EQ(actual.render_annotations_size(), 3);
|
||||||
|
@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
|
||||||
*(detection_list->add_detection()) =
|
*(detection_list->add_detection()) =
|
||||||
CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1");
|
CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1");
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTION_LIST")
|
->Tag(kDetectionListTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection_list.release()).At(Timestamp::PostStream()));
|
Adopt(detection_list.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
|
@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
|
||||||
detections->push_back(
|
detections->push_back(
|
||||||
CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2"));
|
CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2"));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& actual =
|
const std::vector<Packet>& actual =
|
||||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||||
ASSERT_EQ(1, actual.size());
|
ASSERT_EQ(1, actual.size());
|
||||||
// Check the feature tag for item from detection list.
|
// Check the feature tag for item from detection list.
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
|
||||||
|
|
||||||
auto detection_list1(absl::make_unique<DetectionList>());
|
auto detection_list1(absl::make_unique<DetectionList>());
|
||||||
runner1.MutableInputs()
|
runner1.MutableInputs()
|
||||||
->Tag("DETECTION_LIST")
|
->Tag(kDetectionListTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection_list1.release()).At(Timestamp::PostStream()));
|
Adopt(detection_list1.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto detections1(absl::make_unique<std::vector<Detection>>());
|
auto detections1(absl::make_unique<std::vector<Detection>>());
|
||||||
runner1.MutableInputs()
|
runner1.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections1.release()).At(Timestamp::PostStream()));
|
Adopt(detections1.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& exact1 =
|
const std::vector<Packet>& exact1 =
|
||||||
runner1.Outputs().Tag("RENDER_DATA").packets;
|
runner1.Outputs().Tag(kRenderDataTag).packets;
|
||||||
ASSERT_EQ(0, exact1.size());
|
ASSERT_EQ(0, exact1.size());
|
||||||
|
|
||||||
// Check when produce_empty_packet is true.
|
// Check when produce_empty_packet is true.
|
||||||
|
@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
|
||||||
|
|
||||||
auto detection_list2(absl::make_unique<DetectionList>());
|
auto detection_list2(absl::make_unique<DetectionList>());
|
||||||
runner2.MutableInputs()
|
runner2.MutableInputs()
|
||||||
->Tag("DETECTION_LIST")
|
->Tag(kDetectionListTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detection_list2.release()).At(Timestamp::PostStream()));
|
Adopt(detection_list2.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto detections2(absl::make_unique<std::vector<Detection>>());
|
auto detections2(absl::make_unique<std::vector<Detection>>());
|
||||||
runner2.MutableInputs()
|
runner2.MutableInputs()
|
||||||
->Tag("DETECTIONS")
|
->Tag(kDetectionsTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(detections2.release()).At(Timestamp::PostStream()));
|
Adopt(detections2.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& exact2 =
|
const std::vector<Packet>& exact2 =
|
||||||
runner2.Outputs().Tag("RENDER_DATA").packets;
|
runner2.Outputs().Tag(kRenderDataTag).packets;
|
||||||
ASSERT_EQ(1, exact2.size());
|
ASSERT_EQ(1, exact2.size());
|
||||||
EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0);
|
EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,12 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||||
|
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||||
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
constexpr char kLabelsTag[] = "LABELS";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
|
||||||
constexpr float kFontHeightScale = 1.25f;
|
constexpr float kFontHeightScale = 1.25f;
|
||||||
|
|
||||||
// A calculator takes in pairs of labels and scores or classifications, outputs
|
// A calculator takes in pairs of labels and scores or classifications, outputs
|
||||||
|
@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
|
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
|
||||||
|
|
||||||
absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
if (cc->Inputs().HasTag(kClassificationsTag)) {
|
||||||
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
cc->Inputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(cc->Inputs().HasTag("LABELS"))
|
RET_CHECK(cc->Inputs().HasTag(kLabelsTag))
|
||||||
<< "Must provide input stream \"LABELS\"";
|
<< "Must provide input stream \"LABELS\"";
|
||||||
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
|
cc->Inputs().Tag(kLabelsTag).Set<std::vector<std::string>>();
|
||||||
if (cc->Inputs().HasTag("SCORES")) {
|
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
|
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
|
if (cc->Inputs().HasTag(kVideoPrestreamTag) &&
|
||||||
cc->InputTimestamp() == Timestamp::PreStream()) {
|
cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||||
const VideoHeader& video_header =
|
const VideoHeader& video_header =
|
||||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||||
video_width_ = video_header.width;
|
video_width_ = video_header.width;
|
||||||
video_height_ = video_header.height;
|
video_height_ = video_header.height;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
std::vector<std::string> labels;
|
std::vector<std::string> labels;
|
||||||
std::vector<float> scores;
|
std::vector<float> scores;
|
||||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
if (cc->Inputs().HasTag(kClassificationsTag)) {
|
||||||
const ClassificationList& classifications =
|
const ClassificationList& classifications =
|
||||||
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
|
cc->Inputs().Tag(kClassificationsTag).Get<ClassificationList>();
|
||||||
labels.resize(classifications.classification_size());
|
labels.resize(classifications.classification_size());
|
||||||
scores.resize(classifications.classification_size());
|
scores.resize(classifications.classification_size());
|
||||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||||
|
@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const std::vector<std::string>& label_vector =
|
const std::vector<std::string>& label_vector =
|
||||||
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
|
cc->Inputs().Tag(kLabelsTag).Get<std::vector<std::string>>();
|
||||||
labels.resize(label_vector.size());
|
labels.resize(label_vector.size());
|
||||||
for (int i = 0; i < label_vector.size(); ++i) {
|
for (int i = 0; i < label_vector.size(); ++i) {
|
||||||
labels[i] = label_vector[i];
|
labels[i] = label_vector[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("SCORES")) {
|
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||||
std::vector<float> score_vector =
|
std::vector<float> score_vector =
|
||||||
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
|
||||||
CHECK_EQ(label_vector.size(), score_vector.size());
|
CHECK_EQ(label_vector.size(), score_vector.size());
|
||||||
scores.resize(label_vector.size());
|
scores.resize(label_vector.size());
|
||||||
for (int i = 0; i < label_vector.size(); ++i) {
|
for (int i = 0; i < label_vector.size(); ++i) {
|
||||||
|
@ -169,7 +175,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
auto* text = label_annotation->mutable_text();
|
auto* text = label_annotation->mutable_text();
|
||||||
std::string display_text = labels[i];
|
std::string display_text = labels[i];
|
||||||
if (cc->Inputs().HasTag("SCORES")) {
|
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||||
absl::StrAppend(&display_text, ":", scores[i]);
|
absl::StrAppend(&display_text, ":", scores[i]);
|
||||||
}
|
}
|
||||||
text->set_display_text(display_text);
|
text->set_display_text(display_text);
|
||||||
|
@ -179,7 +185,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||||
text->set_font_face(options_.font_face());
|
text->set_font_face(options_.font_face());
|
||||||
}
|
}
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("RENDER_DATA")
|
.Tag(kRenderDataTag)
|
||||||
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
|
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -24,6 +24,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
|
||||||
|
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||||
|
|
||||||
NormalizedLandmark CreateLandmark(float x, float y) {
|
NormalizedLandmark CreateLandmark(float x, float y) {
|
||||||
NormalizedLandmark landmark;
|
NormalizedLandmark landmark;
|
||||||
landmark.set_x(x);
|
landmark.set_x(x);
|
||||||
|
@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) {
|
||||||
*landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f);
|
*landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f);
|
||||||
*landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f);
|
*landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f);
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LANDMARKS")
|
->Tag(kLandmarksTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||||
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LETTERBOX_PADDING")
|
->Tag(kLetterboxPaddingTag)
|
||||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
|
const std::vector<Packet>& output =
|
||||||
|
runner.Outputs().Tag(kLandmarksTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
||||||
|
|
||||||
|
@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) {
|
||||||
landmark = landmarks->add_landmark();
|
landmark = landmarks->add_landmark();
|
||||||
*landmark = CreateLandmark(0.7f, 0.7f);
|
*landmark = CreateLandmark(0.7f, 0.7f);
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LANDMARKS")
|
->Tag(kLandmarksTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||||
std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f});
|
std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f});
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("LETTERBOX_PADDING")
|
->Tag(kLetterboxPaddingTag)
|
||||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
|
const std::vector<Packet>& output =
|
||||||
|
runner.Outputs().Tag(kLandmarksTag).packets;
|
||||||
ASSERT_EQ(1, output.size());
|
ASSERT_EQ(1, output.size());
|
||||||
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
||||||
|
|
||||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||||
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
|
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
|
||||||
mediapipe::CalculatorRunner runner(
|
mediapipe::CalculatorRunner runner(
|
||||||
|
@ -26,17 +30,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||||
)pb"));
|
)pb"));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("NORM_LANDMARKS")
|
->Tag(kNormLandmarksTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||||
.At(Timestamp(1)));
|
.At(Timestamp(1)));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("NORM_RECT")
|
->Tag(kNormRectTag)
|
||||||
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
|
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
|
||||||
.At(Timestamp(1)));
|
.At(Timestamp(1)));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(runner.Run());
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
|
||||||
RET_CHECK_EQ(output_packets.size(), 1);
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||||
}
|
}
|
||||||
|
@ -104,17 +108,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||||
)pb"));
|
)pb"));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("NORM_LANDMARKS")
|
->Tag(kNormLandmarksTag)
|
||||||
.packets.push_back(
|
.packets.push_back(
|
||||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||||
.At(Timestamp(1)));
|
.At(Timestamp(1)));
|
||||||
runner.MutableInputs()
|
runner.MutableInputs()
|
||||||
->Tag("PROJECTION_MATRIX")
|
->Tag(kProjectionMatrixTag)
|
||||||
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
|
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
|
||||||
.At(Timestamp(1)));
|
.At(Timestamp(1)));
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(runner.Run());
|
MP_RETURN_IF_ERROR(runner.Run());
|
||||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
|
||||||
RET_CHECK_EQ(output_packets.size(), 1);
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,11 @@
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kContentsTag[] = "CONTENTS";
|
||||||
|
constexpr char kFileSuffixTag[] = "FILE_SUFFIX";
|
||||||
|
constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY";
|
||||||
|
|
||||||
// The calculator takes the path to local directory and desired file suffix to
|
// The calculator takes the path to local directory and desired file suffix to
|
||||||
// mach as input side packets, and outputs the contents of those files that
|
// mach as input side packets, and outputs the contents of those files that
|
||||||
// match the pattern. Those matched files will be sent sequentially through the
|
// match the pattern. Those matched files will be sent sequentially through the
|
||||||
|
@ -35,16 +40,16 @@ namespace mediapipe {
|
||||||
class LocalFilePatternContentsCalculator : public CalculatorBase {
|
class LocalFilePatternContentsCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->InputSidePackets().Tag("FILE_DIRECTORY").Set<std::string>();
|
cc->InputSidePackets().Tag(kFileDirectoryTag).Set<std::string>();
|
||||||
cc->InputSidePackets().Tag("FILE_SUFFIX").Set<std::string>();
|
cc->InputSidePackets().Tag(kFileSuffixTag).Set<std::string>();
|
||||||
cc->Outputs().Tag("CONTENTS").Set<std::string>();
|
cc->Outputs().Tag(kContentsTag).Set<std::string>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory(
|
MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory(
|
||||||
cc->InputSidePackets().Tag("FILE_DIRECTORY").Get<std::string>(),
|
cc->InputSidePackets().Tag(kFileDirectoryTag).Get<std::string>(),
|
||||||
cc->InputSidePackets().Tag("FILE_SUFFIX").Get<std::string>(),
|
cc->InputSidePackets().Tag(kFileSuffixTag).Get<std::string>(),
|
||||||
&filenames_));
|
&filenames_));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase {
|
||||||
filenames_[current_output_], contents.get()));
|
filenames_[current_output_], contents.get()));
|
||||||
++current_output_;
|
++current_output_;
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("CONTENTS")
|
.Tag(kContentsTag)
|
||||||
.Add(contents.release(), Timestamp(current_output_));
|
.Add(contents.release(), Timestamp(current_output_));
|
||||||
} else {
|
} else {
|
||||||
return tool::StatusStop();
|
return tool::StatusStop();
|
||||||
|
|
|
@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) {
|
||||||
// Initialize the clock.
|
// Initialize the clock.
|
||||||
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
||||||
clock_ = cc->InputSidePackets()
|
clock_ = cc->InputSidePackets()
|
||||||
.Tag("CLOCK")
|
.Tag(kClockTag)
|
||||||
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
||||||
} else {
|
} else {
|
||||||
clock_ = std::shared_ptr<::mediapipe::Clock>(
|
clock_ = std::shared_ptr<::mediapipe::Clock>(
|
||||||
|
|
|
@ -17,6 +17,12 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kThresholdTag[] = "THRESHOLD";
|
||||||
|
constexpr char kRejectTag[] = "REJECT";
|
||||||
|
constexpr char kAcceptTag[] = "ACCEPT";
|
||||||
|
constexpr char kFlagTag[] = "FLAG";
|
||||||
|
constexpr char kFloatTag[] = "FLOAT";
|
||||||
|
|
||||||
// Applies a threshold on a stream of numeric values and outputs a flag and/or
|
// Applies a threshold on a stream of numeric values and outputs a flag and/or
|
||||||
// accept/reject stream. The threshold can be specified by one of the following:
|
// accept/reject stream. The threshold can be specified by one of the following:
|
||||||
// 1) Input stream.
|
// 1) Input stream.
|
||||||
|
@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(ThresholdingCalculator);
|
REGISTER_CALCULATOR(ThresholdingCalculator);
|
||||||
|
|
||||||
absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("FLOAT"));
|
RET_CHECK(cc->Inputs().HasTag(kFloatTag));
|
||||||
cc->Inputs().Tag("FLOAT").Set<float>();
|
cc->Inputs().Tag(kFloatTag).Set<float>();
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("FLAG")) {
|
if (cc->Outputs().HasTag(kFlagTag)) {
|
||||||
cc->Outputs().Tag("FLAG").Set<bool>();
|
cc->Outputs().Tag(kFlagTag).Set<bool>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("ACCEPT")) {
|
if (cc->Outputs().HasTag(kAcceptTag)) {
|
||||||
cc->Outputs().Tag("ACCEPT").Set<bool>();
|
cc->Outputs().Tag(kAcceptTag).Set<bool>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("REJECT")) {
|
if (cc->Outputs().HasTag(kRejectTag)) {
|
||||||
cc->Outputs().Tag("REJECT").Set<bool>();
|
cc->Outputs().Tag(kRejectTag).Set<bool>();
|
||||||
}
|
}
|
||||||
if (cc->Inputs().HasTag("THRESHOLD")) {
|
if (cc->Inputs().HasTag(kThresholdTag)) {
|
||||||
cc->Inputs().Tag("THRESHOLD").Set<double>();
|
cc->Inputs().Tag(kThresholdTag).Set<double>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
|
||||||
cc->InputSidePackets().Tag("THRESHOLD").Set<double>();
|
cc->InputSidePackets().Tag(kThresholdTag).Set<double>();
|
||||||
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
|
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
|
||||||
<< "Using both the threshold input side packet and input stream is not "
|
<< "Using both the threshold input side packet and input stream is not "
|
||||||
"supported.";
|
"supported.";
|
||||||
}
|
}
|
||||||
|
@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
|
||||||
const auto& options =
|
const auto& options =
|
||||||
cc->Options<::mediapipe::ThresholdingCalculatorOptions>();
|
cc->Options<::mediapipe::ThresholdingCalculatorOptions>();
|
||||||
if (options.has_threshold()) {
|
if (options.has_threshold()) {
|
||||||
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
|
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
|
||||||
<< "Using both the threshold option and input stream is not supported.";
|
<< "Using both the threshold option and input stream is not supported.";
|
||||||
RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD"))
|
RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag))
|
||||||
<< "Using both the threshold option and input side packet is not "
|
<< "Using both the threshold option and input side packet is not "
|
||||||
"supported.";
|
"supported.";
|
||||||
threshold_ = options.threshold();
|
threshold_ = options.threshold();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
|
||||||
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
|
threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get<double>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) {
|
absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) {
|
||||||
if (cc->Inputs().HasTag("THRESHOLD") &&
|
if (cc->Inputs().HasTag(kThresholdTag) &&
|
||||||
!cc->Inputs().Tag("THRESHOLD").IsEmpty()) {
|
!cc->Inputs().Tag(kThresholdTag).IsEmpty()) {
|
||||||
threshold_ = cc->Inputs().Tag("THRESHOLD").Get<double>();
|
threshold_ = cc->Inputs().Tag(kThresholdTag).Get<double>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool accept = false;
|
bool accept = false;
|
||||||
RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty());
|
||||||
accept =
|
accept = static_cast<double>(cc->Inputs().Tag(kFloatTag).Get<float>()) >
|
||||||
static_cast<double>(cc->Inputs().Tag("FLOAT").Get<float>()) > threshold_;
|
threshold_;
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("FLAG")) {
|
if (cc->Outputs().HasTag(kFlagTag)) {
|
||||||
cc->Outputs().Tag("FLAG").AddPacket(
|
cc->Outputs().Tag(kFlagTag).AddPacket(
|
||||||
MakePacket<bool>(accept).At(cc->InputTimestamp()));
|
MakePacket<bool>(accept).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accept && cc->Outputs().HasTag("ACCEPT")) {
|
if (accept && cc->Outputs().HasTag(kAcceptTag)) {
|
||||||
cc->Outputs().Tag("ACCEPT").AddPacket(
|
cc->Outputs()
|
||||||
MakePacket<bool>(true).At(cc->InputTimestamp()));
|
.Tag(kAcceptTag)
|
||||||
|
.AddPacket(MakePacket<bool>(true).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
if (!accept && cc->Outputs().HasTag("REJECT")) {
|
if (!accept && cc->Outputs().HasTag(kRejectTag)) {
|
||||||
cc->Outputs().Tag("REJECT").AddPacket(
|
cc->Outputs()
|
||||||
MakePacket<bool>(false).At(cc->InputTimestamp()));
|
.Tag(kRejectTag)
|
||||||
|
.AddPacket(MakePacket<bool>(false).At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -39,6 +39,14 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
|
||||||
|
constexpr char kSummaryTag[] = "SUMMARY";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
|
||||||
|
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||||
|
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||||
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
|
||||||
// A calculator that takes a vector of scores and returns the indexes, scores,
|
// A calculator that takes a vector of scores and returns the indexes, scores,
|
||||||
// labels of the top k elements, classification protos, and summary std::string
|
// labels of the top k elements, classification protos, and summary std::string
|
||||||
// (in csv format).
|
// (in csv format).
|
||||||
|
@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(TopKScoresCalculator);
|
REGISTER_CALCULATOR(TopKScoresCalculator);
|
||||||
|
|
||||||
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("SCORES"));
|
RET_CHECK(cc->Inputs().HasTag(kScoresTag));
|
||||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
|
||||||
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
|
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
|
||||||
cc->Outputs().Tag("TOP_K_INDEXES").Set<std::vector<int>>();
|
cc->Outputs().Tag(kTopKIndexesTag).Set<std::vector<int>>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
|
if (cc->Outputs().HasTag(kTopKScoresTag)) {
|
||||||
cc->Outputs().Tag("TOP_K_SCORES").Set<std::vector<float>>();
|
cc->Outputs().Tag(kTopKScoresTag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||||
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
|
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
|
if (cc->Outputs().HasTag(kClassificationsTag)) {
|
||||||
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||||
cc->Outputs().Tag("SUMMARY").Set<std::string>();
|
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
|
||||||
if (options.has_label_map_path()) {
|
if (options.has_label_map_path()) {
|
||||||
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
|
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||||
RET_CHECK(!label_map_.empty());
|
RET_CHECK(!label_map_.empty());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||||
const std::vector<float>& input_vector =
|
const std::vector<float>& input_vector =
|
||||||
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
|
||||||
std::vector<int> top_k_indexes;
|
std::vector<int> top_k_indexes;
|
||||||
|
|
||||||
std::vector<float> top_k_scores;
|
std::vector<float> top_k_scores;
|
||||||
|
@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||||
top_k_labels.push_back(label_map_[index]);
|
top_k_labels.push_back(label_map_[index]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
|
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TOP_K_INDEXES")
|
.Tag(kTopKIndexesTag)
|
||||||
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
|
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
|
if (cc->Outputs().HasTag(kTopKScoresTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TOP_K_SCORES")
|
.Tag(kTopKScoresTag)
|
||||||
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
|
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TOP_K_LABELS")
|
.Tag(kTopKLabelsTag)
|
||||||
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
|
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||||
std::vector<std::string> results;
|
std::vector<std::string> results;
|
||||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||||
if (label_map_loaded_) {
|
if (label_map_loaded_) {
|
||||||
|
@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||||
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
|
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("SUMMARY").AddPacket(
|
cc->Outputs()
|
||||||
MakePacket<std::string>(absl::StrJoin(results, ","))
|
.Tag(kSummaryTag)
|
||||||
|
.AddPacket(MakePacket<std::string>(absl::StrJoin(results, ","))
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
|
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||||
Classification* classification =
|
Classification* classification =
|
||||||
|
|
|
@ -23,6 +23,10 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||||
|
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||||
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
|
|
||||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "TopKScoresCalculator"
|
calculator: "TopKScoresCalculator"
|
||||||
|
@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) {
|
||||||
|
|
||||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||||
|
|
||||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kScoresTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& indexes_outputs =
|
const std::vector<Packet>& indexes_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||||
ASSERT_EQ(1, indexes_outputs.size());
|
ASSERT_EQ(1, indexes_outputs.size());
|
||||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||||
EXPECT_EQ(2, indexes.size());
|
EXPECT_EQ(2, indexes.size());
|
||||||
EXPECT_EQ(3, indexes[0]);
|
EXPECT_EQ(3, indexes[0]);
|
||||||
EXPECT_EQ(0, indexes[1]);
|
EXPECT_EQ(0, indexes[1]);
|
||||||
const std::vector<Packet>& scores_outputs =
|
const std::vector<Packet>& scores_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||||
ASSERT_EQ(1, scores_outputs.size());
|
ASSERT_EQ(1, scores_outputs.size());
|
||||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||||
EXPECT_EQ(2, scores.size());
|
EXPECT_EQ(2, scores.size());
|
||||||
|
@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
|
||||||
|
|
||||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||||
|
|
||||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kScoresTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& indexes_outputs =
|
const std::vector<Packet>& indexes_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||||
ASSERT_EQ(1, indexes_outputs.size());
|
ASSERT_EQ(1, indexes_outputs.size());
|
||||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||||
EXPECT_EQ(4, indexes.size());
|
EXPECT_EQ(4, indexes.size());
|
||||||
|
@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
|
||||||
EXPECT_EQ(2, indexes[2]);
|
EXPECT_EQ(2, indexes[2]);
|
||||||
EXPECT_EQ(1, indexes[3]);
|
EXPECT_EQ(1, indexes[3]);
|
||||||
const std::vector<Packet>& scores_outputs =
|
const std::vector<Packet>& scores_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||||
ASSERT_EQ(1, scores_outputs.size());
|
ASSERT_EQ(1, scores_outputs.size());
|
||||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||||
EXPECT_EQ(4, scores.size());
|
EXPECT_EQ(4, scores.size());
|
||||||
|
@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
||||||
|
|
||||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||||
|
|
||||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
runner.MutableInputs()
|
||||||
|
->Tag(kScoresTag)
|
||||||
|
.packets.push_back(
|
||||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
const std::vector<Packet>& indexes_outputs =
|
const std::vector<Packet>& indexes_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||||
ASSERT_EQ(1, indexes_outputs.size());
|
ASSERT_EQ(1, indexes_outputs.size());
|
||||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||||
EXPECT_EQ(3, indexes.size());
|
EXPECT_EQ(3, indexes.size());
|
||||||
|
@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
||||||
EXPECT_EQ(0, indexes[1]);
|
EXPECT_EQ(0, indexes[1]);
|
||||||
EXPECT_EQ(2, indexes[2]);
|
EXPECT_EQ(2, indexes[2]);
|
||||||
const std::vector<Packet>& scores_outputs =
|
const std::vector<Packet>& scores_outputs =
|
||||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||||
ASSERT_EQ(1, scores_outputs.size());
|
ASSERT_EQ(1, scores_outputs.size());
|
||||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||||
EXPECT_EQ(3, scores.size());
|
EXPECT_EQ(3, scores.size());
|
||||||
|
|
|
@ -47,6 +47,21 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT";
|
||||||
|
constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME";
|
||||||
|
constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING";
|
||||||
|
constexpr char kVizTag[] = "VIZ";
|
||||||
|
constexpr char kBoxesTag[] = "BOXES";
|
||||||
|
constexpr char kReacqSwitchTag[] = "REACQ_SWITCH";
|
||||||
|
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
|
||||||
|
constexpr char kAddIndexTag[] = "ADD_INDEX";
|
||||||
|
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||||
|
constexpr char kDescriptorsTag[] = "DESCRIPTORS";
|
||||||
|
constexpr char kFeaturesTag[] = "FEATURES";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES";
|
||||||
|
constexpr char kTrackingTag[] = "TRACKING";
|
||||||
|
|
||||||
// A calculator to detect reappeared box positions from single frame.
|
// A calculator to detect reappeared box positions from single frame.
|
||||||
//
|
//
|
||||||
// Input stream:
|
// Input stream:
|
||||||
|
@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(BoxDetectorCalculator);
|
REGISTER_CALCULATOR(BoxDetectorCalculator);
|
||||||
|
|
||||||
absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Inputs().HasTag("TRACKING")) {
|
if (cc->Inputs().HasTag(kTrackingTag)) {
|
||||||
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
|
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("TRACKED_BOXES")) {
|
if (cc->Inputs().HasTag(kTrackedBoxesTag)) {
|
||||||
cc->Inputs().Tag("TRACKED_BOXES").Set<TimedBoxProtoList>();
|
cc->Inputs().Tag(kTrackedBoxesTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("VIDEO")) {
|
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("FEATURES")) {
|
if (cc->Inputs().HasTag(kFeaturesTag)) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS"))
|
RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag))
|
||||||
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
||||||
cc->Inputs().Tag("FEATURES").Set<std::vector<cv::KeyPoint>>();
|
cc->Inputs().Tag(kFeaturesTag).Set<std::vector<cv::KeyPoint>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("DESCRIPTORS")) {
|
if (cc->Inputs().HasTag(kDescriptorsTag)) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("FEATURES"))
|
RET_CHECK(cc->Inputs().HasTag(kFeaturesTag))
|
||||||
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
||||||
cc->Inputs().Tag("DESCRIPTORS").Set<std::vector<float>>();
|
cc->Inputs().Tag(kDescriptorsTag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("IMAGE_SIZE")) {
|
if (cc->Inputs().HasTag(kImageSizeTag)) {
|
||||||
cc->Inputs().Tag("IMAGE_SIZE").Set<std::pair<int, int>>();
|
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("ADD_INDEX")) {
|
if (cc->Inputs().HasTag(kAddIndexTag)) {
|
||||||
cc->Inputs().Tag("ADD_INDEX").Set<std::string>();
|
cc->Inputs().Tag(kAddIndexTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
|
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
|
||||||
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
|
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("REACQ_SWITCH")) {
|
if (cc->Inputs().HasTag(kReacqSwitchTag)) {
|
||||||
cc->Inputs().Tag("REACQ_SWITCH").Set<bool>();
|
cc->Inputs().Tag(kReacqSwitchTag).Set<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("BOXES")) {
|
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||||
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
|
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIZ")) {
|
if (cc->Outputs().HasTag(kVizTag)) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
|
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
|
||||||
<< "Output stream VIZ requires VIDEO to be present.";
|
<< "Output stream VIZ requires VIDEO to be present.";
|
||||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
|
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
|
||||||
cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set<std::string>();
|
cc->InputSidePackets().Tag(kIndexProtoStringTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
|
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
|
||||||
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set<std::string>();
|
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
|
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
|
||||||
cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set<int>();
|
cc->InputSidePackets().Tag(kFrameAlignmentTag).Set<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
|
||||||
options_ = cc->Options<BoxDetectorCalculatorOptions>();
|
options_ = cc->Options<BoxDetectorCalculatorOptions>();
|
||||||
box_detector_ = BoxDetectorInterface::Create(options_.detector_options());
|
box_detector_ = BoxDetectorInterface::Create(options_.detector_options());
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
|
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
|
||||||
BoxDetectorIndex predefined_index;
|
BoxDetectorIndex predefined_index;
|
||||||
if (!predefined_index.ParseFromString(cc->InputSidePackets()
|
if (!predefined_index.ParseFromString(cc->InputSidePackets()
|
||||||
.Tag("INDEX_PROTO_STRING")
|
.Tag(kIndexProtoStringTag)
|
||||||
.Get<std::string>())) {
|
.Get<std::string>())) {
|
||||||
LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING";
|
LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING";
|
||||||
}
|
}
|
||||||
|
@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
|
||||||
box_detector_->AddBoxDetectorIndex(predefined_index);
|
box_detector_->AddBoxDetectorIndex(predefined_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
|
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
|
||||||
write_index_ = true;
|
write_index_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
|
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
|
||||||
frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get<int>();
|
frame_alignment_ =
|
||||||
|
cc->InputSidePackets().Tag(kFrameAlignmentTag).Get<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
const int64 timestamp_msec = timestamp.Value() / 1000;
|
const int64 timestamp_msec = timestamp.Value() / 1000;
|
||||||
|
|
||||||
InputStream* cancel_object_id_stream =
|
InputStream* cancel_object_id_stream =
|
||||||
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
|
cc->Inputs().HasTag(kCancelObjectIdTag)
|
||||||
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
|
? &(cc->Inputs().Tag(kCancelObjectIdTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
||||||
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
||||||
box_detector_->CancelBoxDetection(cancel_object_id);
|
box_detector_->CancelBoxDetection(cancel_object_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX")
|
InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag)
|
||||||
? &(cc->Inputs().Tag("ADD_INDEX"))
|
? &(cc->Inputs().Tag(kAddIndexTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (add_index_stream && !add_index_stream->IsEmpty()) {
|
if (add_index_stream && !add_index_stream->IsEmpty()) {
|
||||||
BoxDetectorIndex predefined_index;
|
BoxDetectorIndex predefined_index;
|
||||||
|
@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
box_detector_->AddBoxDetectorIndex(predefined_index);
|
box_detector_->AddBoxDetectorIndex(predefined_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH")
|
InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag)
|
||||||
? &(cc->Inputs().Tag("REACQ_SWITCH"))
|
? &(cc->Inputs().Tag(kReacqSwitchTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) {
|
if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) {
|
||||||
detector_switch_ = reacq_switch_stream->Get<bool>();
|
detector_switch_ = reacq_switch_stream->Get<bool>();
|
||||||
|
@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
|
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
|
||||||
? &(cc->Inputs().Tag("TRACKING"))
|
? &(cc->Inputs().Tag(kTrackingTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
InputStream* video_stream =
|
InputStream* video_stream =
|
||||||
cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
|
cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
|
||||||
InputStream* feature_stream = cc->Inputs().HasTag("FEATURES")
|
InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag)
|
||||||
? &(cc->Inputs().Tag("FEATURES"))
|
? &(cc->Inputs().Tag(kFeaturesTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS")
|
InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag)
|
||||||
? &(cc->Inputs().Tag("DESCRIPTORS"))
|
? &(cc->Inputs().Tag(kDescriptorsTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
CHECK(track_stream != nullptr || video_stream != nullptr ||
|
CHECK(track_stream != nullptr || video_stream != nullptr ||
|
||||||
|
@ -266,8 +282,9 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
<< "One and only one of {tracking_data, input image frame, "
|
<< "One and only one of {tracking_data, input image frame, "
|
||||||
"feature/descriptor} need to be valid.";
|
"feature/descriptor} need to be valid.";
|
||||||
|
|
||||||
InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES")
|
InputStream* tracked_boxes_stream =
|
||||||
? &(cc->Inputs().Tag("TRACKED_BOXES"))
|
cc->Inputs().HasTag(kTrackedBoxesTag)
|
||||||
|
? &(cc->Inputs().Tag(kTrackedBoxesTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList());
|
std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList());
|
||||||
|
|
||||||
|
@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& image_size =
|
const auto& image_size =
|
||||||
cc->Inputs().Tag("IMAGE_SIZE").Get<std::pair<int, int>>();
|
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||||
float inv_scale = 1.0f / std::max(image_size.first, image_size.second);
|
float inv_scale = 1.0f / std::max(image_size.first, image_size.second);
|
||||||
|
|
||||||
TimedBoxProtoList tracked_boxes;
|
TimedBoxProtoList tracked_boxes;
|
||||||
|
@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
detected_boxes.get());
|
detected_boxes.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIZ")) {
|
if (cc->Outputs().HasTag(kVizTag)) {
|
||||||
cv::Mat viz_view;
|
cv::Mat viz_view;
|
||||||
std::unique_ptr<ImageFrame> viz_frame;
|
std::unique_ptr<ImageFrame> viz_frame;
|
||||||
if (video_stream != nullptr && !video_stream->IsEmpty()) {
|
if (video_stream != nullptr && !video_stream->IsEmpty()) {
|
||||||
|
@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
||||||
for (const auto& box : detected_boxes->box()) {
|
for (const auto& box : detected_boxes->box()) {
|
||||||
RenderBox(box, &viz_view);
|
RenderBox(box, &viz_view);
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
|
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("BOXES")) {
|
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||||
cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp);
|
cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) {
|
||||||
if (write_index_) {
|
if (write_index_) {
|
||||||
BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex();
|
BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex();
|
||||||
MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents(
|
MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents(
|
||||||
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get<std::string>(),
|
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get<std::string>(),
|
||||||
index.SerializeAsString()));
|
index.SerializeAsString()));
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kCacheDirTag[] = "CACHE_DIR";
|
||||||
|
constexpr char kInitialPosTag[] = "INITIAL_POS";
|
||||||
|
constexpr char kRaBoxesTag[] = "RA_BOXES";
|
||||||
|
constexpr char kBoxesTag[] = "BOXES";
|
||||||
|
constexpr char kVizTag[] = "VIZ";
|
||||||
|
constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING";
|
||||||
|
constexpr char kRaTrackTag[] = "RA_TRACK";
|
||||||
|
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
|
||||||
|
constexpr char kRestartPosTag[] = "RESTART_POS";
|
||||||
|
constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING";
|
||||||
|
constexpr char kStartPosTag[] = "START_POS";
|
||||||
|
constexpr char kStartTag[] = "START";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
constexpr char kTrackTimeTag[] = "TRACK_TIME";
|
||||||
|
constexpr char kTrackingTag[] = "TRACKING";
|
||||||
|
|
||||||
// Convert box position according to rotation angle in degrees.
|
// Convert box position according to rotation angle in degrees.
|
||||||
void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom,
|
void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom,
|
||||||
float in_right, int rotation, float* out_top,
|
float in_right, int rotation, float* out_top,
|
||||||
|
@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec,
|
||||||
} // namespace.
|
} // namespace.
|
||||||
|
|
||||||
absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Inputs().HasTag("TRACKING")) {
|
if (cc->Inputs().HasTag(kTrackingTag)) {
|
||||||
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
|
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("TRACK_TIME")) {
|
if (cc->Inputs().HasTag(kTrackTimeTag)) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("TRACKING"))
|
RET_CHECK(cc->Inputs().HasTag(kTrackingTag))
|
||||||
<< "TRACK_TIME needs TRACKING input";
|
<< "TRACK_TIME needs TRACKING input";
|
||||||
cc->Inputs().Tag("TRACK_TIME").SetAny();
|
cc->Inputs().Tag(kTrackTimeTag).SetAny();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("VIDEO")) {
|
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("START")) {
|
if (cc->Inputs().HasTag(kStartTag)) {
|
||||||
// Actual packet content does not matter.
|
// Actual packet content does not matter.
|
||||||
cc->Inputs().Tag("START").SetAny();
|
cc->Inputs().Tag(kStartTag).SetAny();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("START_POS")) {
|
if (cc->Inputs().HasTag(kStartPosTag)) {
|
||||||
cc->Inputs().Tag("START_POS").Set<TimedBoxProtoList>();
|
cc->Inputs().Tag(kStartPosTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) {
|
if (cc->Inputs().HasTag(kStartPosProtoStringTag)) {
|
||||||
cc->Inputs().Tag("START_POS_PROTO_STRING").Set<std::string>();
|
cc->Inputs().Tag(kStartPosProtoStringTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("RESTART_POS")) {
|
if (cc->Inputs().HasTag(kRestartPosTag)) {
|
||||||
cc->Inputs().Tag("RESTART_POS").Set<TimedBoxProtoList>();
|
cc->Inputs().Tag(kRestartPosTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
|
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
|
||||||
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
|
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("RA_TRACK")) {
|
if (cc->Inputs().HasTag(kRaTrackTag)) {
|
||||||
cc->Inputs().Tag("RA_TRACK").Set<TimedBoxProtoList>();
|
cc->Inputs().Tag(kRaTrackTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) {
|
if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) {
|
||||||
cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set<std::string>();
|
cc->Inputs().Tag(kRaTrackProtoStringTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIZ")) {
|
if (cc->Outputs().HasTag(kVizTag)) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
|
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
|
||||||
<< "Output stream VIZ requires VIDEO to be present.";
|
<< "Output stream VIZ requires VIDEO to be present.";
|
||||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("BOXES")) {
|
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||||
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
|
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("RA_BOXES")) {
|
if (cc->Outputs().HasTag(kRaBoxesTag)) {
|
||||||
cc->Outputs().Tag("RA_BOXES").Set<TimedBoxProtoList>();
|
cc->Outputs().Tag(kRaBoxesTag).Set<TimedBoxProtoList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||||
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS"))
|
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag))
|
||||||
<< "Unsupported on mobile";
|
<< "Unsupported on mobile";
|
||||||
#else
|
#else
|
||||||
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
|
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
|
||||||
cc->InputSidePackets().Tag("INITIAL_POS").Set<std::string>();
|
cc->InputSidePackets().Tag(kInitialPosTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
#endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
#endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||||
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
|
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(cc->Inputs().HasTag("TRACKING") !=
|
RET_CHECK(cc->Inputs().HasTag(kTrackingTag) !=
|
||||||
cc->InputSidePackets().HasTag("CACHE_DIR"))
|
cc->InputSidePackets().HasTag(kCacheDirTag))
|
||||||
<< "Either TRACKING or CACHE_DIR needs to be specified.";
|
<< "Either TRACKING or CACHE_DIR needs to be specified.";
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||||
|
@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
||||||
options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(),
|
options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(),
|
||||||
cc->InputSidePackets(), kOptionsTag);
|
cc->InputSidePackets(), kOptionsTag);
|
||||||
|
|
||||||
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") ||
|
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) ||
|
||||||
!options_.has_initial_position())
|
!options_.has_initial_position())
|
||||||
<< "Can not specify initial position as side packet and via options";
|
<< "Can not specify initial position as side packet and via options";
|
||||||
|
|
||||||
|
@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__)
|
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__)
|
||||||
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
|
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
|
||||||
LOG(INFO) << "Parsing: "
|
LOG(INFO) << "Parsing: "
|
||||||
<< cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>();
|
<< cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>();
|
||||||
initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>(
|
initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>(
|
||||||
cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>());
|
cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>());
|
||||||
}
|
}
|
||||||
#endif // !defined(__ANDROID__) && !defined(__APPLE__) &&
|
#endif // !defined(__ANDROID__) && !defined(__APPLE__) &&
|
||||||
// !defined(__EMSCRIPTEN__)
|
// !defined(__EMSCRIPTEN__)
|
||||||
|
@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
visualize_tracking_data_ =
|
visualize_tracking_data_ =
|
||||||
options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ");
|
options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag);
|
||||||
visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ");
|
visualize_state_ =
|
||||||
|
options_.visualize_state() && cc->Outputs().HasTag(kVizTag);
|
||||||
visualize_internal_state_ =
|
visualize_internal_state_ =
|
||||||
options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ");
|
options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag);
|
||||||
|
|
||||||
// Force recording of internal state for rendering.
|
// Force recording of internal state for rendering.
|
||||||
if (visualize_internal_state_) {
|
if (visualize_internal_state_) {
|
||||||
|
@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
||||||
options_.mutable_tracker_options()->set_record_path_states(true);
|
options_.mutable_tracker_options()->set_record_path_states(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||||
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
|
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
|
||||||
RET_CHECK(!cache_dir_.empty());
|
RET_CHECK(!cache_dir_.empty());
|
||||||
box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options()));
|
box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options()));
|
||||||
} else {
|
} else {
|
||||||
|
@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options_.streaming_track_data_cache_size() > 0) {
|
if (options_.streaming_track_data_cache_size() > 0) {
|
||||||
RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR"))
|
RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag))
|
||||||
<< "Streaming mode not compatible with cache dir.";
|
<< "Streaming mode not compatible with cache dir.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
|
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
|
||||||
? &(cc->Inputs().Tag("TRACKING"))
|
? &(cc->Inputs().Tag(kTrackingTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME")
|
InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag)
|
||||||
? &(cc->Inputs().Tag("TRACK_TIME"))
|
? &(cc->Inputs().Tag(kTrackTimeTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
// Cache tracking data if possible.
|
// Cache tracking data if possible.
|
||||||
|
@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS")
|
InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag)
|
||||||
? &(cc->Inputs().Tag("START_POS"))
|
? &(cc->Inputs().Tag(kStartPosTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
MotionBoxMap fast_forward_boxes;
|
MotionBoxMap fast_forward_boxes;
|
||||||
|
@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* start_pos_proto_string_stream =
|
InputStream* start_pos_proto_string_stream =
|
||||||
cc->Inputs().HasTag("START_POS_PROTO_STRING")
|
cc->Inputs().HasTag(kStartPosProtoStringTag)
|
||||||
? &(cc->Inputs().Tag("START_POS_PROTO_STRING"))
|
? &(cc->Inputs().Tag(kStartPosProtoStringTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) {
|
if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) {
|
||||||
if (start_pos_proto_string_stream &&
|
if (start_pos_proto_string_stream &&
|
||||||
|
@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS")
|
InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag)
|
||||||
? &(cc->Inputs().Tag("RESTART_POS"))
|
? &(cc->Inputs().Tag(kRestartPosTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
if (restart_pos_stream && !restart_pos_stream->IsEmpty()) {
|
if (restart_pos_stream && !restart_pos_stream->IsEmpty()) {
|
||||||
|
@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* cancel_object_id_stream =
|
InputStream* cancel_object_id_stream =
|
||||||
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
|
cc->Inputs().HasTag(kCancelObjectIdTag)
|
||||||
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
|
? &(cc->Inputs().Tag(kCancelObjectIdTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
||||||
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
||||||
|
@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
TrackingData track_data_to_render;
|
TrackingData track_data_to_render;
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIZ")) {
|
if (cc->Outputs().HasTag(kVizTag)) {
|
||||||
InputStream* video_stream = &(cc->Inputs().Tag("VIDEO"));
|
InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag));
|
||||||
if (!video_stream->IsEmpty()) {
|
if (!video_stream->IsEmpty()) {
|
||||||
input_view = formats::MatView(&video_stream->Get<ImageFrame>());
|
input_view = formats::MatView(&video_stream->Get<ImageFrame>());
|
||||||
|
|
||||||
|
@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
++frame_num_since_reset_;
|
++frame_num_since_reset_;
|
||||||
|
|
||||||
// Generate results for queued up request.
|
// Generate results for queued up request.
|
||||||
if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) {
|
if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) {
|
||||||
for (int j = 0; j < queued_track_requests_.size(); ++j) {
|
for (int j = 0; j < queued_track_requests_.size(); ++j) {
|
||||||
const Timestamp& past_time = queued_track_requests_[j];
|
const Timestamp& past_time = queued_track_requests_[j];
|
||||||
RET_CHECK(past_time.Value() < timestamp.Value())
|
RET_CHECK(past_time.Value() < timestamp.Value())
|
||||||
|
@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output for every time.
|
// Output for every time.
|
||||||
cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time);
|
cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
queued_track_requests_.clear();
|
queued_track_requests_.clear();
|
||||||
|
@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle random access track requests.
|
// Handle random access track requests.
|
||||||
InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK")
|
InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag)
|
||||||
? &(cc->Inputs().Tag("RA_TRACK"))
|
? &(cc->Inputs().Tag(kRaTrackTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
if (ra_track_stream && !ra_track_stream->IsEmpty()) {
|
if (ra_track_stream && !ra_track_stream->IsEmpty()) {
|
||||||
|
@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* ra_track_proto_string_stream =
|
InputStream* ra_track_proto_string_stream =
|
||||||
cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")
|
cc->Inputs().HasTag(kRaTrackProtoStringTag)
|
||||||
? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING"))
|
? &(cc->Inputs().Tag(kRaTrackProtoStringTag))
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) {
|
if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) {
|
||||||
if (ra_track_proto_string_stream &&
|
if (ra_track_proto_string_stream &&
|
||||||
|
@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Always output in batch, only output in streaming if tracking data
|
// Always output in batch, only output in streaming if tracking data
|
||||||
// is present (might be in fast forward mode instead).
|
// is present (might be in fast forward mode instead).
|
||||||
if (cc->Outputs().HasTag("BOXES") &&
|
if (cc->Outputs().HasTag(kBoxesTag) &&
|
||||||
(box_tracker_ || !track_stream->IsEmpty())) {
|
(box_tracker_ || !track_stream->IsEmpty())) {
|
||||||
std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList());
|
std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList());
|
||||||
*boxes = std::move(box_track_list);
|
*boxes = std::move(box_track_list);
|
||||||
cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp);
|
cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (viz_frame) {
|
if (viz_frame) {
|
||||||
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
|
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack(
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("RA_BOXES")
|
.Tag(kRaBoxesTag)
|
||||||
.Add(result_list.release(), cc->InputTimestamp());
|
.Add(result_list.release(), cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,13 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kCacheDirTag[] = "CACHE_DIR";
|
||||||
|
constexpr char kCompleteTag[] = "COMPLETE";
|
||||||
|
constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK";
|
||||||
|
constexpr char kTrackingTag[] = "TRACKING";
|
||||||
|
constexpr char kCameraTag[] = "CAMERA";
|
||||||
|
constexpr char kFlowTag[] = "FLOW";
|
||||||
|
|
||||||
using mediapipe::CameraMotion;
|
using mediapipe::CameraMotion;
|
||||||
using mediapipe::FlowPackager;
|
using mediapipe::FlowPackager;
|
||||||
using mediapipe::RegionFlowFeatureList;
|
using mediapipe::RegionFlowFeatureList;
|
||||||
|
@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(FlowPackagerCalculator);
|
REGISTER_CALCULATOR(FlowPackagerCalculator);
|
||||||
|
|
||||||
absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (!cc->Inputs().HasTag("FLOW")) {
|
if (!cc->Inputs().HasTag(kFlowTag)) {
|
||||||
return tool::StatusFail("No input flow was specified.");
|
return tool::StatusFail("No input flow was specified.");
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Inputs().Tag("FLOW").Set<RegionFlowFeatureList>();
|
cc->Inputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
|
||||||
|
|
||||||
if (cc->Inputs().HasTag("CAMERA")) {
|
if (cc->Inputs().HasTag(kCameraTag)) {
|
||||||
cc->Inputs().Tag("CAMERA").Set<CameraMotion>();
|
cc->Inputs().Tag(kCameraTag).Set<CameraMotion>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TRACKING")) {
|
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||||
cc->Outputs().Tag("TRACKING").Set<TrackingData>();
|
cc->Outputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||||
cc->Outputs().Tag("TRACKING_CHUNK").Set<TrackingDataChunk>();
|
cc->Outputs().Tag(kTrackingChunkTag).Set<TrackingDataChunk>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("COMPLETE")) {
|
if (cc->Outputs().HasTag(kCompleteTag)) {
|
||||||
cc->Outputs().Tag("COMPLETE").Set<bool>();
|
cc->Outputs().Tag(kCompleteTag).Set<bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||||
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
|
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
flow_packager_.reset(new FlowPackager(options_.flow_packager_options()));
|
flow_packager_.reset(new FlowPackager(options_.flow_packager_options()));
|
||||||
|
|
||||||
use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR");
|
use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag);
|
||||||
build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK");
|
build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag);
|
||||||
if (use_caching_) {
|
if (use_caching_) {
|
||||||
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
|
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||||
InputStream* flow_stream = &(cc->Inputs().Tag("FLOW"));
|
InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag));
|
||||||
const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>();
|
const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>();
|
||||||
|
|
||||||
const Timestamp timestamp = flow_stream->Value().Timestamp();
|
const Timestamp timestamp = flow_stream->Value().Timestamp();
|
||||||
|
|
||||||
const CameraMotion* camera_motion = nullptr;
|
const CameraMotion* camera_motion = nullptr;
|
||||||
if (cc->Inputs().HasTag("CAMERA")) {
|
if (cc->Inputs().HasTag(kCameraTag)) {
|
||||||
InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA"));
|
InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag));
|
||||||
camera_motion = &camera_stream->Get<CameraMotion>();
|
camera_motion = &camera_stream->Get<CameraMotion>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||||
if (frame_idx_ > 0) {
|
if (frame_idx_ > 0) {
|
||||||
item->set_prev_timestamp_usec(prev_timestamp_.Value());
|
item->set_prev_timestamp_usec(prev_timestamp_.Value());
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("TRACKING")) {
|
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||||
// Need to copy as output is requested.
|
// Need to copy as output is requested.
|
||||||
*item->mutable_tracking_data() = *tracking_data;
|
*item->mutable_tracking_data() = *tracking_data;
|
||||||
} else {
|
} else {
|
||||||
|
@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||||
options_.caching_chunk_size_msec() * (chunk_idx_ + 1);
|
options_.caching_chunk_size_msec() * (chunk_idx_ + 1);
|
||||||
|
|
||||||
if (timestamp.Value() / 1000 >= next_chunk_msec) {
|
if (timestamp.Value() / 1000 >= next_chunk_msec) {
|
||||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TRACKING_CHUNK")
|
.Tag(kTrackingChunkTag)
|
||||||
.Add(new TrackingDataChunk(tracking_chunk_),
|
.Add(new TrackingDataChunk(tracking_chunk_),
|
||||||
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
||||||
}
|
}
|
||||||
|
@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("TRACKING")) {
|
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TRACKING")
|
.Tag(kTrackingTag)
|
||||||
.Add(tracking_data.release(), flow_stream->Value().Timestamp());
|
.Add(tracking_data.release(), flow_stream->Value().Timestamp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||||
absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
|
absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
|
||||||
if (frame_idx_ > 0) {
|
if (frame_idx_ > 0) {
|
||||||
tracking_chunk_.set_last_chunk(true);
|
tracking_chunk_.set_last_chunk(true);
|
||||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("TRACKING_CHUNK")
|
.Tag(kTrackingChunkTag)
|
||||||
.Add(new TrackingDataChunk(tracking_chunk_),
|
.Add(new TrackingDataChunk(tracking_chunk_),
|
||||||
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
||||||
}
|
}
|
||||||
|
@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("COMPLETE")) {
|
if (cc->Outputs().HasTag(kCompleteTag)) {
|
||||||
cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream());
|
cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream());
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -38,6 +38,18 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kDownsampleTag[] = "DOWNSAMPLE";
|
||||||
|
constexpr char kCsvFileTag[] = "CSV_FILE";
|
||||||
|
constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT";
|
||||||
|
constexpr char kVideoOutTag[] = "VIDEO_OUT";
|
||||||
|
constexpr char kDenseFgTag[] = "DENSE_FG";
|
||||||
|
constexpr char kVizTag[] = "VIZ";
|
||||||
|
constexpr char kSaliencyTag[] = "SALIENCY";
|
||||||
|
constexpr char kCameraTag[] = "CAMERA";
|
||||||
|
constexpr char kFlowTag[] = "FLOW";
|
||||||
|
constexpr char kSelectionTag[] = "SELECTION";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
|
||||||
using mediapipe::AffineAdapter;
|
using mediapipe::AffineAdapter;
|
||||||
using mediapipe::CameraMotion;
|
using mediapipe::CameraMotion;
|
||||||
using mediapipe::FrameSelectionResult;
|
using mediapipe::FrameSelectionResult;
|
||||||
|
@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(MotionAnalysisCalculator);
|
REGISTER_CALCULATOR(MotionAnalysisCalculator);
|
||||||
|
|
||||||
absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Inputs().HasTag("VIDEO")) {
|
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional input stream from frame selection calculator.
|
// Optional input stream from frame selection calculator.
|
||||||
if (cc->Inputs().HasTag("SELECTION")) {
|
if (cc->Inputs().HasTag(kSelectionTag)) {
|
||||||
cc->Inputs().Tag("SELECTION").Set<FrameSelectionResult>();
|
cc->Inputs().Tag(kSelectionTag).Set<FrameSelectionResult>();
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION"))
|
RET_CHECK(cc->Inputs().HasTag(kVideoTag) ||
|
||||||
|
cc->Inputs().HasTag(kSelectionTag))
|
||||||
<< "Either VIDEO, SELECTION must be specified.";
|
<< "Either VIDEO, SELECTION must be specified.";
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("FLOW")) {
|
if (cc->Outputs().HasTag(kFlowTag)) {
|
||||||
cc->Outputs().Tag("FLOW").Set<RegionFlowFeatureList>();
|
cc->Outputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("CAMERA")) {
|
if (cc->Outputs().HasTag(kCameraTag)) {
|
||||||
cc->Outputs().Tag("CAMERA").Set<CameraMotion>();
|
cc->Outputs().Tag(kCameraTag).Set<CameraMotion>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||||
cc->Outputs().Tag("SALIENCY").Set<SalientPointFrame>();
|
cc->Outputs().Tag(kSaliencyTag).Set<SalientPointFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIZ")) {
|
if (cc->Outputs().HasTag(kVizTag)) {
|
||||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("DENSE_FG")) {
|
if (cc->Outputs().HasTag(kDenseFgTag)) {
|
||||||
cc->Outputs().Tag("DENSE_FG").Set<ImageFrame>();
|
cc->Outputs().Tag(kDenseFgTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIDEO_OUT")) {
|
if (cc->Outputs().HasTag(kVideoOutTag)) {
|
||||||
cc->Outputs().Tag("VIDEO_OUT").Set<ImageFrame>();
|
cc->Outputs().Tag(kVideoOutTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) {
|
if (cc->Outputs().HasTag(kGrayVideoOutTag)) {
|
||||||
// We only output grayscale video if we're actually performing full region-
|
// We only output grayscale video if we're actually performing full region-
|
||||||
// flow analysis on the video.
|
// flow analysis on the video.
|
||||||
RET_CHECK(cc->Inputs().HasTag("VIDEO") &&
|
RET_CHECK(cc->Inputs().HasTag(kVideoTag) &&
|
||||||
!cc->Inputs().HasTag("SELECTION"));
|
!cc->Inputs().HasTag(kSelectionTag));
|
||||||
cc->Outputs().Tag("GRAY_VIDEO_OUT").Set<ImageFrame>();
|
cc->Outputs().Tag(kGrayVideoOutTag).Set<ImageFrame>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("CSV_FILE")) {
|
if (cc->InputSidePackets().HasTag(kCsvFileTag)) {
|
||||||
cc->InputSidePackets().Tag("CSV_FILE").Set<std::string>();
|
cc->InputSidePackets().Tag(kCsvFileTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
|
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
|
||||||
cc->InputSidePackets().Tag("DOWNSAMPLE").Set<float>();
|
cc->InputSidePackets().Tag(kDownsampleTag).Set<float>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||||
|
@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(),
|
tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(),
|
||||||
cc->InputSidePackets(), kOptionsTag);
|
cc->InputSidePackets(), kOptionsTag);
|
||||||
|
|
||||||
video_input_ = cc->Inputs().HasTag("VIDEO");
|
video_input_ = cc->Inputs().HasTag(kVideoTag);
|
||||||
selection_input_ = cc->Inputs().HasTag("SELECTION");
|
selection_input_ = cc->Inputs().HasTag(kSelectionTag);
|
||||||
region_flow_feature_output_ = cc->Outputs().HasTag("FLOW");
|
region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag);
|
||||||
camera_motion_output_ = cc->Outputs().HasTag("CAMERA");
|
camera_motion_output_ = cc->Outputs().HasTag(kCameraTag);
|
||||||
saliency_output_ = cc->Outputs().HasTag("SALIENCY");
|
saliency_output_ = cc->Outputs().HasTag(kSaliencyTag);
|
||||||
visualize_output_ = cc->Outputs().HasTag("VIZ");
|
visualize_output_ = cc->Outputs().HasTag(kVizTag);
|
||||||
dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG");
|
dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag);
|
||||||
video_output_ = cc->Outputs().HasTag("VIDEO_OUT");
|
video_output_ = cc->Outputs().HasTag(kVideoOutTag);
|
||||||
grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT");
|
grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag);
|
||||||
csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE");
|
csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag);
|
||||||
hybrid_meta_analysis_ = options_.meta_analysis() ==
|
hybrid_meta_analysis_ = options_.meta_analysis() ==
|
||||||
MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID;
|
MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID;
|
||||||
|
|
||||||
|
@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
if (csv_file_input_) {
|
if (csv_file_input_) {
|
||||||
// Read from file and parse.
|
// Read from file and parse.
|
||||||
const std::string filename =
|
const std::string filename =
|
||||||
cc->InputSidePackets().Tag("CSV_FILE").Get<std::string>();
|
cc->InputSidePackets().Tag(kCsvFileTag).Get<std::string>();
|
||||||
|
|
||||||
std::string file_contents;
|
std::string file_contents;
|
||||||
std::ifstream input_file(filename, std::ios::in);
|
std::ifstream input_file(filename, std::ios::in);
|
||||||
|
@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Get video header from video or selection input if present.
|
// Get video header from video or selection input if present.
|
||||||
const VideoHeader* video_header = nullptr;
|
const VideoHeader* video_header = nullptr;
|
||||||
if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) {
|
if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) {
|
||||||
video_header = &(cc->Inputs().Tag("VIDEO").Header().Get<VideoHeader>());
|
video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get<VideoHeader>());
|
||||||
} else if (selection_input_ &&
|
} else if (selection_input_ &&
|
||||||
!cc->Inputs().Tag("SELECTION").Header().IsEmpty()) {
|
!cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) {
|
||||||
video_header = &(cc->Inputs().Tag("SELECTION").Header().Get<VideoHeader>());
|
video_header =
|
||||||
|
&(cc->Inputs().Tag(kSelectionTag).Header().Get<VideoHeader>());
|
||||||
} else {
|
} else {
|
||||||
LOG(WARNING) << "No input video header found. Downstream calculators "
|
LOG(WARNING) << "No input video header found. Downstream calculators "
|
||||||
"expecting video headers are likely to fail.";
|
"expecting video headers are likely to fail.";
|
||||||
|
@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
with_saliency_ = options_.analysis_options().compute_motion_saliency();
|
with_saliency_ = options_.analysis_options().compute_motion_saliency();
|
||||||
// Force computation of saliency if requested as output.
|
// Force computation of saliency if requested as output.
|
||||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||||
with_saliency_ = true;
|
with_saliency_ = true;
|
||||||
if (!options_.analysis_options().compute_motion_saliency()) {
|
if (!options_.analysis_options().compute_motion_saliency()) {
|
||||||
LOG(WARNING) << "Enable saliency computation. Set "
|
LOG(WARNING) << "Enable saliency computation. Set "
|
||||||
|
@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
|
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
|
||||||
options_.mutable_analysis_options()
|
options_.mutable_analysis_options()
|
||||||
->mutable_flow_options()
|
->mutable_flow_options()
|
||||||
->set_downsample_factor(
|
->set_downsample_factor(
|
||||||
cc->InputSidePackets().Tag("DOWNSAMPLE").Get<float>());
|
cc->InputSidePackets().Tag(kDownsampleTag).Get<float>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no video header is provided, just return and initialize on the first
|
// If no video header is provided, just return and initialize on the first
|
||||||
|
@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
||||||
////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE ///////////////
|
////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE ///////////////
|
||||||
|
|
||||||
if (visualize_output_) {
|
if (visualize_output_) {
|
||||||
cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header)));
|
cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (video_output_) {
|
if (video_output_) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("VIDEO_OUT")
|
.Tag(kVideoOutTag)
|
||||||
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("DENSE_FG")) {
|
if (cc->Outputs().HasTag(kDenseFgTag)) {
|
||||||
std::unique_ptr<VideoHeader> foreground_header(
|
std::unique_ptr<VideoHeader> foreground_header(
|
||||||
new VideoHeader(*video_header));
|
new VideoHeader(*video_header));
|
||||||
foreground_header->format = ImageFormat::GRAY8;
|
foreground_header->format = ImageFormat::GRAY8;
|
||||||
cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("CAMERA")) {
|
|
||||||
cc->Outputs().Tag("CAMERA").SetHeader(
|
|
||||||
Adopt(new VideoHeader(*video_header)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("SALIENCY")
|
.Tag(kDenseFgTag)
|
||||||
|
.SetHeader(Adopt(foreground_header.release()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cc->Outputs().HasTag(kCameraTag)) {
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kCameraTag)
|
||||||
|
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kSaliencyTag)
|
||||||
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream* video_stream =
|
InputStream* video_stream =
|
||||||
video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
|
video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
|
||||||
InputStream* selection_stream =
|
InputStream* selection_stream =
|
||||||
selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr;
|
selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr;
|
||||||
|
|
||||||
// Checked on Open.
|
// Checked on Open.
|
||||||
CHECK(video_stream || selection_stream);
|
CHECK(video_stream || selection_stream);
|
||||||
|
@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
||||||
CameraMotion output_motion = meta_motions_.front();
|
CameraMotion output_motion = meta_motions_.front();
|
||||||
meta_motions_.pop_front();
|
meta_motions_.pop_front();
|
||||||
output_motion.set_timestamp_usec(timestamp.Value());
|
output_motion.set_timestamp_usec(timestamp.Value());
|
||||||
cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion),
|
cc->Outputs()
|
||||||
timestamp);
|
.Tag(kCameraTag)
|
||||||
|
.Add(new CameraMotion(output_motion), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (region_flow_feature_output_) {
|
if (region_flow_feature_output_) {
|
||||||
|
@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
||||||
meta_features_.pop_front();
|
meta_features_.pop_front();
|
||||||
|
|
||||||
output_features.set_timestamp_usec(timestamp.Value());
|
output_features.set_timestamp_usec(timestamp.Value());
|
||||||
cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features),
|
cc->Outputs().Tag(kFlowTag).Add(
|
||||||
timestamp);
|
new RegionFlowFeatureList(output_features), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
++frame_idx_;
|
++frame_idx_;
|
||||||
|
@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
||||||
MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) {
|
MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) {
|
||||||
// Output concatenated results, nothing to compute here.
|
// Output concatenated results, nothing to compute here.
|
||||||
if (camera_motion_output_) {
|
if (camera_motion_output_) {
|
||||||
cc->Outputs().Tag("CAMERA").Add(
|
cc->Outputs()
|
||||||
frame_selection_result->release_camera_motion(), timestamp);
|
.Tag(kCameraTag)
|
||||||
|
.Add(frame_selection_result->release_camera_motion(), timestamp);
|
||||||
}
|
}
|
||||||
if (region_flow_feature_output_) {
|
if (region_flow_feature_output_) {
|
||||||
cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(),
|
cc->Outputs().Tag(kFlowTag).Add(
|
||||||
timestamp);
|
frame_selection_result->release_features(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (video_output_) {
|
if (video_output_) {
|
||||||
cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value());
|
cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value());
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
||||||
grayscale_mat.copyTo(image_frame_mat);
|
grayscale_mat.copyTo(image_frame_mat);
|
||||||
|
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("GRAY_VIDEO_OUT")
|
.Tag(kGrayVideoOutTag)
|
||||||
.Add(grayscale_image.release(), timestamp);
|
.Add(grayscale_image.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
|
||||||
*feature_list, *camera_motion,
|
*feature_list, *camera_motion,
|
||||||
with_saliency_ ? saliency[k].get() : nullptr, &visualization);
|
with_saliency_ ? saliency[k].get() : nullptr, &visualization);
|
||||||
|
|
||||||
cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp);
|
cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output dense foreground mask.
|
// Output dense foreground mask.
|
||||||
|
@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
|
||||||
cv::Mat foreground = formats::MatView(foreground_frame.get());
|
cv::Mat foreground = formats::MatView(foreground_frame.get());
|
||||||
motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion,
|
motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion,
|
||||||
&foreground);
|
&foreground);
|
||||||
cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp);
|
cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output flow features if requested.
|
// Output flow features if requested.
|
||||||
if (region_flow_feature_output_) {
|
if (region_flow_feature_output_) {
|
||||||
cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp);
|
cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output camera motion.
|
// Output camera motion.
|
||||||
if (camera_motion_output_) {
|
if (camera_motion_output_) {
|
||||||
cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp);
|
cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (video_output_) {
|
if (video_output_) {
|
||||||
cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]);
|
cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output saliency.
|
// Output saliency.
|
||||||
if (saliency_output_) {
|
if (saliency_output_) {
|
||||||
cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp);
|
cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,12 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH";
|
||||||
|
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
|
||||||
|
|
||||||
// cv::VideoCapture set data type to unsigned char by default. Therefore, the
|
// cv::VideoCapture set data type to unsigned char by default. Therefore, the
|
||||||
// image format is only related to the number of channles the cv::Mat has.
|
// image format is only related to the number of channles the cv::Mat has.
|
||||||
ImageFormat::Format GetImageFormat(int num_channels) {
|
ImageFormat::Format GetImageFormat(int num_channels) {
|
||||||
|
@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) {
|
||||||
class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
|
cc->InputSidePackets().Tag(kInputFilePathTag).Set<std::string>();
|
||||||
cc->Outputs().Tag("VIDEO").Set<ImageFrame>();
|
cc->Outputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||||
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
|
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
|
||||||
cc->Outputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
cc->Outputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
|
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
|
||||||
cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set<std::string>();
|
cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
const std::string& input_file_path =
|
const std::string& input_file_path =
|
||||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
|
cc->InputSidePackets().Tag(kInputFilePathTag).Get<std::string>();
|
||||||
cap_ = absl::make_unique<cv::VideoCapture>(input_file_path);
|
cap_ = absl::make_unique<cv::VideoCapture>(input_file_path);
|
||||||
if (!cap_->isOpened()) {
|
if (!cap_->isOpened()) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
|
@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
||||||
header->frame_rate = fps;
|
header->frame_rate = fps;
|
||||||
header->duration = frame_count_ / fps;
|
header->duration = frame_count_ / fps;
|
||||||
|
|
||||||
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
|
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("VIDEO_PRESTREAM")
|
.Tag(kVideoPrestreamTag)
|
||||||
.Add(header.release(), Timestamp::PreStream());
|
.Add(header.release(), Timestamp::PreStream());
|
||||||
cc->Outputs().Tag("VIDEO_PRESTREAM").Close();
|
cc->Outputs().Tag(kVideoPrestreamTag).Close();
|
||||||
}
|
}
|
||||||
// Rewind to the very first frame.
|
// Rewind to the very first frame.
|
||||||
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
|
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
|
||||||
|
|
||||||
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
|
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
|
||||||
#ifdef HAVE_FFMPEG
|
#ifdef HAVE_FFMPEG
|
||||||
std::string saved_audio_path = std::tmpnam(nullptr);
|
std::string saved_audio_path = std::tmpnam(nullptr);
|
||||||
std::string ffmpeg_command =
|
std::string ffmpeg_command =
|
||||||
|
@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
||||||
int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str());
|
int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str());
|
||||||
if (status_code == 0) {
|
if (status_code == 0) {
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Tag("SAVED_AUDIO_PATH")
|
.Tag(kSavedAudioPathTag)
|
||||||
.Set(MakePacket<std::string>(saved_audio_path));
|
.Set(MakePacket<std::string>(saved_audio_path));
|
||||||
} else {
|
} else {
|
||||||
LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path
|
LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path
|
||||||
<< " by executing the following command: "
|
<< " by executing the following command: "
|
||||||
<< ffmpeg_command;
|
<< ffmpeg_command;
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Tag("SAVED_AUDIO_PATH")
|
.Tag(kSavedAudioPathTag)
|
||||||
.Set(MakePacket<std::string>(std::string()));
|
.Set(MakePacket<std::string>(std::string()));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
||||||
// If the timestamp of the current frame is not greater than the one of the
|
// If the timestamp of the current frame is not greater than the one of the
|
||||||
// previous frame, the new frame will be discarded.
|
// previous frame, the new frame will be discarded.
|
||||||
if (prev_timestamp_ < timestamp) {
|
if (prev_timestamp_ < timestamp) {
|
||||||
cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp);
|
cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp);
|
||||||
prev_timestamp_ = timestamp;
|
prev_timestamp_ = timestamp;
|
||||||
decoded_frames_++;
|
decoded_frames_++;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,10 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||||
|
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
|
||||||
|
|
||||||
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
||||||
CalculatorGraphConfig::Node node_config =
|
CalculatorGraphConfig::Node node_config =
|
||||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
||||||
output_stream: "VIDEO:video"
|
output_stream: "VIDEO:video"
|
||||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||||
file::JoinPath("./",
|
file::JoinPath("./",
|
||||||
"/mediapipe/calculators/video/"
|
"/mediapipe/calculators/video/"
|
||||||
"testdata/format_MP4_AVC720P_AAC.video"));
|
"testdata/format_MP4_AVC720P_AAC.video"));
|
||||||
MP_EXPECT_OK(runner.Run());
|
MP_EXPECT_OK(runner.Run());
|
||||||
|
|
||||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||||
MP_EXPECT_OK(runner.Outputs()
|
MP_EXPECT_OK(runner.Outputs()
|
||||||
.Tag("VIDEO_PRESTREAM")
|
.Tag(kVideoPrestreamTag)
|
||||||
.packets[0]
|
.packets[0]
|
||||||
.ValidateAsType<VideoHeader>());
|
.ValidateAsType<VideoHeader>());
|
||||||
const mediapipe::VideoHeader& header =
|
const mediapipe::VideoHeader& header =
|
||||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||||
EXPECT_EQ(1280, header.width);
|
EXPECT_EQ(1280, header.width);
|
||||||
EXPECT_EQ(640, header.height);
|
EXPECT_EQ(640, header.height);
|
||||||
|
@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
||||||
// The number of the output packets should be 180.
|
// The number of the output packets should be 180.
|
||||||
// Some OpenCV version returns the first two frames with the same timestamp on
|
// Some OpenCV version returns the first two frames with the same timestamp on
|
||||||
// macos and we might miss one frame here.
|
// macos and we might miss one frame here.
|
||||||
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
|
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
|
||||||
EXPECT_GE(num_of_packets, 179);
|
EXPECT_GE(num_of_packets, 179);
|
||||||
for (int i = 0; i < num_of_packets; ++i) {
|
for (int i = 0; i < num_of_packets; ++i) {
|
||||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||||
cv::Mat output_mat =
|
cv::Mat output_mat =
|
||||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||||
EXPECT_EQ(1280, output_mat.size().width);
|
EXPECT_EQ(1280, output_mat.size().width);
|
||||||
|
@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
|
||||||
output_stream: "VIDEO:video"
|
output_stream: "VIDEO:video"
|
||||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||||
file::JoinPath("./",
|
file::JoinPath("./",
|
||||||
"/mediapipe/calculators/video/"
|
"/mediapipe/calculators/video/"
|
||||||
"testdata/format_FLV_H264_AAC.video"));
|
"testdata/format_FLV_H264_AAC.video"));
|
||||||
MP_EXPECT_OK(runner.Run());
|
MP_EXPECT_OK(runner.Run());
|
||||||
|
|
||||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||||
MP_EXPECT_OK(runner.Outputs()
|
MP_EXPECT_OK(runner.Outputs()
|
||||||
.Tag("VIDEO_PRESTREAM")
|
.Tag(kVideoPrestreamTag)
|
||||||
.packets[0]
|
.packets[0]
|
||||||
.ValidateAsType<VideoHeader>());
|
.ValidateAsType<VideoHeader>());
|
||||||
const mediapipe::VideoHeader& header =
|
const mediapipe::VideoHeader& header =
|
||||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||||
EXPECT_EQ(640, header.width);
|
EXPECT_EQ(640, header.width);
|
||||||
EXPECT_EQ(320, header.height);
|
EXPECT_EQ(320, header.height);
|
||||||
|
@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
|
||||||
// can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4).
|
// can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4).
|
||||||
// EXPECT_FLOAT_EQ(6.0f, header.duration);
|
// EXPECT_FLOAT_EQ(6.0f, header.duration);
|
||||||
// EXPECT_FLOAT_EQ(30.0f, header.frame_rate);
|
// EXPECT_FLOAT_EQ(30.0f, header.frame_rate);
|
||||||
EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size());
|
EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size());
|
||||||
for (int i = 0; i < 180; ++i) {
|
for (int i = 0; i < 180; ++i) {
|
||||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||||
cv::Mat output_mat =
|
cv::Mat output_mat =
|
||||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||||
EXPECT_EQ(640, output_mat.size().width);
|
EXPECT_EQ(640, output_mat.size().width);
|
||||||
|
@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
|
||||||
output_stream: "VIDEO:video"
|
output_stream: "VIDEO:video"
|
||||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||||
CalculatorRunner runner(node_config);
|
CalculatorRunner runner(node_config);
|
||||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||||
file::JoinPath("./",
|
file::JoinPath("./",
|
||||||
"/mediapipe/calculators/video/"
|
"/mediapipe/calculators/video/"
|
||||||
"testdata/format_MKV_VP8_VORBIS.video"));
|
"testdata/format_MKV_VP8_VORBIS.video"));
|
||||||
MP_EXPECT_OK(runner.Run());
|
MP_EXPECT_OK(runner.Run());
|
||||||
|
|
||||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||||
MP_EXPECT_OK(runner.Outputs()
|
MP_EXPECT_OK(runner.Outputs()
|
||||||
.Tag("VIDEO_PRESTREAM")
|
.Tag(kVideoPrestreamTag)
|
||||||
.packets[0]
|
.packets[0]
|
||||||
.ValidateAsType<VideoHeader>());
|
.ValidateAsType<VideoHeader>());
|
||||||
const mediapipe::VideoHeader& header =
|
const mediapipe::VideoHeader& header =
|
||||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||||
EXPECT_EQ(640, header.width);
|
EXPECT_EQ(640, header.width);
|
||||||
EXPECT_EQ(320, header.height);
|
EXPECT_EQ(320, header.height);
|
||||||
|
@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
|
||||||
// The number of the output packets should be 180.
|
// The number of the output packets should be 180.
|
||||||
// Some OpenCV version returns the first two frames with the same timestamp on
|
// Some OpenCV version returns the first two frames with the same timestamp on
|
||||||
// macos and we might miss one frame here.
|
// macos and we might miss one frame here.
|
||||||
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
|
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
|
||||||
EXPECT_GE(num_of_packets, 179);
|
EXPECT_GE(num_of_packets, 179);
|
||||||
for (int i = 0; i < num_of_packets; ++i) {
|
for (int i = 0; i < num_of_packets; ++i) {
|
||||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||||
cv::Mat output_mat =
|
cv::Mat output_mat =
|
||||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||||
EXPECT_EQ(640, output_mat.size().width);
|
EXPECT_EQ(640, output_mat.size().width);
|
||||||
|
|
|
@ -36,6 +36,11 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH";
|
||||||
|
constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH";
|
||||||
|
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||||
|
constexpr char kVideoTag[] = "VIDEO";
|
||||||
|
|
||||||
// Encodes the input video stream and produces a media file.
|
// Encodes the input video stream and produces a media file.
|
||||||
// The media file can be output to the output_file_path specified as a side
|
// The media file can be output to the output_file_path specified as a side
|
||||||
// packet. Currently, the calculator only supports one video stream (in
|
// packet. Currently, the calculator only supports one video stream (in
|
||||||
|
@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase {
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"));
|
RET_CHECK(cc->Inputs().HasTag(kVideoTag));
|
||||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH"));
|
RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag));
|
||||||
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set<std::string>();
|
cc->InputSidePackets().Tag(kOutputFilePathTag).Set<std::string>();
|
||||||
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
|
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
|
||||||
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set<std::string>();
|
cc->InputSidePackets().Tag(kAudioFilePathTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
||||||
<< "Video format must be specified in "
|
<< "Video format must be specified in "
|
||||||
"OpenCvVideoEncoderCalculatorOptions";
|
"OpenCvVideoEncoderCalculatorOptions";
|
||||||
output_file_path_ =
|
output_file_path_ =
|
||||||
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get<std::string>();
|
cc->InputSidePackets().Tag(kOutputFilePathTag).Get<std::string>();
|
||||||
std::vector<std::string> splited_file_path =
|
std::vector<std::string> splited_file_path =
|
||||||
absl::StrSplit(output_file_path_, '.');
|
absl::StrSplit(output_file_path_, '.');
|
||||||
RET_CHECK(splited_file_path.size() >= 2 &&
|
RET_CHECK(splited_file_path.size() >= 2 &&
|
||||||
|
@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
||||||
// If the video header will be available, the video metadata will be fetched
|
// If the video header will be available, the video metadata will be fetched
|
||||||
// from the video header directly. The calculator will receive the video
|
// from the video header directly. The calculator will receive the video
|
||||||
// header packet at timestamp prestream.
|
// header packet at timestamp prestream.
|
||||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
return SetUpVideoWriter(options.fps(), options.width(), options.height());
|
return SetUpVideoWriter(options.fps(), options.width(), options.height());
|
||||||
|
@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
||||||
absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
||||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||||
const VideoHeader& video_header =
|
const VideoHeader& video_header =
|
||||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||||
return SetUpVideoWriter(video_header.frame_rate, video_header.width,
|
return SetUpVideoWriter(video_header.frame_rate, video_header.width,
|
||||||
video_header.height);
|
video_header.height);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ImageFrame& image_frame =
|
const ImageFrame& image_frame =
|
||||||
cc->Inputs().Tag("VIDEO").Value().Get<ImageFrame>();
|
cc->Inputs().Tag(kVideoTag).Value().Get<ImageFrame>();
|
||||||
ImageFormat::Format format = image_frame.Format();
|
ImageFormat::Format format = image_frame.Format();
|
||||||
cv::Mat frame;
|
cv::Mat frame;
|
||||||
if (format == ImageFormat::GRAY8) {
|
if (format == ImageFormat::GRAY8) {
|
||||||
|
@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
||||||
if (frame.empty()) {
|
if (frame.empty()) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Receive empty frame at timestamp "
|
<< "Receive empty frame at timestamp "
|
||||||
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
|
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
|
||||||
<< " in OpenCvVideoEncoderCalculator::Process()";
|
<< " in OpenCvVideoEncoderCalculator::Process()";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
||||||
if (tmp_frame.empty()) {
|
if (tmp_frame.empty()) {
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Receive empty frame at timestamp "
|
<< "Receive empty frame at timestamp "
|
||||||
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
|
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
|
||||||
<< " in OpenCvVideoEncoderCalculator::Process()";
|
<< " in OpenCvVideoEncoderCalculator::Process()";
|
||||||
}
|
}
|
||||||
if (format == ImageFormat::SRGB) {
|
if (format == ImageFormat::SRGB) {
|
||||||
|
@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) {
|
||||||
if (writer_ && writer_->isOpened()) {
|
if (writer_ && writer_->isOpened()) {
|
||||||
writer_->release();
|
writer_->release();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
|
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
|
||||||
#ifdef HAVE_FFMPEG
|
#ifdef HAVE_FFMPEG
|
||||||
const std::string& audio_file_path =
|
const std::string& audio_file_path =
|
||||||
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get<std::string>();
|
cc->InputSidePackets().Tag(kAudioFilePathTag).Get<std::string>();
|
||||||
if (audio_file_path.empty()) {
|
if (audio_file_path.empty()) {
|
||||||
LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the "
|
LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the "
|
||||||
"audio tracks to the generated video because the audio "
|
"audio tracks to the generated video because the audio "
|
||||||
|
|
|
@ -23,6 +23,11 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW";
|
||||||
|
constexpr char kForwardFlowTag[] = "FORWARD_FLOW";
|
||||||
|
constexpr char kSecondFrameTag[] = "SECOND_FRAME";
|
||||||
|
constexpr char kFirstFrameTag[] = "FIRST_FRAME";
|
||||||
|
|
||||||
// Checks that img1 and img2 have the same dimensions.
|
// Checks that img1 and img2 have the same dimensions.
|
||||||
bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) {
|
bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) {
|
||||||
return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height());
|
return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height());
|
||||||
|
@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase {
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) {
|
absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (!cc->Inputs().HasTag("FIRST_FRAME") ||
|
if (!cc->Inputs().HasTag(kFirstFrameTag) ||
|
||||||
!cc->Inputs().HasTag("SECOND_FRAME")) {
|
!cc->Inputs().HasTag(kSecondFrameTag)) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"Missing required input streams. Both FIRST_FRAME and SECOND_FRAME "
|
"Missing required input streams. Both FIRST_FRAME and SECOND_FRAME "
|
||||||
"must be specified.");
|
"must be specified.");
|
||||||
}
|
}
|
||||||
cc->Inputs().Tag("FIRST_FRAME").Set<ImageFrame>();
|
cc->Inputs().Tag(kFirstFrameTag).Set<ImageFrame>();
|
||||||
cc->Inputs().Tag("SECOND_FRAME").Set<ImageFrame>();
|
cc->Inputs().Tag(kSecondFrameTag).Set<ImageFrame>();
|
||||||
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
|
if (cc->Outputs().HasTag(kForwardFlowTag)) {
|
||||||
cc->Outputs().Tag("FORWARD_FLOW").Set<OpticalFlowField>();
|
cc->Outputs().Tag(kForwardFlowTag).Set<OpticalFlowField>();
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
|
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
|
||||||
cc->Outputs().Tag("BACKWARD_FLOW").Set<OpticalFlowField>();
|
cc->Outputs().Tag(kBackwardFlowTag).Set<OpticalFlowField>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
|
||||||
absl::MutexLock lock(&mutex_);
|
absl::MutexLock lock(&mutex_);
|
||||||
tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1());
|
tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1());
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
|
if (cc->Outputs().HasTag(kForwardFlowTag)) {
|
||||||
forward_requested_ = true;
|
forward_requested_ = true;
|
||||||
}
|
}
|
||||||
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
|
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
|
||||||
backward_requested_ = true;
|
backward_requested_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
|
||||||
|
|
||||||
absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
|
absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
|
||||||
const ImageFrame& first_frame =
|
const ImageFrame& first_frame =
|
||||||
cc->Inputs().Tag("FIRST_FRAME").Value().Get<ImageFrame>();
|
cc->Inputs().Tag(kFirstFrameTag).Value().Get<ImageFrame>();
|
||||||
const ImageFrame& second_frame =
|
const ImageFrame& second_frame =
|
||||||
cc->Inputs().Tag("SECOND_FRAME").Value().Get<ImageFrame>();
|
cc->Inputs().Tag(kSecondFrameTag).Value().Get<ImageFrame>();
|
||||||
if (forward_requested_) {
|
if (forward_requested_) {
|
||||||
auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>();
|
auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>();
|
||||||
MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame,
|
MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame,
|
||||||
forward_optical_flow_field.get()));
|
forward_optical_flow_field.get()));
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("FORWARD_FLOW")
|
.Tag(kForwardFlowTag)
|
||||||
.Add(forward_optical_flow_field.release(), cc->InputTimestamp());
|
.Add(forward_optical_flow_field.release(), cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
if (backward_requested_) {
|
if (backward_requested_) {
|
||||||
|
@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
|
||||||
MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame,
|
MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame,
|
||||||
backward_optical_flow_field.get()));
|
backward_optical_flow_field.get()));
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Tag("BACKWARD_FLOW")
|
.Tag(kBackwardFlowTag)
|
||||||
.Add(backward_optical_flow_field.release(), cc->InputTimestamp());
|
.Add(backward_optical_flow_field.release(), cc->InputTimestamp());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
|
@ -19,6 +19,9 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||||
|
constexpr char kFrameTag[] = "FRAME";
|
||||||
|
|
||||||
// Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp
|
// Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp
|
||||||
// PreStream. Note that this calculator only fills in format, width, and height,
|
// PreStream. Note that this calculator only fills in format, width, and height,
|
||||||
// i.e. frame_rate and duration will not be filled, unless:
|
// i.e. frame_rate and duration will not be filled, unless:
|
||||||
|
@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (!cc->Inputs().UsesTags()) {
|
if (!cc->Inputs().UsesTags()) {
|
||||||
cc->Inputs().Index(0).Set<ImageFrame>();
|
cc->Inputs().Index(0).Set<ImageFrame>();
|
||||||
} else {
|
} else {
|
||||||
cc->Inputs().Tag("FRAME").Set<ImageFrame>();
|
cc->Inputs().Tag(kFrameTag).Set<ImageFrame>();
|
||||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||||
}
|
}
|
||||||
cc->Outputs().Index(0).Set<VideoHeader>();
|
cc->Outputs().Index(0).Set<VideoHeader>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
|
||||||
|
|
||||||
absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) {
|
absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) {
|
||||||
frame_rate_in_prestream_ = cc->Inputs().UsesTags() &&
|
frame_rate_in_prestream_ = cc->Inputs().UsesTags() &&
|
||||||
cc->Inputs().HasTag("FRAME") &&
|
cc->Inputs().HasTag(kFrameTag) &&
|
||||||
cc->Inputs().HasTag("VIDEO_PRESTREAM");
|
cc->Inputs().HasTag(kVideoPrestreamTag);
|
||||||
header_ = absl::make_unique<VideoHeader>();
|
header_ = absl::make_unique<VideoHeader>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment();
|
cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment();
|
||||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||||
RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty());
|
RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty());
|
||||||
RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty());
|
||||||
*header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
*header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||||
RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero";
|
RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero";
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty())
|
RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty())
|
||||||
<< "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream().";
|
<< "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream().";
|
||||||
RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty());
|
||||||
const auto& frame = cc->Inputs().Tag("FRAME").Get<ImageFrame>();
|
const auto& frame = cc->Inputs().Tag(kFrameTag).Get<ImageFrame>();
|
||||||
header_->format = frame.Format();
|
header_->format = frame.Format();
|
||||||
header_->width = frame.Width();
|
header_->width = frame.Width();
|
||||||
header_->height = frame.Height();
|
header_->height = frame.Height();
|
||||||
|
|
|
@ -44,28 +44,32 @@ using mediapipe::MakePacket;
|
||||||
using mediapipe::OutputStreamShardSet;
|
using mediapipe::OutputStreamShardSet;
|
||||||
using mediapipe::Timestamp;
|
using mediapipe::Timestamp;
|
||||||
namespace proto_ns = mediapipe::proto_ns;
|
namespace proto_ns = mediapipe::proto_ns;
|
||||||
|
|
||||||
|
constexpr char kEventTag[] = "EVENT";
|
||||||
|
constexpr char kOutTag[] = "OUT";
|
||||||
|
|
||||||
using mediapipe::CalculatorGraph;
|
using mediapipe::CalculatorGraph;
|
||||||
using mediapipe::Packet;
|
using mediapipe::Packet;
|
||||||
|
|
||||||
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
|
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
|
static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
|
||||||
cc->Outputs().Tag("OUT").Set<int>();
|
cc->Outputs().Tag(kOutTag).Set<int>();
|
||||||
cc->Outputs().Tag("EVENT").Set<int>();
|
cc->Outputs().Tag(kEventTag).Set<int>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Tag("OUT").AddPacket(
|
cc->Outputs().Tag(kOutTag).AddPacket(
|
||||||
MakePacket<int>(count_).At(Timestamp(count_)));
|
MakePacket<int>(count_).At(Timestamp(count_)));
|
||||||
count_++;
|
count_++;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Close(CalculatorContext* cc) override {
|
absl::Status Close(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,11 +85,11 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
|
||||||
cc->Inputs().Get("", i).SetAny();
|
cc->Inputs().Get("", i).SetAny();
|
||||||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("EVENT").Set<int>();
|
cc->Outputs().Tag(kEventTag).Set<int>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
@ -98,7 +102,7 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
|
||||||
: mediapipe::tool::StatusStop();
|
: mediapipe::tool::StatusStop();
|
||||||
}
|
}
|
||||||
absl::Status Close(CalculatorContext* cc) override {
|
absl::Status Close(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,16 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kCounter2Tag[] = "COUNTER2";
|
||||||
|
constexpr char kCounter1Tag[] = "COUNTER1";
|
||||||
|
constexpr char kExtraTag[] = "EXTRA";
|
||||||
|
constexpr char kWaitSemTag[] = "WAIT_SEM";
|
||||||
|
constexpr char kPostSemTag[] = "POST_SEM";
|
||||||
|
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
|
||||||
|
constexpr char kOutputTag[] = "OUTPUT";
|
||||||
|
constexpr char kInputTag[] = "INPUT";
|
||||||
|
constexpr char kSelectTag[] = "SELECT";
|
||||||
|
|
||||||
using testing::ElementsAre;
|
using testing::ElementsAre;
|
||||||
using testing::HasSubstr;
|
using testing::HasSubstr;
|
||||||
|
|
||||||
|
@ -125,8 +135,8 @@ class DemuxTimedCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
||||||
cc->Inputs().Tag("SELECT").Set<int>();
|
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||||
PacketType* data_input = &cc->Inputs().Tag("INPUT");
|
PacketType* data_input = &cc->Inputs().Tag(kInputTag);
|
||||||
data_input->SetAny();
|
data_input->SetAny();
|
||||||
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
||||||
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
||||||
|
@ -182,7 +192,7 @@ REGISTER_CALCULATOR(DemuxTimedCalculator);
|
||||||
class MuxTimedCalculator : public CalculatorBase {
|
class MuxTimedCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("SELECT").Set<int>();
|
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||||
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
|
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
|
||||||
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
|
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
|
||||||
data_input0->SetAny();
|
data_input0->SetAny();
|
||||||
|
@ -191,7 +201,7 @@ class MuxTimedCalculator : public CalculatorBase {
|
||||||
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
|
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
|
||||||
}
|
}
|
||||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||||
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
|
cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0);
|
||||||
cc->SetTimestampOffset(TimestampDiff(0));
|
cc->SetTimestampOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -598,12 +608,12 @@ class ErrorOnOpenCalculator : public CalculatorBase {
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).SetAny();
|
cc->Inputs().Index(0).SetAny();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
|
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
|
if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
|
||||||
return absl::NotFoundError("expected error");
|
return absl::NotFoundError("expected error");
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -920,8 +930,8 @@ class SemaphoreCalculator : public CalculatorBase {
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).SetAny();
|
cc->Inputs().Index(0).SetAny();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
cc->InputSidePackets().Tag("POST_SEM").Set<Semaphore*>();
|
cc->InputSidePackets().Tag(kPostSemTag).Set<Semaphore*>();
|
||||||
cc->InputSidePackets().Tag("WAIT_SEM").Set<Semaphore*>();
|
cc->InputSidePackets().Tag(kWaitSemTag).Set<Semaphore*>();
|
||||||
cc->SetTimestampOffset(TimestampDiff(0));
|
cc->SetTimestampOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -929,8 +939,8 @@ class SemaphoreCalculator : public CalculatorBase {
|
||||||
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
|
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
cc->InputSidePackets().Tag("POST_SEM").Get<Semaphore*>()->Release(1);
|
cc->InputSidePackets().Tag(kPostSemTag).Get<Semaphore*>()->Release(1);
|
||||||
cc->InputSidePackets().Tag("WAIT_SEM").Get<Semaphore*>()->Acquire(1);
|
cc->InputSidePackets().Tag(kWaitSemTag).Get<Semaphore*>()->Acquire(1);
|
||||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -1177,9 +1187,9 @@ class IncrementingStatusHandler : public StatusHandler {
|
||||||
static absl::Status FillExpectations(
|
static absl::Status FillExpectations(
|
||||||
const MediaPipeOptions& extendable_options,
|
const MediaPipeOptions& extendable_options,
|
||||||
PacketTypeSet* input_side_packets) {
|
PacketTypeSet* input_side_packets) {
|
||||||
input_side_packets->Tag("EXTRA").SetAny().Optional();
|
input_side_packets->Tag(kExtraTag).SetAny().Optional();
|
||||||
input_side_packets->Tag("COUNTER1").Set<std::unique_ptr<int>>();
|
input_side_packets->Tag(kCounter1Tag).Set<std::unique_ptr<int>>();
|
||||||
input_side_packets->Tag("COUNTER2").Set<std::unique_ptr<int>>();
|
input_side_packets->Tag(kCounter2Tag).Set<std::unique_ptr<int>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1187,7 +1197,7 @@ class IncrementingStatusHandler : public StatusHandler {
|
||||||
const MediaPipeOptions& extendable_options,
|
const MediaPipeOptions& extendable_options,
|
||||||
const PacketSet& input_side_packets, //
|
const PacketSet& input_side_packets, //
|
||||||
const absl::Status& pre_run_status) {
|
const absl::Status& pre_run_status) {
|
||||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER1"));
|
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter1Tag));
|
||||||
(*counter)++;
|
(*counter)++;
|
||||||
return pre_run_status_result_;
|
return pre_run_status_result_;
|
||||||
}
|
}
|
||||||
|
@ -1195,7 +1205,7 @@ class IncrementingStatusHandler : public StatusHandler {
|
||||||
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
|
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
|
||||||
const PacketSet& input_side_packets, //
|
const PacketSet& input_side_packets, //
|
||||||
const absl::Status& run_status) {
|
const absl::Status& run_status) {
|
||||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER2"));
|
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter2Tag));
|
||||||
(*counter)++;
|
(*counter)++;
|
||||||
return post_run_status_result_;
|
return post_run_status_result_;
|
||||||
}
|
}
|
||||||
|
@ -2228,20 +2238,20 @@ class DemuxUntimedCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
||||||
cc->Inputs().Tag("INPUT").SetAny();
|
cc->Inputs().Tag(kInputTag).SetAny();
|
||||||
cc->Inputs().Tag("SELECT").Set<int>();
|
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||||
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
||||||
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
||||||
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT"));
|
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag));
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
int index = cc->Inputs().Tag("SELECT").Get<int>();
|
int index = cc->Inputs().Tag(kSelectTag).Get<int>();
|
||||||
if (!cc->Inputs().Tag("INPUT").IsEmpty()) {
|
if (!cc->Inputs().Tag(kInputTag).IsEmpty()) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Get("OUTPUT", index)
|
.Get("OUTPUT", index)
|
||||||
.AddPacket(cc->Inputs().Tag("INPUT").Value());
|
.AddPacket(cc->Inputs().Tag(kInputTag).Value());
|
||||||
} else {
|
} else {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Get("OUTPUT", index)
|
.Get("OUTPUT", index)
|
||||||
|
|
|
@ -32,6 +32,11 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kTag[] = "";
|
||||||
|
constexpr char kBTag[] = "B";
|
||||||
|
constexpr char kATag[] = "A";
|
||||||
|
constexpr char kSideOutputTag[] = "SIDE_OUTPUT";
|
||||||
|
|
||||||
// Inputs: 2 streams with ints. Headers are strings.
|
// Inputs: 2 streams with ints. Headers are strings.
|
||||||
// Input side packets: 1.
|
// Input side packets: 1.
|
||||||
// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from
|
// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from
|
||||||
|
@ -48,7 +53,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
|
||||||
cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0));
|
cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0));
|
||||||
cc->InputSidePackets().Index(0).SetAny();
|
cc->InputSidePackets().Index(0).SetAny();
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Tag("SIDE_OUTPUT")
|
.Tag(kSideOutputTag)
|
||||||
.SetSameAs(&cc->InputSidePackets().Index(0));
|
.SetSameAs(&cc->InputSidePackets().Index(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -64,7 +69,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
|
||||||
Adopt(new std::string(absl::StrCat(input_header_string, i))));
|
Adopt(new std::string(absl::StrCat(input_header_string, i))));
|
||||||
}
|
}
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Tag("SIDE_OUTPUT")
|
.Tag(kSideOutputTag)
|
||||||
.Set(cc->InputSidePackets().Index(0));
|
.Set(cc->InputSidePackets().Index(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -152,7 +157,7 @@ TEST(CalculatorRunner, RunsCalculator) {
|
||||||
Adopt(new int(input_side_packet_content));
|
Adopt(new int(input_side_packet_content));
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
EXPECT_EQ(input_side_packet_content,
|
EXPECT_EQ(input_side_packet_content,
|
||||||
runner.OutputSidePackets().Tag("SIDE_OUTPUT").Get<int>());
|
runner.OutputSidePackets().Tag(kSideOutputTag).Get<int>());
|
||||||
const auto& outputs = runner.Outputs();
|
const auto& outputs = runner.Outputs();
|
||||||
ASSERT_EQ(3, outputs.NumEntries());
|
ASSERT_EQ(3, outputs.NumEntries());
|
||||||
|
|
||||||
|
@ -209,9 +214,9 @@ TEST(CalculatorRunner, MultiTagTestCalculatorOk) {
|
||||||
const auto& outputs = runner.Outputs();
|
const auto& outputs = runner.Outputs();
|
||||||
ASSERT_EQ(3, outputs.NumEntries());
|
ASSERT_EQ(3, outputs.NumEntries());
|
||||||
for (int ts = 0; ts < 5; ++ts) {
|
for (int ts = 0; ts < 5; ++ts) {
|
||||||
const std::vector<Packet>& a_packets = outputs.Tag("A").packets;
|
const std::vector<Packet>& a_packets = outputs.Tag(kATag).packets;
|
||||||
const std::vector<Packet>& b_packets = outputs.Tag("B").packets;
|
const std::vector<Packet>& b_packets = outputs.Tag(kBTag).packets;
|
||||||
const std::vector<Packet>& c_packets = outputs.Tag("").packets;
|
const std::vector<Packet>& c_packets = outputs.Tag(kTag).packets;
|
||||||
EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp());
|
EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp());
|
||||||
EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp());
|
EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp());
|
||||||
EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp());
|
EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp());
|
||||||
|
|
|
@ -24,6 +24,10 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kTag2Tag[] = "TAG_2";
|
||||||
|
constexpr char kTag0Tag[] = "TAG_0";
|
||||||
|
constexpr char kTag1Tag[] = "TAG_1";
|
||||||
|
|
||||||
TEST(CollectionTest, BasicByIndex) {
|
TEST(CollectionTest, BasicByIndex) {
|
||||||
tool::TagAndNameInfo info;
|
tool::TagAndNameInfo info;
|
||||||
info.names.push_back("name_1");
|
info.names.push_back("name_1");
|
||||||
|
@ -55,14 +59,14 @@ TEST(CollectionTest, BasicByTag) {
|
||||||
info.names.push_back("name_2");
|
info.names.push_back("name_2");
|
||||||
info.tags.push_back("TAG_2");
|
info.tags.push_back("TAG_2");
|
||||||
internal::Collection<int> collection(info);
|
internal::Collection<int> collection(info);
|
||||||
collection.Tag("TAG_1") = 101;
|
collection.Tag(kTag1Tag) = 101;
|
||||||
collection.Tag("TAG_0") = 100;
|
collection.Tag(kTag0Tag) = 100;
|
||||||
collection.Tag("TAG_2") = 102;
|
collection.Tag(kTag2Tag) = 102;
|
||||||
|
|
||||||
// Test the stored values.
|
// Test the stored values.
|
||||||
EXPECT_EQ(100, collection.Tag("TAG_0"));
|
EXPECT_EQ(100, collection.Tag(kTag0Tag));
|
||||||
EXPECT_EQ(101, collection.Tag("TAG_1"));
|
EXPECT_EQ(101, collection.Tag(kTag1Tag));
|
||||||
EXPECT_EQ(102, collection.Tag("TAG_2"));
|
EXPECT_EQ(102, collection.Tag(kTag2Tag));
|
||||||
// Test access using a range based for.
|
// Test access using a range based for.
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (int num : collection) {
|
for (int num : collection) {
|
||||||
|
|
|
@ -134,6 +134,21 @@ void Tensor::AllocateMtlBuffer(id<MTLDevice> device) const {
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
bool Tensor::NeedsHalfFloatRenderTarget() const {
|
||||||
|
static bool has_color_buffer_float =
|
||||||
|
gl_context_->HasGlExtension("WEBGL_color_buffer_float") ||
|
||||||
|
gl_context_->HasGlExtension("EXT_color_buffer_float");
|
||||||
|
if (!has_color_buffer_float) {
|
||||||
|
static bool has_color_buffer_half_float =
|
||||||
|
gl_context_->HasGlExtension("EXT_color_buffer_half_float");
|
||||||
|
LOG_IF(FATAL, !has_color_buffer_half_float)
|
||||||
|
<< "EXT_color_buffer_half_float or WEBGL_color_buffer_float "
|
||||||
|
<< "required on web to use MP tensor";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
||||||
LOG_IF(FATAL, valid_ == kValidNone)
|
LOG_IF(FATAL, valid_ == kValidNone)
|
||||||
<< "Tensor must be written prior to read from.";
|
<< "Tensor must be written prior to read from.";
|
||||||
|
@ -164,8 +179,24 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
||||||
// Set alignment for the proper value (default) to avoid address sanitizer
|
// Set alignment for the proper value (default) to avoid address sanitizer
|
||||||
// error "out of boundary reading".
|
// error "out of boundary reading".
|
||||||
glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
// Under WebGL1, format must match in order to use glTexSubImage2D, so if we
|
||||||
|
// have a half-float texture, then uploading from GL_FLOAT here would fail.
|
||||||
|
// We change the texture's data type to float here to accommodate.
|
||||||
|
// Furthermore, for a full-image replacement operation, glTexImage2D is
|
||||||
|
// expected to be more performant than glTexSubImage2D. Note that for WebGL2
|
||||||
|
// we cannot use glTexImage2D, because we allocate using glTexStorage2D in
|
||||||
|
// that case, which is incompatible.
|
||||||
|
if (gl_context_->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
|
||||||
|
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
|
||||||
|
0, GL_RGBA, GL_FLOAT, temp_buffer.get());
|
||||||
|
texture_is_half_float_ = false;
|
||||||
|
} else
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
|
{
|
||||||
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
|
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
|
||||||
GL_RGBA, GL_FLOAT, temp_buffer.get());
|
GL_RGBA, GL_FLOAT, temp_buffer.get());
|
||||||
|
}
|
||||||
glBindTexture(GL_TEXTURE_2D, 0);
|
glBindTexture(GL_TEXTURE_2D, 0);
|
||||||
valid_ |= kValidOpenGlTexture2d;
|
valid_ |= kValidOpenGlTexture2d;
|
||||||
}
|
}
|
||||||
|
@ -175,6 +206,16 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
||||||
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const {
|
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const {
|
||||||
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
||||||
AllocateOpenGlTexture2d();
|
AllocateOpenGlTexture2d();
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
// On web, we may have to change type from float to half-float
|
||||||
|
if (!texture_is_half_float_ && NeedsHalfFloatRenderTarget()) {
|
||||||
|
glBindTexture(GL_TEXTURE_2D, opengl_texture2d_);
|
||||||
|
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, 0,
|
||||||
|
GL_RGBA, GL_HALF_FLOAT_OES, 0 /* data */);
|
||||||
|
glBindTexture(GL_TEXTURE_2D, 0);
|
||||||
|
texture_is_half_float_ = true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
valid_ = kValidOpenGlTexture2d;
|
valid_ = kValidOpenGlTexture2d;
|
||||||
return {opengl_texture2d_, std::move(lock)};
|
return {opengl_texture2d_, std::move(lock)};
|
||||||
}
|
}
|
||||||
|
@ -255,8 +296,18 @@ void Tensor::AllocateOpenGlTexture2d() const {
|
||||||
<< "with GLES 2.0";
|
<< "with GLES 2.0";
|
||||||
// Allocate the image data; note that it's no longer RGBA32F, so will be
|
// Allocate the image data; note that it's no longer RGBA32F, so will be
|
||||||
// lower precision.
|
// lower precision.
|
||||||
|
auto type = GL_FLOAT;
|
||||||
|
// On web, we might need to change type to half-float (e.g. for iOS-
|
||||||
|
// Safari) in order to have a valid framebuffer. See b/194442743 for more
|
||||||
|
// details.
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
if (NeedsHalfFloatRenderTarget()) {
|
||||||
|
type = GL_HALF_FLOAT_OES;
|
||||||
|
texture_is_half_float_ = true;
|
||||||
|
}
|
||||||
|
#endif // __EMSCRIPTEN
|
||||||
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
|
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
|
||||||
0, GL_RGBA, GL_FLOAT, 0 /* data */);
|
0, GL_RGBA, type, 0 /* data */);
|
||||||
}
|
}
|
||||||
glBindTexture(GL_TEXTURE_2D, 0);
|
glBindTexture(GL_TEXTURE_2D, 0);
|
||||||
glGenFramebuffers(1, &frame_buffer_);
|
glGenFramebuffers(1, &frame_buffer_);
|
||||||
|
@ -443,7 +494,6 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
||||||
glPixelStorei(GL_PACK_ALIGNMENT, 4);
|
glPixelStorei(GL_PACK_ALIGNMENT, 4);
|
||||||
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
|
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
|
||||||
buffer);
|
buffer);
|
||||||
|
|
||||||
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
|
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
|
||||||
const int actual_depth_size =
|
const int actual_depth_size =
|
||||||
BhwcDepthFromShape(shape_) * element_size();
|
BhwcDepthFromShape(shape_) * element_size();
|
||||||
|
|
|
@ -266,11 +266,15 @@ class Tensor {
|
||||||
mutable GLuint frame_buffer_ = GL_INVALID_INDEX;
|
mutable GLuint frame_buffer_ = GL_INVALID_INDEX;
|
||||||
mutable int texture_width_;
|
mutable int texture_width_;
|
||||||
mutable int texture_height_;
|
mutable int texture_height_;
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
|
mutable bool texture_is_half_float_ = false;
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
void AllocateOpenGlTexture2d() const;
|
void AllocateOpenGlTexture2d() const;
|
||||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
mutable GLuint opengl_buffer_ = GL_INVALID_INDEX;
|
mutable GLuint opengl_buffer_ = GL_INVALID_INDEX;
|
||||||
void AllocateOpenGlBuffer() const;
|
void AllocateOpenGlBuffer() const;
|
||||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
|
bool NeedsHalfFloatRenderTarget() const;
|
||||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,11 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kOutputTag[] = "OUTPUT";
|
||||||
|
constexpr char kEnableTag[] = "ENABLE";
|
||||||
|
constexpr char kSelectTag[] = "SELECT";
|
||||||
|
constexpr char kSideinputTag[] = "SIDEINPUT";
|
||||||
|
|
||||||
// Shows validation success for a graph and a subgraph.
|
// Shows validation success for a graph and a subgraph.
|
||||||
TEST(GraphValidationTest, InitializeGraphFromProtos) {
|
TEST(GraphValidationTest, InitializeGraphFromProtos) {
|
||||||
auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
@ -323,20 +328,21 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) {
|
||||||
class OptionalSideInputTestCalculator : public CalculatorBase {
|
class OptionalSideInputTestCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->InputSidePackets().Tag("SIDEINPUT").Set<std::string>().Optional();
|
cc->InputSidePackets().Tag(kSideinputTag).Set<std::string>().Optional();
|
||||||
cc->Inputs().Tag("SELECT").Set<int>().Optional();
|
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||||
cc->Inputs().Tag("ENABLE").Set<bool>().Optional();
|
cc->Inputs().Tag(kEnableTag).Set<bool>().Optional();
|
||||||
cc->Outputs().Tag("OUTPUT").Set<std::string>();
|
cc->Outputs().Tag(kOutputTag).Set<std::string>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
std::string value("default");
|
std::string value("default");
|
||||||
if (cc->InputSidePackets().HasTag("SIDEINPUT")) {
|
if (cc->InputSidePackets().HasTag(kSideinputTag)) {
|
||||||
value = cc->InputSidePackets().Tag("SIDEINPUT").Get<std::string>();
|
value = cc->InputSidePackets().Tag(kSideinputTag).Get<std::string>();
|
||||||
}
|
}
|
||||||
cc->Outputs().Tag("OUTPUT").Add(new std::string(value),
|
cc->Outputs()
|
||||||
cc->InputTimestamp());
|
.Tag(kOutputTag)
|
||||||
|
.Add(new std::string(value), cc->InputTimestamp());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,17 +26,20 @@ namespace {
|
||||||
|
|
||||||
namespace test_ns {
|
namespace test_ns {
|
||||||
|
|
||||||
|
constexpr char kOutTag[] = "OUT";
|
||||||
|
constexpr char kInTag[] = "IN";
|
||||||
|
|
||||||
class TestSinkCalculator : public CalculatorBase {
|
class TestSinkCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
|
cc->Inputs().Tag(kInTag).Set<mediapipe::InputOnlyProto>();
|
||||||
cc->Outputs().Tag("OUT").Set<int>();
|
cc->Outputs().Tag(kOutTag).Set<int>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
|
int x = cc->Inputs().Tag(kInTag).Get<mediapipe::InputOnlyProto>().x();
|
||||||
cc->Outputs().Tag("OUT").AddPacket(
|
cc->Outputs().Tag(kOutTag).AddPacket(
|
||||||
MakePacket<int>(x).At(cc->InputTimestamp()));
|
MakePacket<int>(x).At(cc->InputTimestamp()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,19 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
constexpr char kOutTag[] = "OUT";
|
||||||
|
constexpr char kClockTag[] = "CLOCK";
|
||||||
|
constexpr char kSleepMicrosTag[] = "SLEEP_MICROS";
|
||||||
|
constexpr char kCloseTag[] = "CLOSE";
|
||||||
|
constexpr char kProcessTag[] = "PROCESS";
|
||||||
|
constexpr char kOpenTag[] = "OPEN";
|
||||||
|
constexpr char kTag[] = "";
|
||||||
|
constexpr char kMeanTag[] = "MEAN";
|
||||||
|
constexpr char kDataTag[] = "DATA";
|
||||||
|
constexpr char kPairTag[] = "PAIR";
|
||||||
|
constexpr char kLowTag[] = "LOW";
|
||||||
|
constexpr char kHighTag[] = "HIGH";
|
||||||
|
|
||||||
using RandomEngine = std::mt19937_64;
|
using RandomEngine = std::mt19937_64;
|
||||||
|
|
||||||
// A Calculator that outputs twice the value of its input packet (an int).
|
// A Calculator that outputs twice the value of its input packet (an int).
|
||||||
|
@ -95,9 +108,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
|
||||||
PacketTypeSet* input_side_packets, //
|
PacketTypeSet* input_side_packets, //
|
||||||
PacketTypeSet* output_side_packets) {
|
PacketTypeSet* output_side_packets) {
|
||||||
input_side_packets->Index(0).Set<uint64>();
|
input_side_packets->Index(0).Set<uint64>();
|
||||||
output_side_packets->Tag("HIGH").Set<uint32>();
|
output_side_packets->Tag(kHighTag).Set<uint32>();
|
||||||
output_side_packets->Tag("LOW").Set<uint32>();
|
output_side_packets->Tag(kLowTag).Set<uint32>();
|
||||||
output_side_packets->Tag("PAIR").Set<std::pair<uint32, uint32>>();
|
output_side_packets->Tag(kPairTag).Set<std::pair<uint32, uint32>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,9 +121,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
|
||||||
uint64 value = input_side_packets.Index(0).Get<uint64>();
|
uint64 value = input_side_packets.Index(0).Get<uint64>();
|
||||||
uint32 high = value >> 32;
|
uint32 high = value >> 32;
|
||||||
uint32 low = value & 0xFFFFFFFF;
|
uint32 low = value & 0xFFFFFFFF;
|
||||||
output_side_packets->Tag("HIGH") = Adopt(new uint32(high));
|
output_side_packets->Tag(kHighTag) = Adopt(new uint32(high));
|
||||||
output_side_packets->Tag("LOW") = Adopt(new uint32(low));
|
output_side_packets->Tag(kLowTag) = Adopt(new uint32(low));
|
||||||
output_side_packets->Tag("PAIR") =
|
output_side_packets->Tag(kPairTag) =
|
||||||
Adopt(new std::pair<uint32, uint32>(high, low));
|
Adopt(new std::pair<uint32, uint32>(high, low));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -221,8 +234,8 @@ class StdDevCalculator : public CalculatorBase {
|
||||||
StdDevCalculator() {}
|
StdDevCalculator() {}
|
||||||
|
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Tag("DATA").Set<int>();
|
cc->Inputs().Tag(kDataTag).Set<int>();
|
||||||
cc->Inputs().Tag("MEAN").Set<double>();
|
cc->Inputs().Tag(kMeanTag).Set<double>();
|
||||||
cc->Outputs().Index(0).Set<int>();
|
cc->Outputs().Index(0).Set<int>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -234,15 +247,15 @@ class StdDevCalculator : public CalculatorBase {
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||||
RET_CHECK(cc->Inputs().Tag("DATA").Value().IsEmpty());
|
RET_CHECK(cc->Inputs().Tag(kDataTag).Value().IsEmpty());
|
||||||
RET_CHECK(!cc->Inputs().Tag("MEAN").Value().IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
|
||||||
mean_ = cc->Inputs().Tag("MEAN").Get<double>();
|
mean_ = cc->Inputs().Tag(kMeanTag).Get<double>();
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK(initialized_);
|
RET_CHECK(initialized_);
|
||||||
RET_CHECK(!cc->Inputs().Tag("DATA").Value().IsEmpty());
|
RET_CHECK(!cc->Inputs().Tag(kDataTag).Value().IsEmpty());
|
||||||
RET_CHECK(cc->Inputs().Tag("MEAN").Value().IsEmpty());
|
RET_CHECK(cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
|
||||||
double diff = cc->Inputs().Tag("DATA").Get<int>() - mean_;
|
double diff = cc->Inputs().Tag(kDataTag).Get<int>() - mean_;
|
||||||
cummulative_variance_ += diff * diff;
|
cummulative_variance_ += diff * diff;
|
||||||
++count_;
|
++count_;
|
||||||
}
|
}
|
||||||
|
@ -564,8 +577,8 @@ class LambdaCalculator : public CalculatorBase {
|
||||||
id < cc->Outputs().EndId(); ++id) {
|
id < cc->Outputs().EndId(); ++id) {
|
||||||
cc->Outputs().Get(id).SetAny();
|
cc->Outputs().Get(id).SetAny();
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("") > 0) {
|
if (cc->InputSidePackets().HasTag(kTag) > 0) {
|
||||||
cc->InputSidePackets().Tag("").Set<ProcessFunction>();
|
cc->InputSidePackets().Tag(kTag).Set<ProcessFunction>();
|
||||||
}
|
}
|
||||||
for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) {
|
for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) {
|
||||||
if (cc->InputSidePackets().HasTag(tag)) {
|
if (cc->InputSidePackets().HasTag(tag)) {
|
||||||
|
@ -576,24 +589,24 @@ class LambdaCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
if (cc->InputSidePackets().HasTag("OPEN")) {
|
if (cc->InputSidePackets().HasTag(kOpenTag)) {
|
||||||
return GetContextFn(cc, "OPEN")(cc);
|
return GetContextFn(cc, "OPEN")(cc);
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
if (cc->InputSidePackets().HasTag("PROCESS")) {
|
if (cc->InputSidePackets().HasTag(kProcessTag)) {
|
||||||
return GetContextFn(cc, "PROCESS")(cc);
|
return GetContextFn(cc, "PROCESS")(cc);
|
||||||
}
|
}
|
||||||
if (cc->InputSidePackets().HasTag("") > 0) {
|
if (cc->InputSidePackets().HasTag(kTag) > 0) {
|
||||||
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
|
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Close(CalculatorContext* cc) final {
|
absl::Status Close(CalculatorContext* cc) final {
|
||||||
if (cc->InputSidePackets().HasTag("CLOSE")) {
|
if (cc->InputSidePackets().HasTag(kCloseTag)) {
|
||||||
return GetContextFn(cc, "CLOSE")(cc);
|
return GetContextFn(cc, "CLOSE")(cc);
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -645,17 +658,18 @@ class PassThroughWithSleepCalculator : public CalculatorBase {
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).Set<int>();
|
cc->Inputs().Index(0).Set<int>();
|
||||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
|
cc->InputSidePackets().Tag(kSleepMicrosTag).Set<int>();
|
||||||
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
|
sleep_micros_ = cc->InputSidePackets().Tag(kSleepMicrosTag).Get<int>();
|
||||||
if (sleep_micros_ < 0) {
|
if (sleep_micros_ < 0) {
|
||||||
return absl::InternalError("SLEEP_MICROS should be >= 0");
|
return absl::InternalError("SLEEP_MICROS should be >= 0");
|
||||||
}
|
}
|
||||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
clock_ =
|
||||||
|
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
@ -678,8 +692,8 @@ class MultiplyIntCalculator : public CalculatorBase {
|
||||||
cc->Inputs().Index(0).Set<int>();
|
cc->Inputs().Index(0).Set<int>();
|
||||||
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
|
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
|
||||||
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||||
RET_CHECK(cc->Outputs().HasTag("OUT"));
|
RET_CHECK(cc->Outputs().HasTag(kOutTag));
|
||||||
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
|
cc->Outputs().Tag(kOutTag).SetSameAs(&cc->Inputs().Index(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
|
@ -689,7 +703,7 @@ class MultiplyIntCalculator : public CalculatorBase {
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
int x = cc->Inputs().Index(0).Value().Get<int>();
|
int x = cc->Inputs().Index(0).Value().Get<int>();
|
||||||
int y = cc->Inputs().Index(1).Value().Get<int>();
|
int y = cc->Inputs().Index(1).Value().Get<int>();
|
||||||
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
|
cc->Outputs().Tag(kOutTag).Add(new int(x * y), cc->InputTimestamp());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -60,6 +60,18 @@ def mediapipe_aar(
|
||||||
assets: additional assets to be included into the archive.
|
assets: additional assets to be included into the archive.
|
||||||
assets_dir: path where the assets will the packaged.
|
assets_dir: path where the assets will the packaged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
|
||||||
|
# the OpenCV so libraries will be excluded from the AAR package to
|
||||||
|
# save the package size.
|
||||||
|
native.config_setting(
|
||||||
|
name = "exclude_opencv_so_lib",
|
||||||
|
define_values = {
|
||||||
|
"EXCLUDE_OPENCV_SO_LIB": "1",
|
||||||
|
},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
_mediapipe_jni(
|
_mediapipe_jni(
|
||||||
name = name + "_jni",
|
name = name + "_jni",
|
||||||
gen_libmediapipe = gen_libmediapipe,
|
gen_libmediapipe = gen_libmediapipe,
|
||||||
|
@ -133,6 +145,7 @@ EOF
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
|
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
|
||||||
"//mediapipe/framework/port:disable_opencv": [],
|
"//mediapipe/framework/port:disable_opencv": [],
|
||||||
|
"exclude_opencv_so_lib": [],
|
||||||
}),
|
}),
|
||||||
assets = assets,
|
assets = assets,
|
||||||
assets_dir = assets_dir,
|
assets_dir = assets_dir,
|
||||||
|
|
|
@ -245,14 +245,6 @@ void Box::Fit(const std::vector<T>& vertices) {
|
||||||
auto system_g = system_h.colPivHouseholderQr();
|
auto system_g = system_h.colPivHouseholderQr();
|
||||||
auto solution = system_g.solve(v).eval();
|
auto solution = system_g.solve(v).eval();
|
||||||
transformation_.topLeftCorner<3, 4>() = solution.transpose();
|
transformation_.topLeftCorner<3, 4>() = solution.transpose();
|
||||||
|
|
||||||
// Adjust rotation matrix to its nearest orthogonal matrix.
|
|
||||||
const auto rotation = transformation_.topLeftCorner<3, 3>();
|
|
||||||
Eigen::JacobiSVD<Eigen::Matrix3f> svd(
|
|
||||||
rotation, Eigen::ComputeFullV | Eigen::ComputeFullU);
|
|
||||||
const Eigen::Matrix3f matrix_u = svd.matrixU();
|
|
||||||
const Eigen::Matrix3f matrix_v = svd.matrixV();
|
|
||||||
transformation_.topLeftCorner<3, 3>() = matrix_u * matrix_v.transpose();
|
|
||||||
Update();
|
Update();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -254,11 +254,11 @@ class SolutionBase:
|
||||||
for stream_name in self._output_stream_type_info.keys():
|
for stream_name in self._output_stream_type_info.keys():
|
||||||
self._graph.observe_output_stream(stream_name, callback, True)
|
self._graph.observe_output_stream(stream_name, callback, True)
|
||||||
|
|
||||||
input_side_packets = {
|
self._input_side_packets = {
|
||||||
name: self._make_packet(self._side_input_type_info[name], data)
|
name: self._make_packet(self._side_input_type_info[name], data)
|
||||||
for name, data in (side_inputs or {}).items()
|
for name, data in (side_inputs or {}).items()
|
||||||
}
|
}
|
||||||
self._graph.start_run(input_side_packets)
|
self._graph.start_run(self._input_side_packets)
|
||||||
|
|
||||||
# TODO: Use "inspect.Parameter" to fetch the input argument names and
|
# TODO: Use "inspect.Parameter" to fetch the input argument names and
|
||||||
# types from "_input_stream_type_info" and then auto generate the process
|
# types from "_input_stream_type_info" and then auto generate the process
|
||||||
|
@ -353,6 +353,12 @@ class SolutionBase:
|
||||||
self._input_stream_type_info = None
|
self._input_stream_type_info = None
|
||||||
self._output_stream_type_info = None
|
self._output_stream_type_info = None
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Resets the graph for another run."""
|
||||||
|
if self._graph:
|
||||||
|
self._graph.close()
|
||||||
|
self._graph.start_run(self._input_side_packets)
|
||||||
|
|
||||||
def _initialize_graph_interface(
|
def _initialize_graph_interface(
|
||||||
self,
|
self,
|
||||||
validated_graph: validated_graph_config.ValidatedGraphConfig,
|
validated_graph: validated_graph_config.ValidatedGraphConfig,
|
||||||
|
|
|
@ -298,6 +298,56 @@ class SolutionBaseTest(parameterized.TestCase):
|
||||||
'ImageTransformation.output_height': 0
|
'ImageTransformation.output_height': 0
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('graph_without_side_packets', """
|
||||||
|
input_stream: 'image_in'
|
||||||
|
output_stream: 'image_out'
|
||||||
|
node {
|
||||||
|
calculator: 'ImageTransformationCalculator'
|
||||||
|
input_stream: 'IMAGE:image_in'
|
||||||
|
output_stream: 'IMAGE:transformed_image_in'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'ImageTransformationCalculator'
|
||||||
|
input_stream: 'IMAGE:transformed_image_in'
|
||||||
|
output_stream: 'IMAGE:image_out'
|
||||||
|
}
|
||||||
|
""", None), ('graph_with_side_packets', """
|
||||||
|
input_stream: 'image_in'
|
||||||
|
input_side_packet: 'allow_signal'
|
||||||
|
input_side_packet: 'rotation_degrees'
|
||||||
|
output_stream: 'image_out'
|
||||||
|
node {
|
||||||
|
calculator: 'ImageTransformationCalculator'
|
||||||
|
input_stream: 'IMAGE:image_in'
|
||||||
|
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||||
|
output_stream: 'IMAGE:transformed_image_in'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'GateCalculator'
|
||||||
|
input_stream: 'transformed_image_in'
|
||||||
|
input_side_packet: 'ALLOW:allow_signal'
|
||||||
|
output_stream: 'image_out_to_transform'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'ImageTransformationCalculator'
|
||||||
|
input_stream: 'IMAGE:image_out_to_transform'
|
||||||
|
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||||
|
output_stream: 'IMAGE:image_out'
|
||||||
|
}""", {
|
||||||
|
'allow_signal': True,
|
||||||
|
'rotation_degrees': 0
|
||||||
|
}))
|
||||||
|
def test_solution_reset(self, text_config, side_inputs):
|
||||||
|
config_proto = text_format.Parse(text_config,
|
||||||
|
calculator_pb2.CalculatorGraphConfig())
|
||||||
|
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
|
||||||
|
with solution_base.SolutionBase(
|
||||||
|
graph_config=config_proto, side_inputs=side_inputs) as solution:
|
||||||
|
for _ in range(20):
|
||||||
|
outputs = solution.process(input_image)
|
||||||
|
self.assertTrue(np.array_equal(input_image, outputs.image_out))
|
||||||
|
solution.reset()
|
||||||
|
|
||||||
def _process_and_verify(self,
|
def _process_and_verify(self,
|
||||||
config_proto,
|
config_proto,
|
||||||
side_inputs=None,
|
side_inputs=None,
|
||||||
|
|
|
@ -26,7 +26,7 @@ import numpy.testing as npt
|
||||||
from mediapipe.python.solutions import objectron as mp_objectron
|
from mediapipe.python.solutions import objectron as mp_objectron
|
||||||
|
|
||||||
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
||||||
DIFF_THRESHOLD = 35 # pixels
|
DIFF_THRESHOLD = 30 # pixels
|
||||||
EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457],
|
EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457],
|
||||||
[383, 505], [80, 478], [408, 345],
|
[383, 505], [80, 478], [408, 345],
|
||||||
[130, 347], [384, 355], [72, 353]],
|
[130, 347], [384, 355], [72, 353]],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user