Project import generated by Copybara.

GitOrigin-RevId: 8e1da4611d93ccb7d9674713157d43be0348d98f
This commit is contained in:
MediaPipe Team 2021-07-27 18:20:05 -07:00 committed by chuoling
parent 50c92c6623
commit b899d17f18
79 changed files with 1808 additions and 946 deletions

View File

@ -220,6 +220,7 @@ import cv2
import mediapipe as mp import mediapipe as mp
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
mp_hands = mp.solutions.hands mp_hands = mp.solutions.hands
drawing_styles = mp.solutions.drawing_styles
# For static images: # For static images:
IMAGE_FILES = [] IMAGE_FILES = []
@ -248,7 +249,9 @@ with mp_hands.Hands(
f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})' f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})'
) )
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
drawing_styles.get_default_hand_landmark_style(),
drawing_styles.get_default_hand_connection_style())
cv2.imwrite( cv2.imwrite(
'/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1))
@ -278,7 +281,9 @@ with mp_hands.Hands(
if results.multi_hand_landmarks: if results.multi_hand_landmarks:
for hand_landmarks in results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks:
mp_drawing.draw_landmarks( mp_drawing.draw_landmarks(
image, hand_landmarks, mp_hands.HAND_CONNECTIONS) image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
drawing_styles.get_default_hand_landmark_style(),
drawing_styles.get_default_hand_connection_style())
cv2.imshow('MediaPipe Hands', image) cv2.imshow('MediaPipe Hands', image)
if cv2.waitKey(5) & 0xFF == 27: if cv2.waitKey(5) & 0xFF == 27:
break break

View File

@ -24,6 +24,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kDataTag[] = "DATA";
constexpr char kHeaderTag[] = "HEADER";
class AddHeaderCalculatorTest : public ::testing::Test {}; class AddHeaderCalculatorTest : public ::testing::Test {};
TEST_F(AddHeaderCalculatorTest, HeaderStream) { TEST_F(AddHeaderCalculatorTest, HeaderStream) {
@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) {
CalculatorRunner runner(node); CalculatorRunner runner(node);
// Set header and add 5 packets. // Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header = runner.MutableInputs()->Tag(kHeaderTag).header =
Adopt(new std::string("my_header")); Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet); runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
} }
// Run calculator. // Run calculator.
@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
CalculatorRunner runner(node); CalculatorRunner runner(node);
// Set header and add 5 packets. // Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header = runner.MutableInputs()->Tag(kHeaderTag).header =
Adopt(new std::string("my_header")); Adopt(new std::string("my_header"));
runner.MutableInputs()->Tag("HEADER").packets.push_back( runner.MutableInputs()
Adopt(new std::string("not allowed"))); ->Tag(kHeaderTag)
.packets.push_back(Adopt(new std::string("not allowed")));
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet); runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
} }
// Run calculator. // Run calculator.
@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
CalculatorRunner runner(node); CalculatorRunner runner(node);
// Set header and add 5 packets. // Set header and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") = runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header")); Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet); runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
} }
// Run calculator. // Run calculator.
@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
CalculatorRunner runner(node); CalculatorRunner runner(node);
// Set both headers and add 5 packets. // Set both headers and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") = runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header")); Adopt(new std::string("my_header"));
runner.MutableSidePackets()->Tag("HEADER") = runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header")); Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000)); Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet); runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
} }
// Run should fail because header can only be provided one way. // Run should fail because header can only be provided one way.

View File

@ -19,6 +19,13 @@
namespace mediapipe { namespace mediapipe {
constexpr char kIncrementTag[] = "INCREMENT";
constexpr char kInitialValueTag[] = "INITIAL_VALUE";
constexpr char kBatchSizeTag[] = "BATCH_SIZE";
constexpr char kErrorCountTag[] = "ERROR_COUNT";
constexpr char kMaxCountTag[] = "MAX_COUNT";
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of // Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
// sequential numbers from INITIAL_VALUE (default 0) with a common // sequential numbers from INITIAL_VALUE (default 0) with a common
// difference of INCREMENT (default 1) between successive numbers (with // difference of INCREMENT (default 1) between successive numbers (with
@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>(); cc->Outputs().Index(0).Set<int>();
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) {
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>(); cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
} }
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") || RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) ||
cc->InputSidePackets().HasTag("ERROR_COUNT")); cc->InputSidePackets().HasTag(kErrorCountTag));
if (cc->InputSidePackets().HasTag("MAX_COUNT")) { if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>(); cc->InputSidePackets().Tag(kMaxCountTag).Set<int>();
} }
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>(); cc->InputSidePackets().Tag(kErrorCountTag).Set<int>();
} }
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>(); cc->InputSidePackets().Tag(kBatchSizeTag).Set<int>();
} }
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>(); cc->InputSidePackets().Tag(kInitialValueTag).Set<int>();
} }
if (cc->InputSidePackets().HasTag("INCREMENT")) { if (cc->InputSidePackets().HasTag(kIncrementTag)) {
cc->InputSidePackets().Tag("INCREMENT").Set<int>(); cc->InputSidePackets().Tag(kIncrementTag).Set<int>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) &&
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) { cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
return absl::NotFoundError("expected error"); return absl::NotFoundError("expected error");
} }
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>(); error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get<int>();
RET_CHECK_LE(0, error_count_); RET_CHECK_LE(0, error_count_);
} }
if (cc->InputSidePackets().HasTag("MAX_COUNT")) { if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>(); max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get<int>();
RET_CHECK_LE(0, max_count_); RET_CHECK_LE(0, max_count_);
} }
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) { if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>(); batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get<int>();
RET_CHECK_LT(0, batch_size_); RET_CHECK_LT(0, batch_size_);
} }
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) { if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>(); counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get<int>();
} }
if (cc->InputSidePackets().HasTag("INCREMENT")) { if (cc->InputSidePackets().HasTag(kIncrementTag)) {
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>(); increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get<int>();
RET_CHECK_LT(0, increment_); RET_CHECK_LT(0, increment_);
} }
RET_CHECK(error_count_ >= 0 || max_count_ >= 0); RET_CHECK(error_count_ >= 0 || max_count_ >= 0);

View File

@ -35,11 +35,14 @@
// } // }
namespace mediapipe { namespace mediapipe {
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
constexpr char kEncodedTag[] = "ENCODED";
class DequantizeByteArrayCalculator : public CalculatorBase { class DequantizeByteArrayCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("ENCODED").Set<std::string>(); cc->Inputs().Tag(kEncodedTag).Set<std::string>();
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>(); cc->Outputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
const std::string& encoded = const std::string& encoded =
cc->Inputs().Tag("ENCODED").Value().Get<std::string>(); cc->Inputs().Tag(kEncodedTag).Value().Get<std::string>();
std::vector<float> float_vector; std::vector<float> float_vector;
float_vector.reserve(encoded.length()); float_vector.reserve(encoded.length());
for (int i = 0; i < encoded.length(); ++i) { for (int i = 0; i < encoded.length(); ++i) {
@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_); static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
} }
cc->Outputs() cc->Outputs()
.Tag("FLOAT_VECTOR") .Tag(kFloatVectorTag)
.AddPacket(MakePacket<std::vector<float>>(float_vector) .AddPacket(MakePacket<std::vector<float>>(float_vector)
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
return absl::OkStatus(); return absl::OkStatus();

View File

@ -25,6 +25,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
constexpr char kEncodedTag[] = "ENCODED";
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config = CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -39,7 +42,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
)pb"); )pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::string empty_string; std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back( runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0))); MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
@ -64,7 +69,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
)pb"); )pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::string empty_string; std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back( runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0))); MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
@ -89,7 +96,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
)pb"); )pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::string empty_string; std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back( runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0))); MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
EXPECT_FALSE(status.ok()); EXPECT_FALSE(status.ok());
@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
)pb"); )pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01}; unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
runner.MutableInputs()->Tag("ENCODED").packets.push_back( runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>( MakePacket<std::string>(
std::string(reinterpret_cast<char const*>(input), 4)) std::string(reinterpret_cast<char const*>(input), 4))
.At(Timestamp(0))); .At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = const std::vector<Packet>& outputs =
runner.Outputs().Tag("FLOAT_VECTOR").packets; runner.Outputs().Tag(kFloatVectorTag).packets;
EXPECT_EQ(1, outputs.size()); EXPECT_EQ(1, outputs.size());
const std::vector<float>& result = outputs[0].Get<std::vector<float>>(); const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
ASSERT_FALSE(result.empty()); ASSERT_FALSE(result.empty());

View File

@ -24,6 +24,11 @@
namespace mediapipe { namespace mediapipe {
constexpr char kFinishedTag[] = "FINISHED";
constexpr char kAllowTag[] = "ALLOW";
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
constexpr char kOptionsTag[] = "OPTIONS";
// FlowLimiterCalculator is used to limit the number of frames in flight // FlowLimiterCalculator is used to limit the number of frames in flight
// by dropping input frames when necessary. // by dropping input frames when necessary.
// //
@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
auto& side_inputs = cc->InputSidePackets(); auto& side_inputs = cc->InputSidePackets();
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional(); side_inputs.Tag(kOptionsTag).Set<FlowLimiterCalculatorOptions>().Optional();
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional(); cc->Inputs()
.Tag(kOptionsTag)
.Set<FlowLimiterCalculatorOptions>()
.Optional();
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1); RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
cc->Inputs().Get("", i).SetAny(); cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
} }
cc->Inputs().Get("FINISHED", 0).SetAny(); cc->Inputs().Get("FINISHED", 0).SetAny();
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional(); cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>().Optional();
cc->Outputs().Tag("ALLOW").Set<bool>().Optional(); cc->Outputs().Tag(kAllowTag).Set<bool>().Optional();
cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetInputStreamHandler("ImmediateInputStreamHandler");
cc->SetProcessTimestampBounds(true); cc->SetProcessTimestampBounds(true);
return absl::OkStatus(); return absl::OkStatus();
@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
options_ = cc->Options<FlowLimiterCalculatorOptions>(); options_ = cc->Options<FlowLimiterCalculatorOptions>();
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets()); options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
options_.set_max_in_flight( options_.set_max_in_flight(
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>()); cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
} }
input_queues_.resize(cc->Inputs().NumEntries("")); input_queues_.resize(cc->Inputs().NumEntries(""));
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase {
// Outputs a packet indicating whether a frame was sent or dropped. // Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag("ALLOW")) { if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts)); cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
} }
} }
@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase {
options_ = tool::RetrieveOptions(options_, cc->Inputs()); options_ = tool::RetrieveOptions(options_, cc->Inputs());
// Process the FINISHED input stream. // Process the FINISHED input stream.
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value(); Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value();
if (finished_packet.Timestamp() == cc->InputTimestamp()) { if (finished_packet.Timestamp() == cc->InputTimestamp()) {
while (!frames_in_flight_.empty() && while (!frames_in_flight_.empty() &&
frames_in_flight_.front() <= finished_packet.Timestamp()) { frames_in_flight_.front() <= finished_packet.Timestamp()) {
@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase {
Timestamp bound = Timestamp bound =
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream(); cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0)); SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
if (cc->Outputs().HasTag("ALLOW")) { if (cc->Outputs().HasTag(kAllowTag)) {
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW")); SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag));
} }
} }

View File

@ -36,6 +36,13 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS";
constexpr char kClockTag[] = "CLOCK";
constexpr char kWarmupTimeTag[] = "WARMUP_TIME";
constexpr char kSleepTimeTag[] = "SLEEP_TIME";
constexpr char kPacketTag[] = "PACKET";
// A simple Semaphore for synchronizing test threads. // A simple Semaphore for synchronizing test threads.
class AtomicSemaphore { class AtomicSemaphore {
public: public:
@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
class SleepCalculator : public CalculatorBase { class SleepCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny(); cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>(); cc->InputSidePackets().Tag(kSleepTimeTag).Set<int64>();
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>(); cc->InputSidePackets().Tag(kWarmupTimeTag).Set<int64>();
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>(); cc->InputSidePackets().Tag(kClockTag).Set<mediapipe::Clock*>();
cc->SetTimestampOffset(0); cc->SetTimestampOffset(0);
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>(); clock_ = cc->InputSidePackets().Tag(kClockTag).Get<mediapipe::Clock*>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase {
++packet_count; ++packet_count;
absl::Duration sleep_time = absl::Microseconds( absl::Duration sleep_time = absl::Microseconds(
packet_count == 1 packet_count == 1
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>() ? cc->InputSidePackets().Tag(kWarmupTimeTag).Get<int64>()
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>()); : cc->InputSidePackets().Tag(kSleepTimeTag).Get<int64>());
clock_->Sleep(sleep_time); clock_->Sleep(sleep_time);
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); cc->Outputs()
.Tag(kPacketTag)
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
return absl::OkStatus(); return absl::OkStatus();
} }
@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator);
class DropCalculator : public CalculatorBase { class DropCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny(); cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>(); cc->InputSidePackets().Tag(kDropTimestampsTag).Set<bool>();
cc->SetProcessTimestampBounds(true); cc->SetProcessTimestampBounds(true);
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) { if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
++packet_count; ++packet_count;
} }
bool drop = (packet_count == 3); bool drop = (packet_count == 3);
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) { if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); cc->Outputs()
.Tag(kPacketTag)
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
} }
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) { if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get<bool>()) {
cc->Outputs().Tag("PACKET").SetNextTimestampBound( cc->Outputs()
cc->InputTimestamp().NextAllowedInStream()); .Tag(kPacketTag)
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -21,6 +21,11 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kStateChangeTag[] = "STATE_CHANGE";
constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW";
enum GateState { enum GateState {
GATE_UNINITIALIZED, GATE_UNINITIALIZED,
GATE_ALLOW, GATE_ALLOW,
@ -83,30 +88,31 @@ class GateCalculator : public CalculatorBase {
GateCalculator() {} GateCalculator() {}
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) { static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) ||
cc->InputSidePackets().HasTag("DISALLOW"); cc->InputSidePackets().HasTag(kDisallowTag);
bool input_via_stream = bool input_via_stream =
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW"); cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag);
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW // Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
// input. // input.
RET_CHECK(input_via_side_packet ^ input_via_stream); RET_CHECK(input_via_side_packet ^ input_via_stream);
if (input_via_side_packet) { if (input_via_side_packet) {
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^ RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^
cc->InputSidePackets().HasTag("DISALLOW")); cc->InputSidePackets().HasTag(kDisallowTag));
if (cc->InputSidePackets().HasTag("ALLOW")) { if (cc->InputSidePackets().HasTag(kAllowTag)) {
cc->InputSidePackets().Tag("ALLOW").Set<bool>(); cc->InputSidePackets().Tag(kAllowTag).Set<bool>();
} else { } else {
cc->InputSidePackets().Tag("DISALLOW").Set<bool>(); cc->InputSidePackets().Tag(kDisallowTag).Set<bool>();
} }
} else { } else {
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW")); RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^
cc->Inputs().HasTag(kDisallowTag));
if (cc->Inputs().HasTag("ALLOW")) { if (cc->Inputs().HasTag(kAllowTag)) {
cc->Inputs().Tag("ALLOW").Set<bool>(); cc->Inputs().Tag(kAllowTag).Set<bool>();
} else { } else {
cc->Inputs().Tag("DISALLOW").Set<bool>(); cc->Inputs().Tag(kDisallowTag).Set<bool>();
} }
} }
return absl::OkStatus(); return absl::OkStatus();
@ -125,8 +131,8 @@ class GateCalculator : public CalculatorBase {
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
} }
if (cc->Outputs().HasTag("STATE_CHANGE")) { if (cc->Outputs().HasTag(kStateChangeTag)) {
cc->Outputs().Tag("STATE_CHANGE").Set<bool>(); cc->Outputs().Tag(kStateChangeTag).Set<bool>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -134,14 +140,14 @@ class GateCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
use_side_packet_for_allow_disallow_ = false; use_side_packet_for_allow_disallow_ = false;
if (cc->InputSidePackets().HasTag("ALLOW")) { if (cc->InputSidePackets().HasTag(kAllowTag)) {
use_side_packet_for_allow_disallow_ = true; use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ = allow_by_side_packet_decision_ =
cc->InputSidePackets().Tag("ALLOW").Get<bool>(); cc->InputSidePackets().Tag(kAllowTag).Get<bool>();
} else if (cc->InputSidePackets().HasTag("DISALLOW")) { } else if (cc->InputSidePackets().HasTag(kDisallowTag)) {
use_side_packet_for_allow_disallow_ = true; use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ = allow_by_side_packet_decision_ =
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>(); !cc->InputSidePackets().Tag(kDisallowTag).Get<bool>();
} }
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
@ -160,18 +166,18 @@ class GateCalculator : public CalculatorBase {
if (use_side_packet_for_allow_disallow_) { if (use_side_packet_for_allow_disallow_) {
allow = allow_by_side_packet_decision_; allow = allow_by_side_packet_decision_;
} else { } else {
if (cc->Inputs().HasTag("ALLOW") && if (cc->Inputs().HasTag(kAllowTag) &&
!cc->Inputs().Tag("ALLOW").IsEmpty()) { !cc->Inputs().Tag(kAllowTag).IsEmpty()) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>(); allow = cc->Inputs().Tag(kAllowTag).Get<bool>();
} }
if (cc->Inputs().HasTag("DISALLOW") && if (cc->Inputs().HasTag(kDisallowTag) &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) { !cc->Inputs().Tag(kDisallowTag).IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>(); allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
} }
} }
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW; const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
if (cc->Outputs().HasTag("STATE_CHANGE")) { if (cc->Outputs().HasTag(kStateChangeTag)) {
if (last_gate_state_ != GATE_UNINITIALIZED && if (last_gate_state_ != GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) { last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ " VLOG(2) << "State transition in " << cc->NodeName() << " @ "
@ -179,7 +185,7 @@ class GateCalculator : public CalculatorBase {
<< ToString(last_gate_state_) << " to " << ToString(last_gate_state_) << " to "
<< ToString(new_gate_state); << ToString(new_gate_state);
cc->Outputs() cc->Outputs()
.Tag("STATE_CHANGE") .Tag(kStateChangeTag)
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp())); .AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
} }
} }

View File

@ -22,6 +22,9 @@ namespace mediapipe {
namespace { namespace {
constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW";
class GateCalculatorTest : public ::testing::Test { class GateCalculatorTest : public ::testing::Test {
protected: protected:
// Helper to run a graph and return status. // Helper to run a graph and return status.
@ -117,7 +120,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
input_stream: "test_input" input_stream: "test_input"
output_stream: "test_output" output_stream: "test_output"
)"); )");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true)); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42; constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
@ -139,7 +142,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
input_stream: "test_input" input_stream: "test_input"
output_stream: "test_output" output_stream: "test_output"
)"); )");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false)); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42; constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
@ -161,7 +164,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
input_stream: "test_input" input_stream: "test_input"
output_stream: "test_output" output_stream: "test_output"
)"); )");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false)); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42; constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
@ -179,7 +182,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
input_stream: "test_input" input_stream: "test_input"
output_stream: "test_output" output_stream: "test_output"
)"); )");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true)); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42; constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);

View File

@ -39,20 +39,24 @@ using testing::ElementsAre;
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kClockTag[] = "CLOCK";
using mediapipe::Clock; using mediapipe::Clock;
// A Calculator with a fixed Process call latency. // A Calculator with a fixed Process call latency.
class SleepCalculator : public CalculatorBase { class SleepCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>(); cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->SetTimestampOffset(TimestampDiff(0)); cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>(); clock_ =
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -29,6 +29,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kMinuendTag[] = "MINUEND";
constexpr char kSubtrahendTag[] = "SUBTRAHEND";
// A 3x4 Matrix of random integers in [0,1000). // A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] = const char kMatrixText[] =
"rows: 3\n" "rows: 3\n"
@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix(); Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix); MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix); runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix);
Matrix* input_matrix = new Matrix(); Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix); MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()->Tag("MINUEND").packets.push_back( runner.MutableInputs()
Adopt(input_matrix).At(Timestamp(0))); ->Tag(kMinuendTag)
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size()); EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix(); Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix); MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix); runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix);
Matrix* input_matrix = new Matrix(); Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix); MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs() runner.MutableInputs()
->Tag("SUBTRAHEND") ->Tag(kSubtrahendTag)
.packets.push_back(Adopt(input_matrix).At(Timestamp(0))); .packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());

View File

@ -17,6 +17,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPacketTag[] = "PACKET";
// For each non empty input packet, emits a single output packet containing a // For each non empty input packet, emits a single output packet containing a
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp // boolean value "true", "false" in response to empty packets (a.k.a. timestamp
// bound updates) This can be used to "flag" the presence of an arbitrary packet // bound updates) This can be used to "flag" the presence of an arbitrary packet
@ -58,8 +61,8 @@ namespace mediapipe {
class PacketPresenceCalculator : public CalculatorBase { class PacketPresenceCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny(); cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag("PRESENCE").Set<bool>(); cc->Outputs().Tag(kPresenceTag).Set<bool>();
// Process() function is invoked in response to input stream timestamp // Process() function is invoked in response to input stream timestamp
// bound updates. // bound updates.
cc->SetProcessTimestampBounds(true); cc->SetProcessTimestampBounds(true);
@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
cc->Outputs() cc->Outputs()
.Tag("PRESENCE") .Tag(kPresenceTag)
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty()) .AddPacket(MakePacket<bool>(!cc->Inputs().Tag(kPacketTag).IsEmpty())
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -39,6 +39,11 @@ namespace mediapipe {
REGISTER_CALCULATOR(PacketResamplerCalculator); REGISTER_CALCULATOR(PacketResamplerCalculator);
namespace { namespace {
constexpr char kSeedTag[] = "SEED";
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
constexpr char kOptionsTag[] = "OPTIONS";
// Returns a TimestampDiff (assuming microseconds) corresponding to the // Returns a TimestampDiff (assuming microseconds) corresponding to the
// given time in seconds. // given time in seconds.
TimestampDiff TimestampDiffFromSeconds(double seconds) { TimestampDiff TimestampDiffFromSeconds(double seconds) {
@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
const auto& resampler_options = const auto& resampler_options =
cc->Options<PacketResamplerCalculatorOptions>(); cc->Options<PacketResamplerCalculatorOptions>();
if (cc->InputSidePackets().HasTag("OPTIONS")) { if (cc->InputSidePackets().HasTag(kOptionsTag)) {
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>(); cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
} }
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0); CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
if (!input_data_id.IsValid()) { if (!input_data_id.IsValid()) {
input_data_id = cc->Inputs().GetId("", 0); input_data_id = cc->Inputs().GetId("", 0);
} }
cc->Inputs().Get(input_data_id).SetAny(); cc->Inputs().Get(input_data_id).SetAny();
if (cc->Inputs().HasTag("VIDEO_HEADER")) { if (cc->Inputs().HasTag(kVideoHeaderTag)) {
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>(); cc->Inputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
} }
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0); CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
output_data_id = cc->Outputs().GetId("", 0); output_data_id = cc->Outputs().GetId("", 0);
} }
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id)); cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
if (cc->Outputs().HasTag("VIDEO_HEADER")) { if (cc->Outputs().HasTag(kVideoHeaderTag)) {
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>(); cc->Outputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
} }
if (resampler_options.jitter() != 0.0) { if (resampler_options.jitter() != 0.0) {
RET_CHECK_GT(resampler_options.jitter(), 0.0); RET_CHECK_GT(resampler_options.jitter(), 0.0);
RET_CHECK_LE(resampler_options.jitter(), 1.0); RET_CHECK_LE(resampler_options.jitter(), 1.0);
RET_CHECK(cc->InputSidePackets().HasTag("SEED")); RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag));
cc->InputSidePackets().Tag("SEED").Set<std::string>(); cc->InputSidePackets().Tag(kSeedTag).Set<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream() && if (cc->InputTimestamp() == Timestamp::PreStream() &&
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) &&
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { !cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) {
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>(); video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get<VideoHeader>();
video_header_.frame_rate = frame_rate_; video_header_.frame_rate = frame_rate_;
if (cc->Inputs().Get(input_data_id_).IsEmpty()) { if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
return absl::OkStatus(); return absl::OkStatus();
@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>(); const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed); random_ = CreateSecureRandom(seed);
if (random_ == nullptr) { if (random_ == nullptr) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open(
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>(); const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed); random_ = CreateSecureRandom(seed);
if (random_ == nullptr) { if (random_ == nullptr) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
"ignored, because we are adding jitter."; "ignored, because we are adding jitter.";
} }
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>(); const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed); random_ = CreateSecureRandom(seed);
if (random_ == nullptr) { if (random_ == nullptr) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
base_timestamp_ + base_timestamp_ +
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_); TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
} }
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) { if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) {
cc->Outputs() cc->Outputs()
.Tag("VIDEO_HEADER") .Tag(kVideoHeaderTag)
.Add(new VideoHeader(calculator_->video_header_), .Add(new VideoHeader(calculator_->video_header_),
Timestamp::PreStream()); Timestamp::PreStream());
} }

View File

@ -32,6 +32,12 @@ namespace mediapipe {
using ::testing::ElementsAre; using ::testing::ElementsAre;
namespace { namespace {
constexpr char kOptionsTag[] = "OPTIONS";
constexpr char kSeedTag[] = "SEED";
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
constexpr char kDataTag[] = "DATA";
// A simple version of CalculatorRunner with built-in convenience // A simple version of CalculatorRunner with built-in convenience
// methods for setting inputs from a vector and checking outputs // methods for setting inputs from a vector and checking outputs
// against expected outputs (both timestamps and contents). // against expected outputs (both timestamps and contents).
@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
)pb")); )pb"));
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
runner.MutableInputs()->Tag("DATA").packets.push_back( runner.MutableInputs()->Tag(kDataTag).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
} }
VideoHeader video_header_in; VideoHeader video_header_in;
@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
video_header_in.duration = 1.0; video_header_in.duration = 1.0;
video_header_in.format = ImageFormat::SRGB; video_header_in.format = ImageFormat::SRGB;
runner.MutableInputs() runner.MutableInputs()
->Tag("VIDEO_HEADER") ->Tag(kVideoHeaderTag)
.packets.push_back( .packets.push_back(
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream())); Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size()); ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size());
EXPECT_EQ(Timestamp::PreStream(), EXPECT_EQ(Timestamp::PreStream(),
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp()); runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp());
const VideoHeader& video_header_out = const VideoHeader& video_header_out =
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>(); runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(video_header_in.width, video_header_out.width); EXPECT_EQ(video_header_in.width, video_header_out.width);
EXPECT_EQ(video_header_in.height, video_header_out.height); EXPECT_EQ(video_header_in.height, video_header_out.height);
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate); EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
[mediapipe.PacketResamplerCalculatorOptions.ext] { [mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30 frame_rate: 30
})pb")); })pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000}); runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size()); EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
frame_rate: 30 frame_rate: 30
base_timestamp: 0 base_timestamp: 0
})pb")); })pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options); runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000}); runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());

View File

@ -29,6 +29,8 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kPeriodTag[] = "PERIOD";
// A simple version of CalculatorRunner with built-in convenience methods for // A simple version of CalculatorRunner with built-in convenience methods for
// setting inputs from a vector and checking outputs against a vector of // setting inputs from a vector and checking outputs against a vector of
// expected outputs. // expected outputs.
@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
SimpleRunner runner(node); SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14}); runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5); runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 8, 14}; const std::vector<int64> expected_timestamps = {2, 8, 14};
@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
SimpleRunner runner(node); SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14}); runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5); runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 6, 10, 14}; const std::vector<int64> expected_timestamps = {2, 6, 10, 14};

View File

@ -39,6 +39,8 @@ using ::testing::Pair;
using ::testing::Value; using ::testing::Value;
namespace { namespace {
constexpr char kDisallowTag[] = "DISALLOW";
// Returns the timestamp values for a vector of Packets. // Returns the timestamp values for a vector of Packets.
// TODO: puth this kind of test util in a common place. // TODO: puth this kind of test util in a common place.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) { std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(0).SetAny();
cc->Inputs().Tag("DISALLOW").Set<bool>(); cc->Inputs().Tag(kDisallowTag).Set<bool>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Index(0).IsEmpty() && if (!cc->Inputs().Index(0).IsEmpty() &&
!cc->Inputs().Tag("DISALLOW").Get<bool>()) { !cc->Inputs().Tag(kDisallowTag).Get<bool>()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -41,11 +41,14 @@
// } // }
namespace mediapipe { namespace mediapipe {
constexpr char kEncodedTag[] = "ENCODED";
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
class QuantizeFloatVectorCalculator : public CalculatorBase { class QuantizeFloatVectorCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>(); cc->Inputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
cc->Outputs().Tag("ENCODED").Set<std::string>(); cc->Outputs().Tag(kEncodedTag).Set<std::string>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
const std::vector<float>& float_vector = const std::vector<float>& float_vector =
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>(); cc->Inputs().Tag(kFloatVectorTag).Value().Get<std::vector<float>>();
int feature_size = float_vector.size(); int feature_size = float_vector.size();
std::string encoded_features; std::string encoded_features;
encoded_features.reserve(feature_size); encoded_features.reserve(feature_size);
@ -86,7 +89,9 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
(old_value - min_quantized_value_) * (255.0 / range_)); (old_value - min_quantized_value_) * (255.0 / range_));
encoded_features += encoded; encoded_features += encoded;
} }
cc->Outputs().Tag("ENCODED").AddPacket( cc->Outputs()
.Tag(kEncodedTag)
.AddPacket(
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp())); MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -25,6 +25,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kEncodedTag[] = "ENCODED";
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) { TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config = CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> empty_vector; std::vector<float> empty_vector;
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> empty_vector; std::vector<float> empty_vector;
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> empty_vector; std::vector<float> empty_vector;
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run(); auto status = runner.Run();
@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> empty_vector; std::vector<float> empty_vector;
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets; const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size()); EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::string>().empty()); EXPECT_TRUE(outputs[0].Get<std::string>().empty());
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp()); EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f}; std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0))); MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets; const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size()); EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>(); const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty()); ASSERT_FALSE(result.empty());
@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
std::vector<float> vector = {-65.0f, 65.0f}; std::vector<float> vector = {-65.0f, 65.0f};
runner.MutableInputs() runner.MutableInputs()
->Tag("FLOAT_VECTOR") ->Tag(kFloatVectorTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0))); MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets; const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size()); EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>(); const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty()); ASSERT_FALSE(result.empty());

View File

@ -23,6 +23,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kAllowTag[] = "ALLOW";
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined // RealTimeFlowLimiterCalculator is used to limit the number of pipelined
// processing operations in a section of the graph. // processing operations in a section of the graph.
// //
@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i))); cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
} }
cc->Inputs().Get("FINISHED", 0).SetAny(); cc->Inputs().Get("FINISHED", 0).SetAny();
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>(); cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>();
} }
if (cc->Outputs().HasTag("ALLOW")) { if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag("ALLOW").Set<bool>(); cc->Outputs().Tag(kAllowTag).Set<bool>();
} }
cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetInputStreamHandler("ImmediateInputStreamHandler");
@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
finished_id_ = cc->Inputs().GetId("FINISHED", 0); finished_id_ = cc->Inputs().GetId("FINISHED", 0);
max_in_flight_ = 1; max_in_flight_ = 1;
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>(); max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>();
} }
RET_CHECK_GE(max_in_flight_, 1); RET_CHECK_GE(max_in_flight_, 1);
num_in_flight_ = 0; num_in_flight_ = 0;

View File

@ -33,6 +33,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kFinishedTag[] = "FINISHED";
// A simple Semaphore for synchronizing test threads. // A simple Semaphore for synchronizing test threads.
class AtomicSemaphore { class AtomicSemaphore {
public: public:
@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) {
Timestamp timestamp = Timestamp timestamp =
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond); Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
runner.MutableInputs() runner.MutableInputs()
->Tag("FINISHED") ->Tag(kFinishedTag)
.packets.push_back(MakePacket<bool>(true).At(timestamp)); .packets.push_back(MakePacket<bool>(true).At(timestamp));
} }

View File

@ -22,6 +22,8 @@ namespace mediapipe {
namespace { namespace {
constexpr char kPacketOffsetTag[] = "PACKET_OFFSET";
// Adds packets containing integers equal to their original timestamp. // Adds packets containing integers equal to their original timestamp.
void AddPackets(CalculatorRunner* runner) { void AddPackets(CalculatorRunner* runner) {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) {
CalculatorRunner runner(node); CalculatorRunner runner(node);
AddPackets(&runner); AddPackets(&runner);
runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2)); runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& input_packets = const std::vector<Packet>& input_packets =
runner.MutableInputs()->Index(0).packets; runner.MutableInputs()->Index(0).packets;

View File

@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode(
// IMAGE: ImageFrame representing the input image. // IMAGE: ImageFrame representing the input image.
// IMAGE_GPU: GpuBuffer representing the input image. // IMAGE_GPU: GpuBuffer representing the input image.
// //
// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as
// pair<int, int>. If set, it will override corresponding field in calculator
// options and input side packet.
//
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in // ROTATION_DEGREES (optional): The counterclockwise rotation angle in
// degrees. This allows different rotation angles for different frames. It has // degrees. This allows different rotation angles for different frames. It has
// to be a multiple of 90 degrees. If provided, it overrides the // to be a multiple of 90 degrees. If provided, it overrides the
@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract(
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<std::pair<int, int>>();
}
if (cc->Inputs().HasTag("ROTATION_DEGREES")) { if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>(); cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
} }
@ -329,6 +337,13 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) {
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) { !cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>(); flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
} }
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") &&
!cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
const auto& image_size =
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<std::pair<int, int>>();
output_width_ = image_size.first;
output_height_ = image_size.second;
}
if (use_gpu_) { if (use_gpu_) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU

View File

@ -88,6 +88,13 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"], deps = ["//mediapipe/framework:calculator_proto"],
) )
proto_library(
name = "tensor_to_vector_string_calculator_options_proto",
srcs = ["tensor_to_vector_string_calculator_options.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library( proto_library(
name = "unpack_media_sequence_calculator_proto", name = "unpack_media_sequence_calculator_proto",
srcs = ["unpack_media_sequence_calculator.proto"], srcs = ["unpack_media_sequence_calculator.proto"],
@ -257,6 +264,14 @@ mediapipe_cc_proto_library(
deps = [":tensor_to_vector_float_calculator_options_proto"], deps = [":tensor_to_vector_float_calculator_options_proto"],
) )
mediapipe_cc_proto_library(
name = "tensor_to_vector_string_calculator_options_cc_proto",
srcs = ["tensor_to_vector_string_calculator_options.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":tensor_to_vector_string_calculator_options_proto"],
)
mediapipe_cc_proto_library( mediapipe_cc_proto_library(
name = "unpack_media_sequence_calculator_cc_proto", name = "unpack_media_sequence_calculator_cc_proto",
srcs = ["unpack_media_sequence_calculator.proto"], srcs = ["unpack_media_sequence_calculator.proto"],
@ -694,6 +709,26 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "tensor_to_vector_string_calculator",
srcs = ["tensor_to_vector_string_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
":tensor_to_vector_string_calculator_options_cc_proto",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
],
}),
alwayslink = 1,
)
cc_library( cc_library(
name = "unpack_media_sequence_calculator", name = "unpack_media_sequence_calculator",
srcs = ["unpack_media_sequence_calculator.cc"], srcs = ["unpack_media_sequence_calculator.cc"],
@ -1059,6 +1094,20 @@ cc_test(
], ],
) )
cc_test(
name = "tensor_to_vector_string_calculator_test",
srcs = ["tensor_to_vector_string_calculator_test.cc"],
deps = [
":tensor_to_vector_string_calculator",
":tensor_to_vector_string_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test( cc_test(
name = "unpack_media_sequence_calculator_test", name = "unpack_media_sequence_calculator_test",
srcs = ["unpack_media_sequence_calculator_test.cc"], srcs = ["unpack_media_sequence_calculator_test.cc"],

View File

@ -40,6 +40,24 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence; namespace mpms = mediapipe::mediasequence;
constexpr char kBboxTag[] = "BBOX";
constexpr char kEncodedMediaStartTimestampTag[] =
"ENCODED_MEDIA_START_TIMESTAMP";
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION";
constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST";
constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED";
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
constexpr char kAudioTestTag[] = "AUDIO_TEST";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
constexpr char kImageTag[] = "IMAGE";
class PackMediaSequenceCalculatorTest : public ::testing::Test { class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpCalculator(const std::vector<std::string>& input_streams, void SetUpCalculator(const std::vector<std::string>& input_streams,
@ -83,17 +101,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back( runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i))); Adopt(image_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -127,17 +145,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("IMAGE_PREFIX") ->Tag(kImagePrefixTag)
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -161,21 +179,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
for (int i = 0; i < num_timesteps; ++i) { for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i); auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST") ->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i); vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FLOAT_FEATURE_OTHER") ->Tag(kFloatFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -228,20 +246,20 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3); auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FLOAT_CONTEXT_FEATURE_TEST") ->Tag(kFloatContextFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
vf_ptr = absl::make_unique<std::vector<float>>(2, 4); vf_ptr = absl::make_unique<std::vector<float>>(2, 4);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FLOAT_CONTEXT_FEATURE_OTHER") ->Tag(kFloatContextFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -259,7 +277,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
SetUpCalculator({"IMAGE:images"}, context, false, true); SetUpCalculator({"IMAGE:images"}, context, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes; std::vector<uchar> bytes;
@ -268,13 +286,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_encoded_image(bytes.data(), bytes.size());
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back( runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(0))); Adopt(image_ptr.release()).At(Timestamp(0)));
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -307,17 +325,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
auto flow_ptr = auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED") ->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -371,17 +389,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
detections->push_back(detection); detections->push_back(detection);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("BBOX_PREDICTED") ->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i))); .packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -450,11 +468,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) {
detections->push_back(detection); detections->push_back(detection);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("BBOX_PREDICTED") ->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i))); .packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
auto status = runner_->Run(); auto status = runner_->Run();
@ -498,7 +516,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
detections->push_back(detection); detections->push_back(detection);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("BBOX_PREDICTED") ->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i))); .packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
} }
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
@ -513,16 +531,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back( runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i))); Adopt(image_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -564,18 +582,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>> absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}}; points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
runner_->MutableInputs() runner_->MutableInputs()
->Tag("KEYPOINTS_TEST") ->Tag(kKeypointsTestTag)
.packets.push_back(PointToForeign(&points).At(Timestamp(0))); .packets.push_back(PointToForeign(&points).At(Timestamp(0)));
runner_->MutableInputs() runner_->MutableInputs()
->Tag("KEYPOINTS_TEST") ->Tag(kKeypointsTestTag)
.packets.push_back(PointToForeign(&points).At(Timestamp(1))); .packets.push_back(PointToForeign(&points).At(Timestamp(1)));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -615,17 +633,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
detections->push_back(detection); detections->push_back(detection);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("CLASS_SEGMENTATION") ->Tag(kClassSegmentationTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i))); .packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -664,17 +682,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
auto flow_ptr = auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED") ->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -710,11 +728,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
auto flow_ptr = auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED") ->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
absl::Status status = runner_->Run(); absl::Status status = runner_->Run();
@ -731,13 +749,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) {
mpms::AddImageTimestamp(1, input_sequence.get()); mpms::AddImageTimestamp(1, input_sequence.get());
mpms::AddImageTimestamp(2, input_sequence.get()); mpms::AddImageTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -757,13 +775,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) {
mpms::AddForwardFlowTimestamp(1, input_sequence.get()); mpms::AddForwardFlowTimestamp(1, input_sequence.get());
mpms::AddForwardFlowTimestamp(2, input_sequence.get()); mpms::AddForwardFlowTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -794,13 +812,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) {
mpms::GetFeatureTimestampSize("OTHER", *input_sequence)); mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
ASSERT_EQ(num_timesteps, ASSERT_EQ(num_timesteps,
mpms::GetFeatureFloatsSize("OTHER", *input_sequence)); mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -826,7 +844,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back( runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10))); Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
} }
@ -838,11 +856,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get()); mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get()); mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets; runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence = const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>(); output_packets[0].Get<tf::SequenceExample>();
@ -879,7 +897,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
auto image_ptr = auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image); ::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back( runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i))); Adopt(image_ptr.release()).At(Timestamp(i)));
} }
@ -893,7 +911,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5) Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
.ConvertToProto(detection.mutable_location_data()); .ConvertToProto(detection.mutable_location_data());
detections->push_back(detection); detections->push_back(detection);
runner_->MutableInputs()->Tag("BBOX").packets.push_back( runner_->MutableInputs()->Tag(kBboxTag).packets.push_back(
Adopt(detections.release()).At(Timestamp(i))); Adopt(detections.release()).At(Timestamp(i)));
} }
@ -909,7 +927,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
mpms::AddBBoxTrackIndex({-1}, input_sequence.get()); mpms::AddBBoxTrackIndex({-1}, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
// If the all the previous values aren't cleared, this assert will fail. // If the all the previous values aren't cleared, this assert will fail.
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
@ -925,11 +943,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) {
for (int i = 0; i < num_timesteps; ++i) { for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i); auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i);
runner_->MutableInputs() runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST") ->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i))); .packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
ASSERT_FALSE(runner_->Run().ok()); ASSERT_FALSE(runner_->Run().ok());
} }

View File

@ -26,6 +26,8 @@ namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace { namespace {
constexpr char kReferenceTag[] = "REFERENCE";
constexpr char kMatrix[] = "MATRIX"; constexpr char kMatrix[] = "MATRIX";
constexpr char kTensor[] = "TENSOR"; constexpr char kTensor[] = "TENSOR";
@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test {
if (include_rate) { if (include_rate) {
header->set_packet_rate(1.0); header->set_packet_rate(1.0);
} }
runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release()); runner_->MutableInputs()->Tag(kReferenceTag).header =
Adopt(header.release());
} }
std::unique_ptr<CalculatorRunner> runner_; std::unique_ptr<CalculatorRunner> runner_;

View File

@ -0,0 +1,118 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Calculator converts from one-dimensional Tensor of DT_STRING to
// vector<std::string> OR from (batched) two-dimensional Tensor of DT_STRING to
// vector<vector<std::string>.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = ::tensorflow;
class TensorToVectorStringCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
TensorToVectorStringCalculatorOptions options_;
};
REGISTER_CALCULATOR(TensorToVectorStringCalculator);
absl::Status TensorToVectorStringCalculator::GetContract(
CalculatorContract* cc) {
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<tf::Tensor>(
// Input Tensor
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
const auto& options = cc->Options<TensorToVectorStringCalculatorOptions>();
if (options.tensor_is_2d()) {
RET_CHECK(!options.flatten_nd());
cc->Outputs().Index(0).Set<std::vector<std::vector<std::string>>>(
/* "Output vector<vector<std::string>>." */);
} else {
cc->Outputs().Index(0).Set<std::vector<std::string>>(
// Output vector<std::string>.
);
}
return absl::OkStatus();
}
absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<TensorToVectorStringCalculatorOptions>();
// Inform mediapipe that this calculator produces an output at time t for
// each input received at time t (i.e. this calculator does not buffer
// inputs). This enables mediapipe to propagate time of arrival estimates in
// mediapipe graphs through this calculator.
cc->SetOffset(/*offset=*/0);
return absl::OkStatus();
}
absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) {
const tf::Tensor& input_tensor =
cc->Inputs().Index(0).Value().Get<tf::Tensor>();
RET_CHECK(tf::DT_STRING == input_tensor.dtype())
<< "expected DT_STRING input but got "
<< tensorflow::DataTypeString(input_tensor.dtype());
if (options_.tensor_is_2d()) {
RET_CHECK(2 == input_tensor.dims())
<< "Expected 2-dimensional Tensor, but the tensor shape is: "
<< input_tensor.shape().DebugString();
auto output = absl::make_unique<std::vector<std::vector<std::string>>>(
input_tensor.dim_size(0),
std::vector<std::string>(input_tensor.dim_size(1)));
for (int i = 0; i < input_tensor.dim_size(0); ++i) {
auto& instance_output = output->at(i);
const auto& slice =
input_tensor.Slice(i, i + 1).unaligned_flat<tensorflow::tstring>();
for (int j = 0; j < input_tensor.dim_size(1); ++j) {
instance_output.at(j) = slice(j);
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else {
if (!options_.flatten_nd()) {
RET_CHECK(1 == input_tensor.dims())
<< "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the "
<< "tensor shape is: " << input_tensor.shape().DebugString();
}
auto output =
absl::make_unique<std::vector<std::string>>(input_tensor.NumElements());
const auto& tensor_values = input_tensor.flat<tensorflow::tstring>();
for (int i = 0; i < input_tensor.NumElements(); ++i) {
output->at(i) = tensor_values(i);
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorToVectorStringCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorToVectorStringCalculatorOptions ext = 386534187;
}
// If true, unpack a 2d tensor (matrix) into a vector<vector<string>>. If
// false, convert a 1d tensor (vector) into a vector<string>.
optional bool tensor_is_2d = 1 [default = false];
// If true, an N-D tensor will be flattened to a vector<string>. This is
// exclusive with tensor_is_2d.
optional bool flatten_nd = 2 [default = false];
}

View File

@ -0,0 +1,130 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class TensorToVectorStringCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToVectorStringCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
auto options = config.mutable_options()->MutableExtension(
TensorToVectorStringCalculatorOptions::ext);
options->set_tensor_is_2d(tensor_is_2d);
options->set_flatten_nd(flatten_nd);
runner_ = absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) {
SetUpRunner(false, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto tensor_vec = tensor->vec<tensorflow::tstring>();
for (int i = 0; i < 5; ++i) {
tensor_vec(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::string>& output_vector =
output_packets[0].Get<std::vector<std::string>>();
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) {
SetUpRunner(true, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{1, 5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto slice = tensor->Slice(0, 1).flat<tensorflow::tstring>();
for (int i = 0; i < 5; ++i) {
slice(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::vector<std::string>>& output_vectors =
output_packets[0].Get<std::vector<std::vector<std::string>>>();
ASSERT_EQ(1, output_vectors.size());
const std::vector<std::string>& output_vector = output_vectors[0];
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) {
SetUpRunner(false, true);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 2, 2});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto slice = tensor->flat<tensorflow::tstring>();
for (int i = 0; i < 2 * 2 * 2; ++i) {
slice(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::string>& output_vector =
output_packets[0].Get<std::vector<std::string>>();
EXPECT_EQ(2 * 2 * 2, output_vector.size());
for (int i = 0; i < 2 * 2 * 2; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
} // namespace
} // namespace mediapipe

View File

@ -49,6 +49,11 @@ namespace tf = ::tensorflow;
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS";
constexpr char kSessionTag[] = "SESSION";
constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
// This is a simple implementation of a semaphore using standard C++ libraries. // This is a simple implementation of a semaphore using standard C++ libraries.
// It is supposed to be used only by TensorflowInferenceCalculator to throttle // It is supposed to be used only by TensorflowInferenceCalculator to throttle
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple // the concurrent calls of Tensorflow Session::Run. This is useful when multiple
@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
} }
// A mediapipe::TensorFlowSession with a model loaded and ready for use. // A mediapipe::TensorFlowSession with a model loaded and ready for use.
// For this calculator it must include a tag_to_tensor_map. // For this calculator it must include a tag_to_tensor_map.
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>(); cc->InputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) {
cc->InputSidePackets() cc->InputSidePackets()
.Tag("RECURRENT_INIT_TENSORS") .Tag(kRecurrentInitTensorsTag)
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>(); .Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
std::unique_ptr<InferenceState> inference_state = std::unique_ptr<InferenceState> inference_state =
absl::make_unique<InferenceState>(); absl::make_unique<InferenceState>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { !cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map; std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>( init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); cc->InputSidePackets().Tag(kRecurrentInitTensorsTag));
for (const auto& p : *init_tensor_map) { for (const auto& p : *init_tensor_map) {
inference_state->input_tensor_batches_[p.first].emplace_back(p.second); inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
} }
@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>(); options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag));
session_ = cc->InputSidePackets() session_ = cc->InputSidePackets()
.Tag("SESSION") .Tag(kSessionTag)
.Get<TensorFlowSession>() .Get<TensorFlowSession>()
.session.get(); .session.get();
tag_to_tensor_map_ = cc->InputSidePackets() tag_to_tensor_map_ = cc->InputSidePackets()
.Tag("SESSION") .Tag(kSessionTag)
.Get<TensorFlowSession>() .Get<TensorFlowSession>()
.tag_to_tensor_map; .tag_to_tensor_map;

View File

@ -41,6 +41,11 @@ namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace { namespace {
constexpr char kMultipliedTag[] = "MULTIPLIED";
constexpr char kBTag[] = "B";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() { std::string GetGraphDefPath() {
#ifdef __APPLE__ #ifdef __APPLE__
char path[1024]; char path[1024];
@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes( MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options, "TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
input_side_packets, &output_side_packets)); input_side_packets, &output_side_packets));
runner_->MutableSidePackets()->Tag("SESSION") = runner_->MutableSidePackets()->Tag(kSessionTag) =
output_side_packets.Tag("SESSION"); output_side_packets.Tag(kSessionTag);
} }
Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) { Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_b = const std::vector<Packet>& output_packets_b =
runner_->Outputs().Tag("B").packets; runner_->Outputs().Tag(kBTag).packets;
ASSERT_EQ(output_packets_b.size(), 1); ASSERT_EQ(output_packets_b.size(), 1);
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>(); const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({1, 3}); tf::TensorShape expected_shape({1, 3});
@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b); tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size()); ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape); expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size()); ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3}); tf::TensorShape expected_shape({3});
@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size()); ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3}); tf::TensorShape expected_shape({3});
@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(3, output_packets_mult.size()); ASSERT_EQ(3, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(5, output_packets_mult.size()); ASSERT_EQ(5, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10}); auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0; LOG(INFO) << "timestamp: " << 0;
@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size()); ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0; LOG(INFO) << "timestamp: " << 0;
@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) {
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(0, output_packets_mult.size()); ASSERT_EQ(0, output_packets_mult.size());
} }
@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult = const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets; runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size()); ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>(); const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15}); auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});

View File

@ -47,6 +47,11 @@ namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace { namespace {
constexpr char kSessionTag[] = "SESSION";
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
// Updates the graph nodes to use the device as specified by device_id. // Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) { for (auto& node : *graph_def->mutable_node()) {
@ -64,27 +69,29 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>(); cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
bool has_exactly_one_model = bool has_exactly_one_model =
!options.graph_proto_path().empty() !options.graph_proto_path().empty()
? !(cc->InputSidePackets().HasTag("STRING_MODEL") | ? !(cc->InputSidePackets().HasTag(kStringModelTag) |
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) cc->InputSidePackets().HasTag(kStringModelFilePathTag))
: (cc->InputSidePackets().HasTag("STRING_MODEL") ^ : (cc->InputSidePackets().HasTag(kStringModelTag) ^
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")); cc->InputSidePackets().HasTag(kStringModelFilePathTag));
RET_CHECK(has_exactly_one_model) RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of graph_proto_path in options or " << "Must have exactly one of graph_proto_path in options or "
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
if (cc->InputSidePackets().HasTag("STRING_MODEL")) { if (cc->InputSidePackets().HasTag(kStringModelTag)) {
cc->InputSidePackets() cc->InputSidePackets()
.Tag("STRING_MODEL") .Tag(kStringModelTag)
.Set<std::string>( .Set<std::string>(
// String model from embedded path // String model from embedded path
); );
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
cc->InputSidePackets() cc->InputSidePackets()
.Tag("STRING_MODEL_FILE_PATH") .Tag(kStringModelFilePathTag)
.Set<std::string>( .Set<std::string>(
// Filename of std::string model. // Filename of std::string model.
); );
} }
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>( cc->OutputSidePackets()
.Tag(kSessionTag)
.Set<TensorFlowSession>(
// A TensorFlow model loaded and ready for use along with // A TensorFlow model loaded and ready for use along with
// a map from tags to tensor names. // a map from tags to tensor names.
); );
@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
session->session.reset(tf::NewSession(session_options)); session->session.reset(tf::NewSession(session_options));
std::string graph_def_serialized; std::string graph_def_serialized;
if (cc->InputSidePackets().HasTag("STRING_MODEL")) { if (cc->InputSidePackets().HasTag(kStringModelTag)) {
graph_def_serialized = graph_def_serialized =
cc->InputSidePackets().Tag("STRING_MODEL").Get<std::string>(); cc->InputSidePackets().Tag(kStringModelTag).Get<std::string>();
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) { } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
const std::string& frozen_graph = cc->InputSidePackets() const std::string& frozen_graph = cc->InputSidePackets()
.Tag("STRING_MODEL_FILE_PATH") .Tag(kStringModelFilePathTag)
.Get<std::string>(); .Get<std::string>();
RET_CHECK_OK( RET_CHECK_OK(
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
} }
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds."; << " microseconds.";

View File

@ -37,6 +37,10 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() { std::string GetGraphDefPath() {
return mediapipe::file::JoinPath("./", return mediapipe::file::JoinPath("./",
"mediapipe/calculators/tensorflow/" "mediapipe/calculators/tensorflow/"
@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session); VerifySignatureMap(session);
} }
@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
std::string serialized_graph_contents; std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") = runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session); VerifySignatureMap(session);
} }
@ -213,12 +217,12 @@ TEST_F(
} }
})", })",
calculator_options_->DebugString())); calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session); VerifySignatureMap(session);
} }
@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
} }
})", })",
calculator_options_->DebugString())); calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
auto run_status = runner.Run(); auto run_status = runner.Run();
EXPECT_THAT( EXPECT_THAT(
@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
} }
})", })",
calculator_options_->DebugString())); calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
std::string serialized_graph_contents; std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") = runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
auto run_status = runner.Run(); auto run_status = runner.Run();
EXPECT_THAT( EXPECT_THAT(
@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
} }
})", })",
calculator_options_->DebugString())); calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") = runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
std::string serialized_graph_contents; std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") = runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
auto run_status = runner.Run(); auto run_status = runner.Run();
EXPECT_THAT( EXPECT_THAT(
@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session); VerifySignatureMap(session);
} }

View File

@ -43,6 +43,11 @@ namespace mediapipe {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace { namespace {
constexpr char kSessionTag[] = "SESSION";
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
// Updates the graph nodes to use the device as specified by device_id. // Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) { for (auto& node : *graph_def->mutable_node()) {
@ -64,25 +69,26 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext); TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
bool has_exactly_one_model = bool has_exactly_one_model =
!options.graph_proto_path().empty() !options.graph_proto_path().empty()
? !(input_side_packets->HasTag("STRING_MODEL") | ? !(input_side_packets->HasTag(kStringModelTag) |
input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) input_side_packets->HasTag(kStringModelFilePathTag))
: (input_side_packets->HasTag("STRING_MODEL") ^ : (input_side_packets->HasTag(kStringModelTag) ^
input_side_packets->HasTag("STRING_MODEL_FILE_PATH")); input_side_packets->HasTag(kStringModelFilePathTag));
RET_CHECK(has_exactly_one_model) RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of graph_proto_path in options or " << "Must have exactly one of graph_proto_path in options or "
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH"; "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
if (input_side_packets->HasTag("STRING_MODEL")) { if (input_side_packets->HasTag(kStringModelTag)) {
input_side_packets->Tag("STRING_MODEL") input_side_packets->Tag(kStringModelTag)
.Set<std::string>( .Set<std::string>(
// String model from embedded path // String model from embedded path
); );
} else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) { } else if (input_side_packets->HasTag(kStringModelFilePathTag)) {
input_side_packets->Tag("STRING_MODEL_FILE_PATH") input_side_packets->Tag(kStringModelFilePathTag)
.Set<std::string>( .Set<std::string>(
// Filename of std::string model. // Filename of std::string model.
); );
} }
output_side_packets->Tag("SESSION").Set<TensorFlowSession>( output_side_packets->Tag(kSessionTag)
.Set<TensorFlowSession>(
// A TensorFlow model loaded and ready for use along with // A TensorFlow model loaded and ready for use along with
// a map from tags to tensor names. // a map from tags to tensor names.
); );
@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
session->session.reset(tf::NewSession(session_options)); session->session.reset(tf::NewSession(session_options));
std::string graph_def_serialized; std::string graph_def_serialized;
if (input_side_packets.HasTag("STRING_MODEL")) { if (input_side_packets.HasTag(kStringModelTag)) {
graph_def_serialized = graph_def_serialized =
input_side_packets.Tag("STRING_MODEL").Get<std::string>(); input_side_packets.Tag(kStringModelTag).Get<std::string>();
} else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) { } else if (input_side_packets.HasTag(kStringModelFilePathTag)) {
const std::string& frozen_graph = const std::string& frozen_graph =
input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get<std::string>(); input_side_packets.Tag(kStringModelFilePathTag).Get<std::string>();
RET_CHECK_OK( RET_CHECK_OK(
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized)); mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
} else { } else {
@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
} }
output_side_packets->Tag("SESSION") = Adopt(session.release()); output_side_packets->Tag(kSessionTag) = Adopt(session.release());
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds."; << " microseconds.";

View File

@ -37,6 +37,10 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() { std::string GetGraphDefPath() {
return mediapipe::file::JoinPath("./", return mediapipe::file::JoinPath("./",
"mediapipe/calculators/tensorflow/" "mediapipe/calculators/tensorflow/"
@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
void VerifySignatureMap(PacketSet* output_side_packets) { void VerifySignatureMap(PacketSet* output_side_packets) {
const TensorFlowSession& session = const TensorFlowSession& session =
output_side_packets->Tag("SESSION").Get<TensorFlowSession>(); output_side_packets->Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
generator_options_->clear_graph_proto_path(); generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL") = input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
absl::Status run_status = tool::RunGenerateAndValidateTypes( absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -196,7 +200,7 @@ TEST_F(
PacketSet output_side_packets( PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value()); tool::CreateTagMap({"SESSION:session"}).value());
generator_options_->clear_graph_proto_path(); generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL_FILE_PATH") = input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes( absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value());
PacketSet output_side_packets( PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value()); tool::CreateTagMap({"SESSION:session"}).value());
input_side_packets.Tag("STRING_MODEL_FILE_PATH") = input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes( absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
std::string serialized_graph_contents; std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") = input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") = input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes( absl::Status run_status = tool::RunGenerateAndValidateTypes(
@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
std::string serialized_graph_contents; std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents)); &serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") = input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents)); Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") = input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath())); Adopt(new std::string(GetGraphDefPath()));
generator_options_->clear_graph_proto_path(); generator_options_->clear_graph_proto_path();

View File

@ -31,6 +31,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models // Given the path to a directory containing multiple tensorflow saved models
@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>(); cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
} }
// A TensorFlow model loaded and ready for use along with tensor // A TensorFlow model loaded and ready for use along with tensor
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>(); cc->OutputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
output_signature.first, options)] = output_signature.second.name(); output_signature.first, options)] = output_signature.second.name();
} }
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -35,6 +35,9 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
constexpr char kSessionTag[] = "SESSION";
std::string GetSavedModelDir() { std::string GetSavedModelDir() {
std::string out_path = std::string out_path =
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString())); options_->DebugString()));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
} }
})", })",
options_->DebugString())); options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") = runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) =
MakePacket<std::string>(GetSavedModelDir()); MakePacket<std::string>(GetSavedModelDir());
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString())); options_->DebugString()));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString())); options_->DebugString()));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session = const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>(); runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices; std::vector<tensorflow::DeviceAttributes> devices;

View File

@ -33,6 +33,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models // Given the path to a directory containing multiple tensorflow saved models
@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>(); input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
} }
// A TensorFlow model loaded and ready for use along with tensor // A TensorFlow model loaded and ready for use along with tensor
output_side_packets->Tag("SESSION").Set<TensorFlowSession>(); output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
output_signature.first, options)] = output_signature.second.name(); output_signature.first, options)] = output_signature.second.name();
} }
output_side_packets->Tag("SESSION") = Adopt(session.release()); output_side_packets->Tag(kSessionTag) = Adopt(session.release());
return absl::OkStatus(); return absl::OkStatus();
} }
}; };

View File

@ -34,6 +34,9 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
constexpr char kSessionTag[] = "SESSION";
std::string GetSavedModelDir() { std::string GetSavedModelDir() {
std::string out_path = std::string out_path =
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/", file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets); input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message(); MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session = const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>(); output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
generator_options_->clear_saved_model_path(); generator_options_->clear_saved_model_path();
PacketSet input_side_packets( PacketSet input_side_packets(
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value()); tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value());
input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = input_side_packets.Tag(kStringSavedModelPathTag) =
Adopt(new std::string(GetSavedModelDir())); Adopt(new std::string(GetSavedModelDir()));
PacketSet output_side_packets( PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value()); tool::CreateTagMap({"SESSION:session"}).value());
@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets); input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message(); MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session = const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>(); output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets); input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message(); MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session = const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>(); output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
} }
@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets); input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message(); MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session = const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>(); output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set. // Session must be set.
ASSERT_NE(session.session, nullptr); ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices; std::vector<tensorflow::DeviceAttributes> devices;

View File

@ -33,6 +33,31 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence; namespace mpms = mediapipe::mediasequence;
constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE";
constexpr char kEncodedMediaStartTimestampTag[] =
"ENCODED_MEDIA_START_TIMESTAMP";
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS";
constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS";
constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS";
constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS";
constexpr char kDataPathTag[] = "DATA_PATH";
constexpr char kDatasetRootTag[] = "DATASET_ROOT";
constexpr char kMediaIdTag[] = "MEDIA_ID";
constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX";
constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG";
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
constexpr char kAudioTestTag[] = "AUDIO_TEST";
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
constexpr char kBboxPrefixTag[] = "BBOX_PREFIX";
constexpr char kKeypointsTag[] = "KEYPOINTS";
constexpr char kBboxTag[] = "BBOX";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kImageTag[] = "IMAGE";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
class UnpackMediaSequenceCalculatorTest : public ::testing::Test { class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpCalculator(const std::vector<std::string>& output_streams, void SetUpCalculator(const std::vector<std::string>& output_streams,
@ -95,13 +120,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) {
mpms::AddImageEncoded(test_image_string, input_sequence.get()); mpms::AddImageEncoded(test_image_string, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets; runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size()); ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
@ -124,13 +149,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) {
mpms::AddImageEncoded(test_image_string, input_sequence.get()); mpms::AddImageEncoded(test_image_string, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets; runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size()); ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
@ -154,13 +179,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) {
mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get()); mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE_PREFIX").packets; runner_->Outputs().Tag(kImagePrefixTag).packets;
ASSERT_EQ(num_images, output_packets.size()); ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
@ -182,12 +207,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) {
mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get()); mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
ASSERT_EQ(num_forward_flow_images, output_packets.size()); ASSERT_EQ(num_forward_flow_images, output_packets.size());
for (int i = 0; i < num_forward_flow_images; ++i) { for (int i = 0; i < num_forward_flow_images; ++i) {
@ -211,12 +236,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) {
mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get()); mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets; runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
ASSERT_EQ(num_forward_flow_images, output_packets.size()); ASSERT_EQ(num_forward_flow_images, output_packets.size());
for (int i = 0; i < num_forward_flow_images; ++i) { for (int i = 0; i < num_forward_flow_images; ++i) {
@ -240,13 +265,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) {
mpms::AddBBoxTimestamp(i, input_sequence.get()); mpms::AddBBoxTimestamp(i, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("BBOX").packets; runner_->Outputs().Tag(kBboxTag).packets;
ASSERT_EQ(bboxes.size(), output_packets.size()); ASSERT_EQ(bboxes.size(), output_packets.size());
for (int i = 0; i < bboxes.size(); ++i) { for (int i = 0; i < bboxes.size(); ++i) {
@ -274,13 +299,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) {
mpms::AddBBoxTimestamp(prefix, i, input_sequence.get()); mpms::AddBBoxTimestamp(prefix, i, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("BBOX_PREFIX").packets; runner_->Outputs().Tag(kBboxPrefixTag).packets;
ASSERT_EQ(bboxes.size(), output_packets.size()); ASSERT_EQ(bboxes.size(), output_packets.size());
for (int i = 0; i < bboxes.size(); ++i) { for (int i = 0; i < bboxes.size(); ++i) {
@ -306,13 +331,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get()); mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets; runner_->Outputs().Tag(kFloatFeatureTestTag).packets;
ASSERT_EQ(num_float_lists, output_packets.size()); ASSERT_EQ(num_float_lists, output_packets.size());
for (int i = 0; i < num_float_lists; ++i) { for (int i = 0; i < num_float_lists; ++i) {
@ -322,7 +347,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
} }
const std::vector<Packet>& output_packets_other = const std::vector<Packet>& output_packets_other =
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
ASSERT_EQ(num_float_lists, output_packets_other.size()); ASSERT_EQ(num_float_lists, output_packets_other.size());
for (int i = 0; i < num_float_lists; ++i) { for (int i = 0; i < num_float_lists; ++i) {
@ -352,12 +377,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get()); mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get());
} }
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets; runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size()); ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
@ -366,7 +391,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
} }
const std::vector<Packet>& output_packets_other = const std::vector<Packet>& output_packets_other =
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets; runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
ASSERT_EQ(num_float_lists, output_packets_other.size()); ASSERT_EQ(num_float_lists, output_packets_other.size());
for (int i = 0; i < num_float_lists; ++i) { for (int i = 0; i < num_float_lists; ++i) {
@ -389,12 +414,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get()); input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& fdense_avg_packets = const std::vector<Packet>& fdense_avg_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets; runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets;
ASSERT_EQ(fdense_avg_packets.size(), 1); ASSERT_EQ(fdense_avg_packets.size(), 1);
const auto& fdense_avg_vector = const auto& fdense_avg_vector =
fdense_avg_packets[0].Get<std::vector<float>>(); fdense_avg_packets[0].Get<std::vector<float>>();
@ -403,7 +428,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
::testing::Eq(Timestamp::PostStream())); ::testing::Eq(Timestamp::PostStream()));
const std::vector<Packet>& fdense_max_packets = const std::vector<Packet>& fdense_max_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
ASSERT_EQ(fdense_max_packets.size(), 1); ASSERT_EQ(fdense_max_packets.size(), 1);
const auto& fdense_max_vector = const auto& fdense_max_vector =
fdense_max_packets[0].Get<std::vector<float>>(); fdense_max_packets[0].Get<std::vector<float>>();
@ -430,13 +455,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get()); input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets; runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size()); ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
@ -463,13 +488,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(), mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get()); input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& fdense_max_packets = const std::vector<Packet>& fdense_max_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets; runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
ASSERT_EQ(fdense_max_packets.size(), 1); ASSERT_EQ(fdense_max_packets.size(), 1);
const auto& fdense_max_vector = const auto& fdense_max_vector =
fdense_max_packets[0].Get<std::vector<float>>(); fdense_max_packets[0].Get<std::vector<float>>();
@ -481,17 +506,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) { TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"}); SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
std::string root = "test_root"; std::string root = "test_root";
runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root); runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root);
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets() MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH") .Tag(kDataPathTag)
.ValidateAsType<std::string>()); .ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(), ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
root + "/" + data_path_); root + "/" + data_path_);
} }
@ -501,28 +526,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) {
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext) options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_dataset_root_directory(root); ->set_dataset_root_directory(root);
SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options); SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets() MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH") .Tag(kDataPathTag)
.ValidateAsType<std::string>()); .ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(), ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
root + "/" + data_path_); root + "/" + data_path_);
} }
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) { TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
SetUpCalculator({}, {"DATA_PATH:data_path"}); SetUpCalculator({}, {"DATA_PATH:data_path"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets() MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH") .Tag(kDataPathTag)
.ValidateAsType<std::string>()); .ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(), ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
data_path_); data_path_);
} }
@ -534,20 +559,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) {
->set_padding_after_label(2); ->set_padding_after_label(2);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options); &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets() MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.ValidateAsType<AudioDecoderOptions>()); .ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>() .Get<AudioDecoderOptions>()
.start_time(), .start_time(),
2.0, 1e-5); 2.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>() .Get<AudioDecoderOptions>()
.end_time(), .end_time(),
7.0, 1e-5); 7.0, 1e-5);
@ -563,20 +588,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) {
->set_force_decoding_from_start_of_media(true); ->set_force_decoding_from_start_of_media(true);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {}, SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options); &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets() MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.ValidateAsType<AudioDecoderOptions>()); .ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>() .Get<AudioDecoderOptions>()
.start_time(), .start_time(),
0.0, 1e-5); 0.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS") .Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>() .Get<AudioDecoderOptions>()
.end_time(), .end_time(),
7.0, 1e-5); 7.0, 1e-5);
@ -594,27 +619,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
->mutable_base_packet_resampler_options() ->mutable_base_packet_resampler_options()
->set_frame_rate(1.0); ->set_frame_rate(1.0);
SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options); SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets() MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS") .Tag(kResamplerOptionsTag)
.ValidateAsType<CalculatorOptions>()); .ValidateAsType<CalculatorOptions>());
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS") .Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>() .Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext) .GetExtension(PacketResamplerCalculatorOptions::ext)
.start_time(), .start_time(),
2000000, 1); 2000000, 1);
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS") .Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>() .Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext) .GetExtension(PacketResamplerCalculatorOptions::ext)
.end_time(), .end_time(),
7000000, 1); 7000000, 1);
EXPECT_NEAR(runner_->OutputSidePackets() EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS") .Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>() .Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext) .GetExtension(PacketResamplerCalculatorOptions::ext)
.frame_rate(), .frame_rate(),
@ -623,13 +648,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) { TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) {
SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"}); SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release()); Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run()); MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets() MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("IMAGE_FRAME_RATE") .Tag(kImageFrameRateTag)
.ValidateAsType<double>()); .ValidateAsType<double>());
EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get<double>(), EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get<double>(),
image_frame_rate_); image_frame_rate_);
} }

View File

@ -26,6 +26,10 @@ namespace {
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
constexpr char kSingleIntTag[] = "SINGLE_INT";
constexpr char kTensorOutTag[] = "TENSOR_OUT";
constexpr char kVectorIntTag[] = "VECTOR_INT";
class VectorIntToTensorCalculatorTest : public ::testing::Test { class VectorIntToTensorCalculatorTest : public ::testing::Test {
protected: protected:
void SetUpRunner( void SetUpRunner(
@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test {
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs() runner_->MutableInputs()
->Tag("VECTOR_INT") ->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time))); .packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets; runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>(); const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
tensorflow::DT_INT32, false, true); tensorflow::DT_INT32, false, true);
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs() runner_->MutableInputs()
->Tag("SINGLE_INT") ->Tag(kSingleIntTag)
.packets.push_back(MakePacket<int>(1).At(Timestamp(time))); .packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets; runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>(); const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
} }
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs() runner_->MutableInputs()
->Tag("VECTOR_INT") ->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time))); .packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets; runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>(); const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
tensorflow::DT_INT64, false, true); tensorflow::DT_INT64, false, true);
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs() runner_->MutableInputs()
->Tag("SINGLE_INT") ->Tag(kSingleIntTag)
.packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time))); .packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets; runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>(); const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
} }
const int64 time = 1234; const int64 time = 1234;
runner_->MutableInputs() runner_->MutableInputs()
->Tag("VECTOR_INT") ->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time))); .packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok()); EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets = const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets; runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value()); EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>(); const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();

View File

@ -18,6 +18,10 @@
namespace mediapipe { namespace mediapipe {
constexpr char kFloatsTag[] = "FLOATS";
constexpr char kFloatTag[] = "FLOAT";
constexpr char kTensorsTag[] = "TENSORS";
// A calculator for converting TFLite tensors to to a float or a float vector. // A calculator for converting TFLite tensors to to a float or a float vector.
// //
// Input: // Input:
@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
absl::Status TfLiteTensorsToFloatsCalculator::GetContract( absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("TENSORS")); RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
cc->Outputs().HasTag(kFloatTag));
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
if (cc->Outputs().HasTag("FLOATS")) { if (cc->Outputs().HasTag(kFloatsTag)) {
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>(); cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
} }
if (cc->Outputs().HasTag("FLOAT")) { if (cc->Outputs().HasTag(kFloatTag)) {
cc->Outputs().Tag("FLOAT").Set<float>(); cc->Outputs().Tag(kFloatTag).Set<float>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) {
} }
absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty());
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
// TODO: Add option to specify which tensor to take from. // TODO: Add option to specify which tensor to take from.
const TfLiteTensor* raw_tensor = &input_tensors[0]; const TfLiteTensor* raw_tensor = &input_tensors[0];
const float* raw_floats = raw_tensor->data.f; const float* raw_floats = raw_tensor->data.f;
@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
num_values *= raw_tensor->dims->data[i]; num_values *= raw_tensor->dims->data[i];
} }
if (cc->Outputs().HasTag("FLOAT")) { if (cc->Outputs().HasTag(kFloatTag)) {
// TODO: Could add an index in the option to specifiy returning one // TODO: Could add an index in the option to specifiy returning one
// value of a float array. // value of a float array.
RET_CHECK_EQ(num_values, 1); RET_CHECK_EQ(num_values, 1);
cc->Outputs().Tag("FLOAT").AddPacket( cc->Outputs().Tag(kFloatTag).AddPacket(
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp())); MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("FLOATS")) { if (cc->Outputs().HasTag(kFloatsTag)) {
auto output_floats = absl::make_unique<std::vector<float>>( auto output_floats = absl::make_unique<std::vector<float>>(
raw_floats, raw_floats + num_values); raw_floats, raw_floats + num_values);
cc->Outputs().Tag("FLOATS").Add(output_floats.release(), cc->Outputs()
cc->InputTimestamp()); .Tag(kFloatsTag)
.Add(output_floats.release(), cc->InputTimestamp());
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) {
// Initialize the clock. // Initialize the clock.
if (cc->InputSidePackets().HasTag(kClockTag)) { if (cc->InputSidePackets().HasTag(kClockTag)) {
clock_ = cc->InputSidePackets() clock_ = cc->InputSidePackets()
.Tag("CLOCK") .Tag(kClockTag)
.Get<std::shared_ptr<::mediapipe::Clock>>(); .Get<std::shared_ptr<::mediapipe::Clock>>();
} else { } else {
clock_.reset( clock_.reset(

View File

@ -27,6 +27,8 @@
namespace mediapipe { namespace mediapipe {
constexpr char kIterableTag[] = "ITERABLE";
typedef CollectionHasMinSizeCalculator<std::vector<int>> typedef CollectionHasMinSizeCalculator<std::vector<int>>
TestIntCollectionHasMinSizeCalculator; TestIntCollectionHasMinSizeCalculator;
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator); REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
void AddInputVector(const std::vector<int>& input, int64 timestamp, void AddInputVector(const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) { CalculatorRunner* runner) {
runner->MutableInputs() runner->MutableInputs()
->Tag("ITERABLE") ->Tag(kIterableTag)
.packets.push_back( .packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp))); MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
} }

View File

@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase {
} }
cc->Outputs() cc->Outputs()
.Tag("DETECTIONS") .Tag(kDetectionsTag)
.Add(output_detections.release(), cc->InputTimestamp()); .Add(output_detections.release(), cc->InputTimestamp());
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -25,6 +25,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kDetectionsTag[] = "DETECTIONS";
LocationData CreateRelativeLocationData(double xmin, double ymin, double width, LocationData CreateRelativeLocationData(double xmin, double ymin, double width,
double height) { double height) {
LocationData location_data; LocationData location_data;
@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) {
detections->push_back( detections->push_back(
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>( auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f}); std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
runner.MutableInputs() runner.MutableInputs()
->Tag("LETTERBOX_PADDING") ->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets; runner.Outputs().Tag(kDetectionsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& output_detections = output[0].Get<std::vector<Detection>>(); const auto& output_detections = output[0].Get<std::vector<Detection>>();
@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) {
detections->push_back( detections->push_back(
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag")); CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>( auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f}); std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f});
runner.MutableInputs() runner.MutableInputs()
->Tag("LETTERBOX_PADDING") ->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets; runner.Outputs().Tag(kDetectionsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& output_detections = output[0].Get<std::vector<Detection>>(); const auto& output_detections = output[0].Get<std::vector<Detection>>();

View File

@ -31,6 +31,9 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kDetectionsTag[] = "DETECTIONS";
using ::testing::ElementsAre; using ::testing::ElementsAre;
using ::testing::FloatNear; using ::testing::FloatNear;
@ -74,19 +77,19 @@ absl::StatusOr<Detection> RunProjectionCalculator(
)pb")); )pb"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back(MakePacket<std::vector<Detection>>( .packets.push_back(MakePacket<std::vector<Detection>>(
std::vector<Detection>({std::move(detection)})) std::vector<Detection>({std::move(detection)}))
.At(Timestamp::PostStream())); .At(Timestamp::PostStream()));
runner.MutableInputs() runner.MutableInputs()
->Tag("PROJECTION_MATRIX") ->Tag(kProjectionMatrixTag)
.packets.push_back( .packets.push_back(
MakePacket<std::array<float, 16>>(std::move(project_mat)) MakePacket<std::array<float, 16>>(std::move(project_mat))
.At(Timestamp::PostStream())); .At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run()); MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets; runner.Outputs().Tag(kDetectionsTag).packets;
RET_CHECK_EQ(output.size(), 1); RET_CHECK_EQ(output.size(), 1);
const auto& output_detections = output[0].Get<std::vector<Detection>>(); const auto& output_detections = output[0].Get<std::vector<Detection>>();

View File

@ -32,6 +32,14 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kRectsTag[] = "RECTS";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kRectTag[] = "RECT";
constexpr char kDetectionTag[] = "DETECTION";
MATCHER_P4(RectEq, x_center, y_center, width, height, "") { MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
return testing::Value(arg.x_center(), testing::Eq(x_center)) && return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
testing::Value(arg.y_center(), testing::Eq(y_center)) && testing::Value(arg.y_center(), testing::Eq(y_center)) &&
@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) {
DetectionWithLocationData(100, 200, 300, 400)); DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back( .packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream())); Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets; const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<Rect>(); const auto& rect = output[0].Get<Rect>();
EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
@ -120,16 +128,16 @@ absl::StatusOr<Rect> RunDetectionKeyPointsToRectCalculation(
)pb")); )pb"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(std::move(detection)) .packets.push_back(MakePacket<Detection>(std::move(detection))
.At(Timestamp::PostStream())); .At(Timestamp::PostStream()));
runner.MutableInputs() runner.MutableInputs()
->Tag("IMAGE_SIZE") ->Tag(kImageSizeTag)
.packets.push_back(MakePacket<std::pair<int, int>>(image_size) .packets.push_back(MakePacket<std::pair<int, int>>(image_size)
.At(Timestamp::PostStream())); .At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run()); MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets; const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
RET_CHECK_EQ(output.size(), 1); RET_CHECK_EQ(output.size(), 1);
return output[0].Get<Rect>(); return output[0].Get<Rect>();
} }
@ -176,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) {
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back( .packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream())); Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets; const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<NormalizedRect>(); const auto& rect = output[0].Get<NormalizedRect>();
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
@ -201,12 +210,13 @@ absl::StatusOr<NormalizedRect> RunDetectionKeyPointsToNormRectCalculation(
)pb")); )pb"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(std::move(detection)) .packets.push_back(MakePacket<Detection>(std::move(detection))
.At(Timestamp::PostStream())); .At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run()); MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets; const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
RET_CHECK_EQ(output.size(), 1); RET_CHECK_EQ(output.size(), 1);
return output[0].Get<NormalizedRect>(); return output[0].Get<NormalizedRect>();
} }
@ -248,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) {
detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets; const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<Rect>(); const auto& rect = output[0].Get<Rect>();
EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
@ -271,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) {
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets; const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<NormalizedRect>(); const auto& rect = output[0].Get<NormalizedRect>();
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
@ -294,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) {
detections->push_back(DetectionWithLocationData(200, 300, 400, 500)); detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets; const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<Rect>>(); const auto& rects = output[0].Get<std::vector<Rect>>();
ASSERT_EQ(rects.size(), 2); ASSERT_EQ(rects.size(), 2);
@ -319,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) {
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5)); detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("NORM_RECTS").packets; runner.Outputs().Tag(kNormRectsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<NormalizedRect>>(); const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
ASSERT_EQ(rects.size(), 2); ASSERT_EQ(rects.size(), 2);
@ -344,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) {
DetectionWithLocationData(100, 200, 300, 400)); DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back( .packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream())); Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets; const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<Rect>>(); const auto& rects = output[0].Get<std::vector<Rect>>();
EXPECT_EQ(rects.size(), 1); EXPECT_EQ(rects.size(), 1);
@ -367,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) {
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION") ->Tag(kDetectionTag)
.packets.push_back( .packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream())); Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("NORM_RECTS").packets; runner.Outputs().Tag(kNormRectsTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<NormalizedRect>>(); const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
ASSERT_EQ(rects.size(), 1); ASSERT_EQ(rects.size(), 1);
@ -391,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) {
detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4)); detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
@ -411,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) {
detections->push_back(DetectionWithLocationData(100, 200, 300, 400)); detections->push_back(DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));

View File

@ -30,6 +30,10 @@
namespace mediapipe { namespace mediapipe {
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kDetectionListTag[] = "DETECTION_LIST";
using ::testing::DoubleNear; using ::testing::DoubleNear;
// Error tolerance for pixels, distances, etc. // Error tolerance for pixels, distances, etc.
@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) {
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"); CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag");
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION_LIST") ->Tag(kDetectionListTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("RENDER_DATA").packets; runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& actual = output[0].Get<RenderData>(); const auto& actual = output[0].Get<RenderData>();
EXPECT_EQ(actual.render_annotations_size(), 3); EXPECT_EQ(actual.render_annotations_size(), 3);
@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) {
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag")); CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = const std::vector<Packet>& output =
runner.Outputs().Tag("RENDER_DATA").packets; runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& actual = output[0].Get<RenderData>(); const auto& actual = output[0].Get<RenderData>();
EXPECT_EQ(actual.render_annotations_size(), 3); EXPECT_EQ(actual.render_annotations_size(), 3);
@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
*(detection_list->add_detection()) = *(detection_list->add_detection()) =
CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1"); CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1");
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTION_LIST") ->Tag(kDetectionListTag)
.packets.push_back( .packets.push_back(
Adopt(detection_list.release()).At(Timestamp::PostStream())); Adopt(detection_list.release()).At(Timestamp::PostStream()));
@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
detections->push_back( detections->push_back(
CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2")); CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2"));
runner.MutableInputs() runner.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream())); Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& actual = const std::vector<Packet>& actual =
runner.Outputs().Tag("RENDER_DATA").packets; runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, actual.size()); ASSERT_EQ(1, actual.size());
// Check the feature tag for item from detection list. // Check the feature tag for item from detection list.
EXPECT_EQ( EXPECT_EQ(
@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
auto detection_list1(absl::make_unique<DetectionList>()); auto detection_list1(absl::make_unique<DetectionList>());
runner1.MutableInputs() runner1.MutableInputs()
->Tag("DETECTION_LIST") ->Tag(kDetectionListTag)
.packets.push_back( .packets.push_back(
Adopt(detection_list1.release()).At(Timestamp::PostStream())); Adopt(detection_list1.release()).At(Timestamp::PostStream()));
auto detections1(absl::make_unique<std::vector<Detection>>()); auto detections1(absl::make_unique<std::vector<Detection>>());
runner1.MutableInputs() runner1.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections1.release()).At(Timestamp::PostStream())); Adopt(detections1.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed.";
const std::vector<Packet>& exact1 = const std::vector<Packet>& exact1 =
runner1.Outputs().Tag("RENDER_DATA").packets; runner1.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(0, exact1.size()); ASSERT_EQ(0, exact1.size());
// Check when produce_empty_packet is true. // Check when produce_empty_packet is true.
@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
auto detection_list2(absl::make_unique<DetectionList>()); auto detection_list2(absl::make_unique<DetectionList>());
runner2.MutableInputs() runner2.MutableInputs()
->Tag("DETECTION_LIST") ->Tag(kDetectionListTag)
.packets.push_back( .packets.push_back(
Adopt(detection_list2.release()).At(Timestamp::PostStream())); Adopt(detection_list2.release()).At(Timestamp::PostStream()));
auto detections2(absl::make_unique<std::vector<Detection>>()); auto detections2(absl::make_unique<std::vector<Detection>>());
runner2.MutableInputs() runner2.MutableInputs()
->Tag("DETECTIONS") ->Tag(kDetectionsTag)
.packets.push_back( .packets.push_back(
Adopt(detections2.release()).At(Timestamp::PostStream())); Adopt(detections2.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed.";
const std::vector<Packet>& exact2 = const std::vector<Packet>& exact2 =
runner2.Outputs().Tag("RENDER_DATA").packets; runner2.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, exact2.size()); ASSERT_EQ(1, exact2.size());
EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0); EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0);
} }

View File

@ -32,6 +32,12 @@
namespace mediapipe { namespace mediapipe {
constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kScoresTag[] = "SCORES";
constexpr char kLabelsTag[] = "LABELS";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr float kFontHeightScale = 1.25f; constexpr float kFontHeightScale = 1.25f;
// A calculator takes in pairs of labels and scores or classifications, outputs // A calculator takes in pairs of labels and scores or classifications, outputs
@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase {
REGISTER_CALCULATOR(LabelsToRenderDataCalculator); REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) { absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("CLASSIFICATIONS")) { if (cc->Inputs().HasTag(kClassificationsTag)) {
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>(); cc->Inputs().Tag(kClassificationsTag).Set<ClassificationList>();
} else { } else {
RET_CHECK(cc->Inputs().HasTag("LABELS")) RET_CHECK(cc->Inputs().HasTag(kLabelsTag))
<< "Must provide input stream \"LABELS\""; << "Must provide input stream \"LABELS\"";
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>(); cc->Inputs().Tag(kLabelsTag).Set<std::vector<std::string>>();
if (cc->Inputs().HasTag("SCORES")) { if (cc->Inputs().HasTag(kScoresTag)) {
cc->Inputs().Tag("SCORES").Set<std::vector<float>>(); cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
} }
} }
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>(); cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
} }
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>(); cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
} }
absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") && if (cc->Inputs().HasTag(kVideoPrestreamTag) &&
cc->InputTimestamp() == Timestamp::PreStream()) { cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header = const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>(); cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
video_width_ = video_header.width; video_width_ = video_header.width;
video_height_ = video_header.height; video_height_ = video_header.height;
return absl::OkStatus(); return absl::OkStatus();
@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
std::vector<std::string> labels; std::vector<std::string> labels;
std::vector<float> scores; std::vector<float> scores;
if (cc->Inputs().HasTag("CLASSIFICATIONS")) { if (cc->Inputs().HasTag(kClassificationsTag)) {
const ClassificationList& classifications = const ClassificationList& classifications =
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>(); cc->Inputs().Tag(kClassificationsTag).Get<ClassificationList>();
labels.resize(classifications.classification_size()); labels.resize(classifications.classification_size());
scores.resize(classifications.classification_size()); scores.resize(classifications.classification_size());
for (int i = 0; i < classifications.classification_size(); ++i) { for (int i = 0; i < classifications.classification_size(); ++i) {
@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
} }
} else { } else {
const std::vector<std::string>& label_vector = const std::vector<std::string>& label_vector =
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>(); cc->Inputs().Tag(kLabelsTag).Get<std::vector<std::string>>();
labels.resize(label_vector.size()); labels.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) { for (int i = 0; i < label_vector.size(); ++i) {
labels[i] = label_vector[i]; labels[i] = label_vector[i];
} }
if (cc->Inputs().HasTag("SCORES")) { if (cc->Inputs().HasTag(kScoresTag)) {
std::vector<float> score_vector = std::vector<float> score_vector =
cc->Inputs().Tag("SCORES").Get<std::vector<float>>(); cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
CHECK_EQ(label_vector.size(), score_vector.size()); CHECK_EQ(label_vector.size(), score_vector.size());
scores.resize(label_vector.size()); scores.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) { for (int i = 0; i < label_vector.size(); ++i) {
@ -169,7 +175,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* text = label_annotation->mutable_text(); auto* text = label_annotation->mutable_text();
std::string display_text = labels[i]; std::string display_text = labels[i];
if (cc->Inputs().HasTag("SCORES")) { if (cc->Inputs().HasTag(kScoresTag)) {
absl::StrAppend(&display_text, ":", scores[i]); absl::StrAppend(&display_text, ":", scores[i]);
} }
text->set_display_text(display_text); text->set_display_text(display_text);
@ -179,7 +185,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_font_face(options_.font_face()); text->set_font_face(options_.font_face());
} }
cc->Outputs() cc->Outputs()
.Tag("RENDER_DATA") .Tag(kRenderDataTag)
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp())); .AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
return absl::OkStatus(); return absl::OkStatus();

View File

@ -24,6 +24,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kLandmarksTag[] = "LANDMARKS";
NormalizedLandmark CreateLandmark(float x, float y) { NormalizedLandmark CreateLandmark(float x, float y) {
NormalizedLandmark landmark; NormalizedLandmark landmark;
landmark.set_x(x); landmark.set_x(x);
@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) {
*landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f); *landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f);
*landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f); *landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f);
runner.MutableInputs() runner.MutableInputs()
->Tag("LANDMARKS") ->Tag(kLandmarksTag)
.packets.push_back( .packets.push_back(
Adopt(landmarks.release()).At(Timestamp::PostStream())); Adopt(landmarks.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>( auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f}); std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
runner.MutableInputs() runner.MutableInputs()
->Tag("LETTERBOX_PADDING") ->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets; const std::vector<Packet>& output =
runner.Outputs().Tag(kLandmarksTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>(); const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) {
landmark = landmarks->add_landmark(); landmark = landmarks->add_landmark();
*landmark = CreateLandmark(0.7f, 0.7f); *landmark = CreateLandmark(0.7f, 0.7f);
runner.MutableInputs() runner.MutableInputs()
->Tag("LANDMARKS") ->Tag(kLandmarksTag)
.packets.push_back( .packets.push_back(
Adopt(landmarks.release()).At(Timestamp::PostStream())); Adopt(landmarks.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>( auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f}); std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f});
runner.MutableInputs() runner.MutableInputs()
->Tag("LETTERBOX_PADDING") ->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream())); .packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets; const std::vector<Packet>& output =
runner.Outputs().Tag(kLandmarksTag).packets;
ASSERT_EQ(1, output.size()); ASSERT_EQ(1, output.size());
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>(); const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();

View File

@ -16,6 +16,10 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator( absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) { mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
mediapipe::CalculatorRunner runner( mediapipe::CalculatorRunner runner(
@ -26,17 +30,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
output_stream: "NORM_LANDMARKS:projected_landmarks" output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb")); )pb"));
runner.MutableInputs() runner.MutableInputs()
->Tag("NORM_LANDMARKS") ->Tag(kNormLandmarksTag)
.packets.push_back( .packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input)) MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1))); .At(Timestamp(1)));
runner.MutableInputs() runner.MutableInputs()
->Tag("NORM_RECT") ->Tag(kNormRectTag)
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect)) .packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
.At(Timestamp(1))); .At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run()); MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1); RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>(); return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
} }
@ -104,17 +108,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
output_stream: "NORM_LANDMARKS:projected_landmarks" output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb")); )pb"));
runner.MutableInputs() runner.MutableInputs()
->Tag("NORM_LANDMARKS") ->Tag(kNormLandmarksTag)
.packets.push_back( .packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input)) MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1))); .At(Timestamp(1)));
runner.MutableInputs() runner.MutableInputs()
->Tag("PROJECTION_MATRIX") ->Tag(kProjectionMatrixTag)
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix)) .packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
.At(Timestamp(1))); .At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run()); MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets; const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1); RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>(); return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
} }

View File

@ -20,6 +20,11 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
namespace mediapipe { namespace mediapipe {
constexpr char kContentsTag[] = "CONTENTS";
constexpr char kFileSuffixTag[] = "FILE_SUFFIX";
constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY";
// The calculator takes the path to local directory and desired file suffix to // The calculator takes the path to local directory and desired file suffix to
// mach as input side packets, and outputs the contents of those files that // mach as input side packets, and outputs the contents of those files that
// match the pattern. Those matched files will be sent sequentially through the // match the pattern. Those matched files will be sent sequentially through the
@ -35,16 +40,16 @@ namespace mediapipe {
class LocalFilePatternContentsCalculator : public CalculatorBase { class LocalFilePatternContentsCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("FILE_DIRECTORY").Set<std::string>(); cc->InputSidePackets().Tag(kFileDirectoryTag).Set<std::string>();
cc->InputSidePackets().Tag("FILE_SUFFIX").Set<std::string>(); cc->InputSidePackets().Tag(kFileSuffixTag).Set<std::string>();
cc->Outputs().Tag("CONTENTS").Set<std::string>(); cc->Outputs().Tag(kContentsTag).Set<std::string>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory( MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory(
cc->InputSidePackets().Tag("FILE_DIRECTORY").Get<std::string>(), cc->InputSidePackets().Tag(kFileDirectoryTag).Get<std::string>(),
cc->InputSidePackets().Tag("FILE_SUFFIX").Get<std::string>(), cc->InputSidePackets().Tag(kFileSuffixTag).Get<std::string>(),
&filenames_)); &filenames_));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase {
filenames_[current_output_], contents.get())); filenames_[current_output_], contents.get()));
++current_output_; ++current_output_;
cc->Outputs() cc->Outputs()
.Tag("CONTENTS") .Tag(kContentsTag)
.Add(contents.release(), Timestamp(current_output_)); .Add(contents.release(), Timestamp(current_output_));
} else { } else {
return tool::StatusStop(); return tool::StatusStop();

View File

@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) {
// Initialize the clock. // Initialize the clock.
if (cc->InputSidePackets().HasTag(kClockTag)) { if (cc->InputSidePackets().HasTag(kClockTag)) {
clock_ = cc->InputSidePackets() clock_ = cc->InputSidePackets()
.Tag("CLOCK") .Tag(kClockTag)
.Get<std::shared_ptr<::mediapipe::Clock>>(); .Get<std::shared_ptr<::mediapipe::Clock>>();
} else { } else {
clock_ = std::shared_ptr<::mediapipe::Clock>( clock_ = std::shared_ptr<::mediapipe::Clock>(

View File

@ -17,6 +17,12 @@
namespace mediapipe { namespace mediapipe {
constexpr char kThresholdTag[] = "THRESHOLD";
constexpr char kRejectTag[] = "REJECT";
constexpr char kAcceptTag[] = "ACCEPT";
constexpr char kFlagTag[] = "FLAG";
constexpr char kFloatTag[] = "FLOAT";
// Applies a threshold on a stream of numeric values and outputs a flag and/or // Applies a threshold on a stream of numeric values and outputs a flag and/or
// accept/reject stream. The threshold can be specified by one of the following: // accept/reject stream. The threshold can be specified by one of the following:
// 1) Input stream. // 1) Input stream.
@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase {
REGISTER_CALCULATOR(ThresholdingCalculator); REGISTER_CALCULATOR(ThresholdingCalculator);
absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("FLOAT")); RET_CHECK(cc->Inputs().HasTag(kFloatTag));
cc->Inputs().Tag("FLOAT").Set<float>(); cc->Inputs().Tag(kFloatTag).Set<float>();
if (cc->Outputs().HasTag("FLAG")) { if (cc->Outputs().HasTag(kFlagTag)) {
cc->Outputs().Tag("FLAG").Set<bool>(); cc->Outputs().Tag(kFlagTag).Set<bool>();
} }
if (cc->Outputs().HasTag("ACCEPT")) { if (cc->Outputs().HasTag(kAcceptTag)) {
cc->Outputs().Tag("ACCEPT").Set<bool>(); cc->Outputs().Tag(kAcceptTag).Set<bool>();
} }
if (cc->Outputs().HasTag("REJECT")) { if (cc->Outputs().HasTag(kRejectTag)) {
cc->Outputs().Tag("REJECT").Set<bool>(); cc->Outputs().Tag(kRejectTag).Set<bool>();
} }
if (cc->Inputs().HasTag("THRESHOLD")) { if (cc->Inputs().HasTag(kThresholdTag)) {
cc->Inputs().Tag("THRESHOLD").Set<double>(); cc->Inputs().Tag(kThresholdTag).Set<double>();
} }
if (cc->InputSidePackets().HasTag("THRESHOLD")) { if (cc->InputSidePackets().HasTag(kThresholdTag)) {
cc->InputSidePackets().Tag("THRESHOLD").Set<double>(); cc->InputSidePackets().Tag(kThresholdTag).Set<double>();
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
<< "Using both the threshold input side packet and input stream is not " << "Using both the threshold input side packet and input stream is not "
"supported."; "supported.";
} }
@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
const auto& options = const auto& options =
cc->Options<::mediapipe::ThresholdingCalculatorOptions>(); cc->Options<::mediapipe::ThresholdingCalculatorOptions>();
if (options.has_threshold()) { if (options.has_threshold()) {
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD")) RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
<< "Using both the threshold option and input stream is not supported."; << "Using both the threshold option and input stream is not supported.";
RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD")) RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag))
<< "Using both the threshold option and input side packet is not " << "Using both the threshold option and input side packet is not "
"supported."; "supported.";
threshold_ = options.threshold(); threshold_ = options.threshold();
} }
if (cc->InputSidePackets().HasTag("THRESHOLD")) { if (cc->InputSidePackets().HasTag(kThresholdTag)) {
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>(); threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get<double>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) { absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("THRESHOLD") && if (cc->Inputs().HasTag(kThresholdTag) &&
!cc->Inputs().Tag("THRESHOLD").IsEmpty()) { !cc->Inputs().Tag(kThresholdTag).IsEmpty()) {
threshold_ = cc->Inputs().Tag("THRESHOLD").Get<double>(); threshold_ = cc->Inputs().Tag(kThresholdTag).Get<double>();
} }
bool accept = false; bool accept = false;
RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty());
accept = accept = static_cast<double>(cc->Inputs().Tag(kFloatTag).Get<float>()) >
static_cast<double>(cc->Inputs().Tag("FLOAT").Get<float>()) > threshold_; threshold_;
if (cc->Outputs().HasTag("FLAG")) { if (cc->Outputs().HasTag(kFlagTag)) {
cc->Outputs().Tag("FLAG").AddPacket( cc->Outputs().Tag(kFlagTag).AddPacket(
MakePacket<bool>(accept).At(cc->InputTimestamp())); MakePacket<bool>(accept).At(cc->InputTimestamp()));
} }
if (accept && cc->Outputs().HasTag("ACCEPT")) { if (accept && cc->Outputs().HasTag(kAcceptTag)) {
cc->Outputs().Tag("ACCEPT").AddPacket( cc->Outputs()
MakePacket<bool>(true).At(cc->InputTimestamp())); .Tag(kAcceptTag)
.AddPacket(MakePacket<bool>(true).At(cc->InputTimestamp()));
} }
if (!accept && cc->Outputs().HasTag("REJECT")) { if (!accept && cc->Outputs().HasTag(kRejectTag)) {
cc->Outputs().Tag("REJECT").AddPacket( cc->Outputs()
MakePacket<bool>(false).At(cc->InputTimestamp())); .Tag(kRejectTag)
.AddPacket(MakePacket<bool>(false).At(cc->InputTimestamp()));
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -39,6 +39,14 @@
namespace mediapipe { namespace mediapipe {
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
constexpr char kSummaryTag[] = "SUMMARY";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kScoresTag[] = "SCORES";
// A calculator that takes a vector of scores and returns the indexes, scores, // A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements, classification protos, and summary std::string // labels of the top k elements, classification protos, and summary std::string
// (in csv format). // (in csv format).
@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase {
REGISTER_CALCULATOR(TopKScoresCalculator); REGISTER_CALCULATOR(TopKScoresCalculator);
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("SCORES")); RET_CHECK(cc->Inputs().HasTag(kScoresTag));
cc->Inputs().Tag("SCORES").Set<std::vector<float>>(); cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
if (cc->Outputs().HasTag("TOP_K_INDEXES")) { if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs().Tag("TOP_K_INDEXES").Set<std::vector<int>>(); cc->Outputs().Tag(kTopKIndexesTag).Set<std::vector<int>>();
} }
if (cc->Outputs().HasTag("TOP_K_SCORES")) { if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs().Tag("TOP_K_SCORES").Set<std::vector<float>>(); cc->Outputs().Tag(kTopKScoresTag).Set<std::vector<float>>();
} }
if (cc->Outputs().HasTag("TOP_K_LABELS")) { if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>(); cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
} }
if (cc->Outputs().HasTag("CLASSIFICATIONS")) { if (cc->Outputs().HasTag(kClassificationsTag)) {
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>(); cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
} }
if (cc->Outputs().HasTag("SUMMARY")) { if (cc->Outputs().HasTag(kSummaryTag)) {
cc->Outputs().Tag("SUMMARY").Set<std::string>(); cc->Outputs().Tag(kSummaryTag).Set<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
if (options.has_label_map_path()) { if (options.has_label_map_path()) {
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path())); MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
} }
if (cc->Outputs().HasTag("TOP_K_LABELS")) { if (cc->Outputs().HasTag(kTopKLabelsTag)) {
RET_CHECK(!label_map_.empty()); RET_CHECK(!label_map_.empty());
} }
return absl::OkStatus(); return absl::OkStatus();
@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
const std::vector<float>& input_vector = const std::vector<float>& input_vector =
cc->Inputs().Tag("SCORES").Get<std::vector<float>>(); cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
std::vector<int> top_k_indexes; std::vector<int> top_k_indexes;
std::vector<float> top_k_scores; std::vector<float> top_k_scores;
@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
top_k_labels.push_back(label_map_[index]); top_k_labels.push_back(label_map_[index]);
} }
} }
if (cc->Outputs().HasTag("TOP_K_INDEXES")) { if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs() cc->Outputs()
.Tag("TOP_K_INDEXES") .Tag(kTopKIndexesTag)
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes) .AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("TOP_K_SCORES")) { if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs() cc->Outputs()
.Tag("TOP_K_SCORES") .Tag(kTopKScoresTag)
.AddPacket(MakePacket<std::vector<float>>(top_k_scores) .AddPacket(MakePacket<std::vector<float>>(top_k_scores)
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("TOP_K_LABELS")) { if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs() cc->Outputs()
.Tag("TOP_K_LABELS") .Tag(kTopKLabelsTag)
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels) .AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("SUMMARY")) { if (cc->Outputs().HasTag(kSummaryTag)) {
std::vector<std::string> results; std::vector<std::string> results;
for (int index = 0; index < top_k_indexes.size(); ++index) { for (int index = 0; index < top_k_indexes.size(); ++index) {
if (label_map_loaded_) { if (label_map_loaded_) {
@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index])); absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
} }
} }
cc->Outputs().Tag("SUMMARY").AddPacket( cc->Outputs()
MakePacket<std::string>(absl::StrJoin(results, ",")) .Tag(kSummaryTag)
.AddPacket(MakePacket<std::string>(absl::StrJoin(results, ","))
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} }
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) { if (cc->Outputs().HasTag(kTopKClassificationTag)) {
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) { for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification = Classification* classification =

View File

@ -23,6 +23,10 @@
namespace mediapipe { namespace mediapipe {
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kScoresTag[] = "SCORES";
TEST(TopKScoresCalculatorTest, TestNodeConfig) { TEST(TopKScoresCalculatorTest, TestNodeConfig) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TopKScoresCalculator" calculator: "TopKScoresCalculator"
@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back( runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs = const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets; runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size()); ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>(); const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, indexes.size()); EXPECT_EQ(2, indexes.size());
EXPECT_EQ(3, indexes[0]); EXPECT_EQ(3, indexes[0]);
EXPECT_EQ(0, indexes[1]); EXPECT_EQ(0, indexes[1]);
const std::vector<Packet>& scores_outputs = const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets; runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size()); ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>(); const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(2, scores.size()); EXPECT_EQ(2, scores.size());
@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back( runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs = const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets; runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size()); ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>(); const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(4, indexes.size()); EXPECT_EQ(4, indexes.size());
@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
EXPECT_EQ(2, indexes[2]); EXPECT_EQ(2, indexes[2]);
EXPECT_EQ(1, indexes[3]); EXPECT_EQ(1, indexes[3]);
const std::vector<Packet>& scores_outputs = const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets; runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size()); ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>(); const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(4, scores.size()); EXPECT_EQ(4, scores.size());
@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1}; std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back( runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0))); MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs = const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets; runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size()); ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>(); const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(3, indexes.size()); EXPECT_EQ(3, indexes.size());
@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
EXPECT_EQ(0, indexes[1]); EXPECT_EQ(0, indexes[1]);
EXPECT_EQ(2, indexes[2]); EXPECT_EQ(2, indexes[2]);
const std::vector<Packet>& scores_outputs = const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets; runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size()); ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>(); const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(3, scores.size()); EXPECT_EQ(3, scores.size());

View File

@ -47,6 +47,21 @@
namespace mediapipe { namespace mediapipe {
constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT";
constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME";
constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING";
constexpr char kVizTag[] = "VIZ";
constexpr char kBoxesTag[] = "BOXES";
constexpr char kReacqSwitchTag[] = "REACQ_SWITCH";
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
constexpr char kAddIndexTag[] = "ADD_INDEX";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kDescriptorsTag[] = "DESCRIPTORS";
constexpr char kFeaturesTag[] = "FEATURES";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES";
constexpr char kTrackingTag[] = "TRACKING";
// A calculator to detect reappeared box positions from single frame. // A calculator to detect reappeared box positions from single frame.
// //
// Input stream: // Input stream:
@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase {
REGISTER_CALCULATOR(BoxDetectorCalculator); REGISTER_CALCULATOR(BoxDetectorCalculator);
absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("TRACKING")) { if (cc->Inputs().HasTag(kTrackingTag)) {
cc->Inputs().Tag("TRACKING").Set<TrackingData>(); cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
} }
if (cc->Inputs().HasTag("TRACKED_BOXES")) { if (cc->Inputs().HasTag(kTrackedBoxesTag)) {
cc->Inputs().Tag("TRACKED_BOXES").Set<TimedBoxProtoList>(); cc->Inputs().Tag(kTrackedBoxesTag).Set<TimedBoxProtoList>();
} }
if (cc->Inputs().HasTag("VIDEO")) { if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>(); cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
} }
if (cc->Inputs().HasTag("FEATURES")) { if (cc->Inputs().HasTag(kFeaturesTag)) {
RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS")) RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag))
<< "FEATURES and DESCRIPTORS need to be specified together."; << "FEATURES and DESCRIPTORS need to be specified together.";
cc->Inputs().Tag("FEATURES").Set<std::vector<cv::KeyPoint>>(); cc->Inputs().Tag(kFeaturesTag).Set<std::vector<cv::KeyPoint>>();
} }
if (cc->Inputs().HasTag("DESCRIPTORS")) { if (cc->Inputs().HasTag(kDescriptorsTag)) {
RET_CHECK(cc->Inputs().HasTag("FEATURES")) RET_CHECK(cc->Inputs().HasTag(kFeaturesTag))
<< "FEATURES and DESCRIPTORS need to be specified together."; << "FEATURES and DESCRIPTORS need to be specified together.";
cc->Inputs().Tag("DESCRIPTORS").Set<std::vector<float>>(); cc->Inputs().Tag(kDescriptorsTag).Set<std::vector<float>>();
} }
if (cc->Inputs().HasTag("IMAGE_SIZE")) { if (cc->Inputs().HasTag(kImageSizeTag)) {
cc->Inputs().Tag("IMAGE_SIZE").Set<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
} }
if (cc->Inputs().HasTag("ADD_INDEX")) { if (cc->Inputs().HasTag(kAddIndexTag)) {
cc->Inputs().Tag("ADD_INDEX").Set<std::string>(); cc->Inputs().Tag(kAddIndexTag).Set<std::string>();
} }
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>(); cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
} }
if (cc->Inputs().HasTag("REACQ_SWITCH")) { if (cc->Inputs().HasTag(kReacqSwitchTag)) {
cc->Inputs().Tag("REACQ_SWITCH").Set<bool>(); cc->Inputs().Tag(kReacqSwitchTag).Set<bool>();
} }
if (cc->Outputs().HasTag("BOXES")) { if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>(); cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
} }
if (cc->Outputs().HasTag("VIZ")) { if (cc->Outputs().HasTag(kVizTag)) {
RET_CHECK(cc->Inputs().HasTag("VIDEO")) RET_CHECK(cc->Inputs().HasTag(kVideoTag))
<< "Output stream VIZ requires VIDEO to be present."; << "Output stream VIZ requires VIDEO to be present.";
cc->Outputs().Tag("VIZ").Set<ImageFrame>(); cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
} }
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set<std::string>(); cc->InputSidePackets().Tag(kIndexProtoStringTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set<std::string>(); cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set<int>(); cc->InputSidePackets().Tag(kFrameAlignmentTag).Set<int>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<BoxDetectorCalculatorOptions>(); options_ = cc->Options<BoxDetectorCalculatorOptions>();
box_detector_ = BoxDetectorInterface::Create(options_.detector_options()); box_detector_ = BoxDetectorInterface::Create(options_.detector_options());
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) { if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
BoxDetectorIndex predefined_index; BoxDetectorIndex predefined_index;
if (!predefined_index.ParseFromString(cc->InputSidePackets() if (!predefined_index.ParseFromString(cc->InputSidePackets()
.Tag("INDEX_PROTO_STRING") .Tag(kIndexProtoStringTag)
.Get<std::string>())) { .Get<std::string>())) {
LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING"; LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING";
} }
@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
box_detector_->AddBoxDetectorIndex(predefined_index); box_detector_->AddBoxDetectorIndex(predefined_index);
} }
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) { if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
write_index_ = true; write_index_ = true;
} }
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) { if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get<int>(); frame_alignment_ =
cc->InputSidePackets().Tag(kFrameAlignmentTag).Get<int>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
const int64 timestamp_msec = timestamp.Value() / 1000; const int64 timestamp_msec = timestamp.Value() / 1000;
InputStream* cancel_object_id_stream = InputStream* cancel_object_id_stream =
cc->Inputs().HasTag("CANCEL_OBJECT_ID") cc->Inputs().HasTag(kCancelObjectIdTag)
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) ? &(cc->Inputs().Tag(kCancelObjectIdTag))
: nullptr; : nullptr;
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
const int cancel_object_id = cancel_object_id_stream->Get<int>(); const int cancel_object_id = cancel_object_id_stream->Get<int>();
box_detector_->CancelBoxDetection(cancel_object_id); box_detector_->CancelBoxDetection(cancel_object_id);
} }
InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX") InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag)
? &(cc->Inputs().Tag("ADD_INDEX")) ? &(cc->Inputs().Tag(kAddIndexTag))
: nullptr; : nullptr;
if (add_index_stream && !add_index_stream->IsEmpty()) { if (add_index_stream && !add_index_stream->IsEmpty()) {
BoxDetectorIndex predefined_index; BoxDetectorIndex predefined_index;
@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
box_detector_->AddBoxDetectorIndex(predefined_index); box_detector_->AddBoxDetectorIndex(predefined_index);
} }
InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH") InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag)
? &(cc->Inputs().Tag("REACQ_SWITCH")) ? &(cc->Inputs().Tag(kReacqSwitchTag))
: nullptr; : nullptr;
if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) { if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) {
detector_switch_ = reacq_switch_stream->Get<bool>(); detector_switch_ = reacq_switch_stream->Get<bool>();
@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
InputStream* track_stream = cc->Inputs().HasTag("TRACKING") InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
? &(cc->Inputs().Tag("TRACKING")) ? &(cc->Inputs().Tag(kTrackingTag))
: nullptr; : nullptr;
InputStream* video_stream = InputStream* video_stream =
cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr; cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
InputStream* feature_stream = cc->Inputs().HasTag("FEATURES") InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag)
? &(cc->Inputs().Tag("FEATURES")) ? &(cc->Inputs().Tag(kFeaturesTag))
: nullptr; : nullptr;
InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS") InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag)
? &(cc->Inputs().Tag("DESCRIPTORS")) ? &(cc->Inputs().Tag(kDescriptorsTag))
: nullptr; : nullptr;
CHECK(track_stream != nullptr || video_stream != nullptr || CHECK(track_stream != nullptr || video_stream != nullptr ||
@ -266,8 +282,9 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
<< "One and only one of {tracking_data, input image frame, " << "One and only one of {tracking_data, input image frame, "
"feature/descriptor} need to be valid."; "feature/descriptor} need to be valid.";
InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES") InputStream* tracked_boxes_stream =
? &(cc->Inputs().Tag("TRACKED_BOXES")) cc->Inputs().HasTag(kTrackedBoxesTag)
? &(cc->Inputs().Tag(kTrackedBoxesTag))
: nullptr; : nullptr;
std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList()); std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList());
@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
} }
const auto& image_size = const auto& image_size =
cc->Inputs().Tag("IMAGE_SIZE").Get<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
float inv_scale = 1.0f / std::max(image_size.first, image_size.second); float inv_scale = 1.0f / std::max(image_size.first, image_size.second);
TimedBoxProtoList tracked_boxes; TimedBoxProtoList tracked_boxes;
@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
detected_boxes.get()); detected_boxes.get());
} }
if (cc->Outputs().HasTag("VIZ")) { if (cc->Outputs().HasTag(kVizTag)) {
cv::Mat viz_view; cv::Mat viz_view;
std::unique_ptr<ImageFrame> viz_frame; std::unique_ptr<ImageFrame> viz_frame;
if (video_stream != nullptr && !video_stream->IsEmpty()) { if (video_stream != nullptr && !video_stream->IsEmpty()) {
@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
for (const auto& box : detected_boxes->box()) { for (const auto& box : detected_boxes->box()) {
RenderBox(box, &viz_view); RenderBox(box, &viz_view);
} }
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
} }
if (cc->Outputs().HasTag("BOXES")) { if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp); cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp);
} }
return absl::OkStatus(); return absl::OkStatus();
@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) {
if (write_index_) { if (write_index_) {
BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex(); BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex();
MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents( MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents(
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get<std::string>(), cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get<std::string>(),
index.SerializeAsString())); index.SerializeAsString()));
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2;
namespace { namespace {
constexpr char kCacheDirTag[] = "CACHE_DIR";
constexpr char kInitialPosTag[] = "INITIAL_POS";
constexpr char kRaBoxesTag[] = "RA_BOXES";
constexpr char kBoxesTag[] = "BOXES";
constexpr char kVizTag[] = "VIZ";
constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING";
constexpr char kRaTrackTag[] = "RA_TRACK";
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
constexpr char kRestartPosTag[] = "RESTART_POS";
constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING";
constexpr char kStartPosTag[] = "START_POS";
constexpr char kStartTag[] = "START";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kTrackTimeTag[] = "TRACK_TIME";
constexpr char kTrackingTag[] = "TRACKING";
// Convert box position according to rotation angle in degrees. // Convert box position according to rotation angle in degrees.
void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom, void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom,
float in_right, int rotation, float* out_top, float in_right, int rotation, float* out_top,
@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec,
} // namespace. } // namespace.
absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("TRACKING")) { if (cc->Inputs().HasTag(kTrackingTag)) {
cc->Inputs().Tag("TRACKING").Set<TrackingData>(); cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
} }
if (cc->Inputs().HasTag("TRACK_TIME")) { if (cc->Inputs().HasTag(kTrackTimeTag)) {
RET_CHECK(cc->Inputs().HasTag("TRACKING")) RET_CHECK(cc->Inputs().HasTag(kTrackingTag))
<< "TRACK_TIME needs TRACKING input"; << "TRACK_TIME needs TRACKING input";
cc->Inputs().Tag("TRACK_TIME").SetAny(); cc->Inputs().Tag(kTrackTimeTag).SetAny();
} }
if (cc->Inputs().HasTag("VIDEO")) { if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>(); cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
} }
if (cc->Inputs().HasTag("START")) { if (cc->Inputs().HasTag(kStartTag)) {
// Actual packet content does not matter. // Actual packet content does not matter.
cc->Inputs().Tag("START").SetAny(); cc->Inputs().Tag(kStartTag).SetAny();
} }
if (cc->Inputs().HasTag("START_POS")) { if (cc->Inputs().HasTag(kStartPosTag)) {
cc->Inputs().Tag("START_POS").Set<TimedBoxProtoList>(); cc->Inputs().Tag(kStartPosTag).Set<TimedBoxProtoList>();
} }
if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) { if (cc->Inputs().HasTag(kStartPosProtoStringTag)) {
cc->Inputs().Tag("START_POS_PROTO_STRING").Set<std::string>(); cc->Inputs().Tag(kStartPosProtoStringTag).Set<std::string>();
} }
if (cc->Inputs().HasTag("RESTART_POS")) { if (cc->Inputs().HasTag(kRestartPosTag)) {
cc->Inputs().Tag("RESTART_POS").Set<TimedBoxProtoList>(); cc->Inputs().Tag(kRestartPosTag).Set<TimedBoxProtoList>();
} }
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) { if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>(); cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
} }
if (cc->Inputs().HasTag("RA_TRACK")) { if (cc->Inputs().HasTag(kRaTrackTag)) {
cc->Inputs().Tag("RA_TRACK").Set<TimedBoxProtoList>(); cc->Inputs().Tag(kRaTrackTag).Set<TimedBoxProtoList>();
} }
if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) { if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) {
cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set<std::string>(); cc->Inputs().Tag(kRaTrackProtoStringTag).Set<std::string>();
} }
if (cc->Outputs().HasTag("VIZ")) { if (cc->Outputs().HasTag(kVizTag)) {
RET_CHECK(cc->Inputs().HasTag("VIDEO")) RET_CHECK(cc->Inputs().HasTag(kVideoTag))
<< "Output stream VIZ requires VIDEO to be present."; << "Output stream VIZ requires VIDEO to be present.";
cc->Outputs().Tag("VIZ").Set<ImageFrame>(); cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
} }
if (cc->Outputs().HasTag("BOXES")) { if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>(); cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
} }
if (cc->Outputs().HasTag("RA_BOXES")) { if (cc->Outputs().HasTag(kRaBoxesTag)) {
cc->Outputs().Tag("RA_BOXES").Set<TimedBoxProtoList>(); cc->Outputs().Tag(kRaBoxesTag).Set<TimedBoxProtoList>();
} }
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) #if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS")) RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag))
<< "Unsupported on mobile"; << "Unsupported on mobile";
#else #else
if (cc->InputSidePackets().HasTag("INITIAL_POS")) { if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
cc->InputSidePackets().Tag("INITIAL_POS").Set<std::string>(); cc->InputSidePackets().Tag(kInitialPosTag).Set<std::string>();
} }
#endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__) #endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
if (cc->InputSidePackets().HasTag("CACHE_DIR")) { if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>(); cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
} }
RET_CHECK(cc->Inputs().HasTag("TRACKING") != RET_CHECK(cc->Inputs().HasTag(kTrackingTag) !=
cc->InputSidePackets().HasTag("CACHE_DIR")) cc->InputSidePackets().HasTag(kCacheDirTag))
<< "Either TRACKING or CACHE_DIR needs to be specified."; << "Either TRACKING or CACHE_DIR needs to be specified.";
if (cc->InputSidePackets().HasTag(kOptionsTag)) { if (cc->InputSidePackets().HasTag(kOptionsTag)) {
@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(), options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(),
cc->InputSidePackets(), kOptionsTag); cc->InputSidePackets(), kOptionsTag);
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") || RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) ||
!options_.has_initial_position()) !options_.has_initial_position())
<< "Can not specify initial position as side packet and via options"; << "Can not specify initial position as side packet and via options";
@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
} }
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__) #if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__)
if (cc->InputSidePackets().HasTag("INITIAL_POS")) { if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
LOG(INFO) << "Parsing: " LOG(INFO) << "Parsing: "
<< cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>(); << cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>();
initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>( initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>(
cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>()); cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>());
} }
#endif // !defined(__ANDROID__) && !defined(__APPLE__) && #endif // !defined(__ANDROID__) && !defined(__APPLE__) &&
// !defined(__EMSCRIPTEN__) // !defined(__EMSCRIPTEN__)
@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
} }
visualize_tracking_data_ = visualize_tracking_data_ =
options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ"); options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag);
visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ"); visualize_state_ =
options_.visualize_state() && cc->Outputs().HasTag(kVizTag);
visualize_internal_state_ = visualize_internal_state_ =
options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ"); options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag);
// Force recording of internal state for rendering. // Force recording of internal state for rendering.
if (visualize_internal_state_) { if (visualize_internal_state_) {
@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
options_.mutable_tracker_options()->set_record_path_states(true); options_.mutable_tracker_options()->set_record_path_states(true);
} }
if (cc->InputSidePackets().HasTag("CACHE_DIR")) { if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>(); cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
RET_CHECK(!cache_dir_.empty()); RET_CHECK(!cache_dir_.empty());
box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options())); box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options()));
} else { } else {
@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
} }
if (options_.streaming_track_data_cache_size() > 0) { if (options_.streaming_track_data_cache_size() > 0) {
RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR")) RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag))
<< "Streaming mode not compatible with cache dir."; << "Streaming mode not compatible with cache dir.";
} }
@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
InputStream* track_stream = cc->Inputs().HasTag("TRACKING") InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
? &(cc->Inputs().Tag("TRACKING")) ? &(cc->Inputs().Tag(kTrackingTag))
: nullptr; : nullptr;
InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME") InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag)
? &(cc->Inputs().Tag("TRACK_TIME")) ? &(cc->Inputs().Tag(kTrackTimeTag))
: nullptr; : nullptr;
// Cache tracking data if possible. // Cache tracking data if possible.
@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
} }
InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS") InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag)
? &(cc->Inputs().Tag("START_POS")) ? &(cc->Inputs().Tag(kStartPosTag))
: nullptr; : nullptr;
MotionBoxMap fast_forward_boxes; MotionBoxMap fast_forward_boxes;
@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
InputStream* start_pos_proto_string_stream = InputStream* start_pos_proto_string_stream =
cc->Inputs().HasTag("START_POS_PROTO_STRING") cc->Inputs().HasTag(kStartPosProtoStringTag)
? &(cc->Inputs().Tag("START_POS_PROTO_STRING")) ? &(cc->Inputs().Tag(kStartPosProtoStringTag))
: nullptr; : nullptr;
if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) { if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) {
if (start_pos_proto_string_stream && if (start_pos_proto_string_stream &&
@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
} }
InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS") InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag)
? &(cc->Inputs().Tag("RESTART_POS")) ? &(cc->Inputs().Tag(kRestartPosTag))
: nullptr; : nullptr;
if (restart_pos_stream && !restart_pos_stream->IsEmpty()) { if (restart_pos_stream && !restart_pos_stream->IsEmpty()) {
@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
InputStream* cancel_object_id_stream = InputStream* cancel_object_id_stream =
cc->Inputs().HasTag("CANCEL_OBJECT_ID") cc->Inputs().HasTag(kCancelObjectIdTag)
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID")) ? &(cc->Inputs().Tag(kCancelObjectIdTag))
: nullptr; : nullptr;
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) { if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
const int cancel_object_id = cancel_object_id_stream->Get<int>(); const int cancel_object_id = cancel_object_id_stream->Get<int>();
@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
TrackingData track_data_to_render; TrackingData track_data_to_render;
if (cc->Outputs().HasTag("VIZ")) { if (cc->Outputs().HasTag(kVizTag)) {
InputStream* video_stream = &(cc->Inputs().Tag("VIDEO")); InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag));
if (!video_stream->IsEmpty()) { if (!video_stream->IsEmpty()) {
input_view = formats::MatView(&video_stream->Get<ImageFrame>()); input_view = formats::MatView(&video_stream->Get<ImageFrame>());
@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
++frame_num_since_reset_; ++frame_num_since_reset_;
// Generate results for queued up request. // Generate results for queued up request.
if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) { if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) {
for (int j = 0; j < queued_track_requests_.size(); ++j) { for (int j = 0; j < queued_track_requests_.size(); ++j) {
const Timestamp& past_time = queued_track_requests_[j]; const Timestamp& past_time = queued_track_requests_[j];
RET_CHECK(past_time.Value() < timestamp.Value()) RET_CHECK(past_time.Value() < timestamp.Value())
@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
// Output for every time. // Output for every time.
cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time); cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time);
} }
queued_track_requests_.clear(); queued_track_requests_.clear();
@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
// Handle random access track requests. // Handle random access track requests.
InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK") InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag)
? &(cc->Inputs().Tag("RA_TRACK")) ? &(cc->Inputs().Tag(kRaTrackTag))
: nullptr; : nullptr;
if (ra_track_stream && !ra_track_stream->IsEmpty()) { if (ra_track_stream && !ra_track_stream->IsEmpty()) {
@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
} }
InputStream* ra_track_proto_string_stream = InputStream* ra_track_proto_string_stream =
cc->Inputs().HasTag("RA_TRACK_PROTO_STRING") cc->Inputs().HasTag(kRaTrackProtoStringTag)
? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING")) ? &(cc->Inputs().Tag(kRaTrackProtoStringTag))
: nullptr; : nullptr;
if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) { if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) {
if (ra_track_proto_string_stream && if (ra_track_proto_string_stream &&
@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
// Always output in batch, only output in streaming if tracking data // Always output in batch, only output in streaming if tracking data
// is present (might be in fast forward mode instead). // is present (might be in fast forward mode instead).
if (cc->Outputs().HasTag("BOXES") && if (cc->Outputs().HasTag(kBoxesTag) &&
(box_tracker_ || !track_stream->IsEmpty())) { (box_tracker_ || !track_stream->IsEmpty())) {
std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList()); std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList());
*boxes = std::move(box_track_list); *boxes = std::move(box_track_list);
cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp); cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp);
} }
if (viz_frame) { if (viz_frame) {
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
} }
return absl::OkStatus(); return absl::OkStatus();
@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack(
} }
cc->Outputs() cc->Outputs()
.Tag("RA_BOXES") .Tag(kRaBoxesTag)
.Add(result_list.release(), cc->InputTimestamp()); .Add(result_list.release(), cc->InputTimestamp());
} }

View File

@ -29,6 +29,13 @@
namespace mediapipe { namespace mediapipe {
constexpr char kCacheDirTag[] = "CACHE_DIR";
constexpr char kCompleteTag[] = "COMPLETE";
constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK";
constexpr char kTrackingTag[] = "TRACKING";
constexpr char kCameraTag[] = "CAMERA";
constexpr char kFlowTag[] = "FLOW";
using mediapipe::CameraMotion; using mediapipe::CameraMotion;
using mediapipe::FlowPackager; using mediapipe::FlowPackager;
using mediapipe::RegionFlowFeatureList; using mediapipe::RegionFlowFeatureList;
@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase {
REGISTER_CALCULATOR(FlowPackagerCalculator); REGISTER_CALCULATOR(FlowPackagerCalculator);
absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().HasTag("FLOW")) { if (!cc->Inputs().HasTag(kFlowTag)) {
return tool::StatusFail("No input flow was specified."); return tool::StatusFail("No input flow was specified.");
} }
cc->Inputs().Tag("FLOW").Set<RegionFlowFeatureList>(); cc->Inputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
if (cc->Inputs().HasTag("CAMERA")) { if (cc->Inputs().HasTag(kCameraTag)) {
cc->Inputs().Tag("CAMERA").Set<CameraMotion>(); cc->Inputs().Tag(kCameraTag).Set<CameraMotion>();
} }
if (cc->Outputs().HasTag("TRACKING")) { if (cc->Outputs().HasTag(kTrackingTag)) {
cc->Outputs().Tag("TRACKING").Set<TrackingData>(); cc->Outputs().Tag(kTrackingTag).Set<TrackingData>();
} }
if (cc->Outputs().HasTag("TRACKING_CHUNK")) { if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs().Tag("TRACKING_CHUNK").Set<TrackingDataChunk>(); cc->Outputs().Tag(kTrackingChunkTag).Set<TrackingDataChunk>();
} }
if (cc->Outputs().HasTag("COMPLETE")) { if (cc->Outputs().HasTag(kCompleteTag)) {
cc->Outputs().Tag("COMPLETE").Set<bool>(); cc->Outputs().Tag(kCompleteTag).Set<bool>();
} }
if (cc->InputSidePackets().HasTag("CACHE_DIR")) { if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>(); cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) {
flow_packager_.reset(new FlowPackager(options_.flow_packager_options())); flow_packager_.reset(new FlowPackager(options_.flow_packager_options()));
use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR"); use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag);
build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK"); build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag);
if (use_caching_) { if (use_caching_) {
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>(); cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
InputStream* flow_stream = &(cc->Inputs().Tag("FLOW")); InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag));
const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>(); const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>();
const Timestamp timestamp = flow_stream->Value().Timestamp(); const Timestamp timestamp = flow_stream->Value().Timestamp();
const CameraMotion* camera_motion = nullptr; const CameraMotion* camera_motion = nullptr;
if (cc->Inputs().HasTag("CAMERA")) { if (cc->Inputs().HasTag(kCameraTag)) {
InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA")); InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag));
camera_motion = &camera_stream->Get<CameraMotion>(); camera_motion = &camera_stream->Get<CameraMotion>();
} }
@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
if (frame_idx_ > 0) { if (frame_idx_ > 0) {
item->set_prev_timestamp_usec(prev_timestamp_.Value()); item->set_prev_timestamp_usec(prev_timestamp_.Value());
} }
if (cc->Outputs().HasTag("TRACKING")) { if (cc->Outputs().HasTag(kTrackingTag)) {
// Need to copy as output is requested. // Need to copy as output is requested.
*item->mutable_tracking_data() = *tracking_data; *item->mutable_tracking_data() = *tracking_data;
} else { } else {
@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
options_.caching_chunk_size_msec() * (chunk_idx_ + 1); options_.caching_chunk_size_msec() * (chunk_idx_ + 1);
if (timestamp.Value() / 1000 >= next_chunk_msec) { if (timestamp.Value() / 1000 >= next_chunk_msec) {
if (cc->Outputs().HasTag("TRACKING_CHUNK")) { if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs() cc->Outputs()
.Tag("TRACKING_CHUNK") .Tag(kTrackingChunkTag)
.Add(new TrackingDataChunk(tracking_chunk_), .Add(new TrackingDataChunk(tracking_chunk_),
Timestamp(tracking_chunk_.item(0).timestamp_usec())); Timestamp(tracking_chunk_.item(0).timestamp_usec()));
} }
@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
} }
} }
if (cc->Outputs().HasTag("TRACKING")) { if (cc->Outputs().HasTag(kTrackingTag)) {
cc->Outputs() cc->Outputs()
.Tag("TRACKING") .Tag(kTrackingTag)
.Add(tracking_data.release(), flow_stream->Value().Timestamp()); .Add(tracking_data.release(), flow_stream->Value().Timestamp());
} }
@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
if (frame_idx_ > 0) { if (frame_idx_ > 0) {
tracking_chunk_.set_last_chunk(true); tracking_chunk_.set_last_chunk(true);
if (cc->Outputs().HasTag("TRACKING_CHUNK")) { if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs() cc->Outputs()
.Tag("TRACKING_CHUNK") .Tag(kTrackingChunkTag)
.Add(new TrackingDataChunk(tracking_chunk_), .Add(new TrackingDataChunk(tracking_chunk_),
Timestamp(tracking_chunk_.item(0).timestamp_usec())); Timestamp(tracking_chunk_.item(0).timestamp_usec()));
} }
@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
} }
} }
if (cc->Outputs().HasTag("COMPLETE")) { if (cc->Outputs().HasTag(kCompleteTag)) {
cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream()); cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream());
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -38,6 +38,18 @@
namespace mediapipe { namespace mediapipe {
constexpr char kDownsampleTag[] = "DOWNSAMPLE";
constexpr char kCsvFileTag[] = "CSV_FILE";
constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT";
constexpr char kVideoOutTag[] = "VIDEO_OUT";
constexpr char kDenseFgTag[] = "DENSE_FG";
constexpr char kVizTag[] = "VIZ";
constexpr char kSaliencyTag[] = "SALIENCY";
constexpr char kCameraTag[] = "CAMERA";
constexpr char kFlowTag[] = "FLOW";
constexpr char kSelectionTag[] = "SELECTION";
constexpr char kVideoTag[] = "VIDEO";
using mediapipe::AffineAdapter; using mediapipe::AffineAdapter;
using mediapipe::CameraMotion; using mediapipe::CameraMotion;
using mediapipe::FrameSelectionResult; using mediapipe::FrameSelectionResult;
@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase {
REGISTER_CALCULATOR(MotionAnalysisCalculator); REGISTER_CALCULATOR(MotionAnalysisCalculator);
absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) { absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("VIDEO")) { if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>(); cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
} }
// Optional input stream from frame selection calculator. // Optional input stream from frame selection calculator.
if (cc->Inputs().HasTag("SELECTION")) { if (cc->Inputs().HasTag(kSelectionTag)) {
cc->Inputs().Tag("SELECTION").Set<FrameSelectionResult>(); cc->Inputs().Tag(kSelectionTag).Set<FrameSelectionResult>();
} }
RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION")) RET_CHECK(cc->Inputs().HasTag(kVideoTag) ||
cc->Inputs().HasTag(kSelectionTag))
<< "Either VIDEO, SELECTION must be specified."; << "Either VIDEO, SELECTION must be specified.";
if (cc->Outputs().HasTag("FLOW")) { if (cc->Outputs().HasTag(kFlowTag)) {
cc->Outputs().Tag("FLOW").Set<RegionFlowFeatureList>(); cc->Outputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
} }
if (cc->Outputs().HasTag("CAMERA")) { if (cc->Outputs().HasTag(kCameraTag)) {
cc->Outputs().Tag("CAMERA").Set<CameraMotion>(); cc->Outputs().Tag(kCameraTag).Set<CameraMotion>();
} }
if (cc->Outputs().HasTag("SALIENCY")) { if (cc->Outputs().HasTag(kSaliencyTag)) {
cc->Outputs().Tag("SALIENCY").Set<SalientPointFrame>(); cc->Outputs().Tag(kSaliencyTag).Set<SalientPointFrame>();
} }
if (cc->Outputs().HasTag("VIZ")) { if (cc->Outputs().HasTag(kVizTag)) {
cc->Outputs().Tag("VIZ").Set<ImageFrame>(); cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
} }
if (cc->Outputs().HasTag("DENSE_FG")) { if (cc->Outputs().HasTag(kDenseFgTag)) {
cc->Outputs().Tag("DENSE_FG").Set<ImageFrame>(); cc->Outputs().Tag(kDenseFgTag).Set<ImageFrame>();
} }
if (cc->Outputs().HasTag("VIDEO_OUT")) { if (cc->Outputs().HasTag(kVideoOutTag)) {
cc->Outputs().Tag("VIDEO_OUT").Set<ImageFrame>(); cc->Outputs().Tag(kVideoOutTag).Set<ImageFrame>();
} }
if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) { if (cc->Outputs().HasTag(kGrayVideoOutTag)) {
// We only output grayscale video if we're actually performing full region- // We only output grayscale video if we're actually performing full region-
// flow analysis on the video. // flow analysis on the video.
RET_CHECK(cc->Inputs().HasTag("VIDEO") && RET_CHECK(cc->Inputs().HasTag(kVideoTag) &&
!cc->Inputs().HasTag("SELECTION")); !cc->Inputs().HasTag(kSelectionTag));
cc->Outputs().Tag("GRAY_VIDEO_OUT").Set<ImageFrame>(); cc->Outputs().Tag(kGrayVideoOutTag).Set<ImageFrame>();
} }
if (cc->InputSidePackets().HasTag("CSV_FILE")) { if (cc->InputSidePackets().HasTag(kCsvFileTag)) {
cc->InputSidePackets().Tag("CSV_FILE").Set<std::string>(); cc->InputSidePackets().Tag(kCsvFileTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
cc->InputSidePackets().Tag("DOWNSAMPLE").Set<float>(); cc->InputSidePackets().Tag(kDownsampleTag).Set<float>();
} }
if (cc->InputSidePackets().HasTag(kOptionsTag)) { if (cc->InputSidePackets().HasTag(kOptionsTag)) {
@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(), tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(),
cc->InputSidePackets(), kOptionsTag); cc->InputSidePackets(), kOptionsTag);
video_input_ = cc->Inputs().HasTag("VIDEO"); video_input_ = cc->Inputs().HasTag(kVideoTag);
selection_input_ = cc->Inputs().HasTag("SELECTION"); selection_input_ = cc->Inputs().HasTag(kSelectionTag);
region_flow_feature_output_ = cc->Outputs().HasTag("FLOW"); region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag);
camera_motion_output_ = cc->Outputs().HasTag("CAMERA"); camera_motion_output_ = cc->Outputs().HasTag(kCameraTag);
saliency_output_ = cc->Outputs().HasTag("SALIENCY"); saliency_output_ = cc->Outputs().HasTag(kSaliencyTag);
visualize_output_ = cc->Outputs().HasTag("VIZ"); visualize_output_ = cc->Outputs().HasTag(kVizTag);
dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG"); dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag);
video_output_ = cc->Outputs().HasTag("VIDEO_OUT"); video_output_ = cc->Outputs().HasTag(kVideoOutTag);
grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT"); grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag);
csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE"); csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag);
hybrid_meta_analysis_ = options_.meta_analysis() == hybrid_meta_analysis_ = options_.meta_analysis() ==
MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID; MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID;
@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
if (csv_file_input_) { if (csv_file_input_) {
// Read from file and parse. // Read from file and parse.
const std::string filename = const std::string filename =
cc->InputSidePackets().Tag("CSV_FILE").Get<std::string>(); cc->InputSidePackets().Tag(kCsvFileTag).Get<std::string>();
std::string file_contents; std::string file_contents;
std::ifstream input_file(filename, std::ios::in); std::ifstream input_file(filename, std::ios::in);
@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
// Get video header from video or selection input if present. // Get video header from video or selection input if present.
const VideoHeader* video_header = nullptr; const VideoHeader* video_header = nullptr;
if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) { if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) {
video_header = &(cc->Inputs().Tag("VIDEO").Header().Get<VideoHeader>()); video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get<VideoHeader>());
} else if (selection_input_ && } else if (selection_input_ &&
!cc->Inputs().Tag("SELECTION").Header().IsEmpty()) { !cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) {
video_header = &(cc->Inputs().Tag("SELECTION").Header().Get<VideoHeader>()); video_header =
&(cc->Inputs().Tag(kSelectionTag).Header().Get<VideoHeader>());
} else { } else {
LOG(WARNING) << "No input video header found. Downstream calculators " LOG(WARNING) << "No input video header found. Downstream calculators "
"expecting video headers are likely to fail."; "expecting video headers are likely to fail.";
@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
with_saliency_ = options_.analysis_options().compute_motion_saliency(); with_saliency_ = options_.analysis_options().compute_motion_saliency();
// Force computation of saliency if requested as output. // Force computation of saliency if requested as output.
if (cc->Outputs().HasTag("SALIENCY")) { if (cc->Outputs().HasTag(kSaliencyTag)) {
with_saliency_ = true; with_saliency_ = true;
if (!options_.analysis_options().compute_motion_saliency()) { if (!options_.analysis_options().compute_motion_saliency()) {
LOG(WARNING) << "Enable saliency computation. Set " LOG(WARNING) << "Enable saliency computation. Set "
@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
} }
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) { if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
options_.mutable_analysis_options() options_.mutable_analysis_options()
->mutable_flow_options() ->mutable_flow_options()
->set_downsample_factor( ->set_downsample_factor(
cc->InputSidePackets().Tag("DOWNSAMPLE").Get<float>()); cc->InputSidePackets().Tag(kDownsampleTag).Get<float>());
} }
// If no video header is provided, just return and initialize on the first // If no video header is provided, just return and initialize on the first
@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE /////////////// ////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE ///////////////
if (visualize_output_) { if (visualize_output_) {
cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header))); cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header)));
} }
if (video_output_) { if (video_output_) {
cc->Outputs() cc->Outputs()
.Tag("VIDEO_OUT") .Tag(kVideoOutTag)
.SetHeader(Adopt(new VideoHeader(*video_header))); .SetHeader(Adopt(new VideoHeader(*video_header)));
} }
if (cc->Outputs().HasTag("DENSE_FG")) { if (cc->Outputs().HasTag(kDenseFgTag)) {
std::unique_ptr<VideoHeader> foreground_header( std::unique_ptr<VideoHeader> foreground_header(
new VideoHeader(*video_header)); new VideoHeader(*video_header));
foreground_header->format = ImageFormat::GRAY8; foreground_header->format = ImageFormat::GRAY8;
cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release()));
}
if (cc->Outputs().HasTag("CAMERA")) {
cc->Outputs().Tag("CAMERA").SetHeader(
Adopt(new VideoHeader(*video_header)));
}
if (cc->Outputs().HasTag("SALIENCY")) {
cc->Outputs() cc->Outputs()
.Tag("SALIENCY") .Tag(kDenseFgTag)
.SetHeader(Adopt(foreground_header.release()));
}
if (cc->Outputs().HasTag(kCameraTag)) {
cc->Outputs()
.Tag(kCameraTag)
.SetHeader(Adopt(new VideoHeader(*video_header)));
}
if (cc->Outputs().HasTag(kSaliencyTag)) {
cc->Outputs()
.Tag(kSaliencyTag)
.SetHeader(Adopt(new VideoHeader(*video_header))); .SetHeader(Adopt(new VideoHeader(*video_header)));
} }
@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
} }
InputStream* video_stream = InputStream* video_stream =
video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr; video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
InputStream* selection_stream = InputStream* selection_stream =
selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr; selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr;
// Checked on Open. // Checked on Open.
CHECK(video_stream || selection_stream); CHECK(video_stream || selection_stream);
@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
CameraMotion output_motion = meta_motions_.front(); CameraMotion output_motion = meta_motions_.front();
meta_motions_.pop_front(); meta_motions_.pop_front();
output_motion.set_timestamp_usec(timestamp.Value()); output_motion.set_timestamp_usec(timestamp.Value());
cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion), cc->Outputs()
timestamp); .Tag(kCameraTag)
.Add(new CameraMotion(output_motion), timestamp);
} }
if (region_flow_feature_output_) { if (region_flow_feature_output_) {
@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
meta_features_.pop_front(); meta_features_.pop_front();
output_features.set_timestamp_usec(timestamp.Value()); output_features.set_timestamp_usec(timestamp.Value());
cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features), cc->Outputs().Tag(kFlowTag).Add(
timestamp); new RegionFlowFeatureList(output_features), timestamp);
} }
++frame_idx_; ++frame_idx_;
@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) { MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) {
// Output concatenated results, nothing to compute here. // Output concatenated results, nothing to compute here.
if (camera_motion_output_) { if (camera_motion_output_) {
cc->Outputs().Tag("CAMERA").Add( cc->Outputs()
frame_selection_result->release_camera_motion(), timestamp); .Tag(kCameraTag)
.Add(frame_selection_result->release_camera_motion(), timestamp);
} }
if (region_flow_feature_output_) { if (region_flow_feature_output_) {
cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(), cc->Outputs().Tag(kFlowTag).Add(
timestamp); frame_selection_result->release_features(), timestamp);
} }
if (video_output_) { if (video_output_) {
cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value()); cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value());
} }
return absl::OkStatus(); return absl::OkStatus();
@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
grayscale_mat.copyTo(image_frame_mat); grayscale_mat.copyTo(image_frame_mat);
cc->Outputs() cc->Outputs()
.Tag("GRAY_VIDEO_OUT") .Tag(kGrayVideoOutTag)
.Add(grayscale_image.release(), timestamp); .Add(grayscale_image.release(), timestamp);
} }
@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
*feature_list, *camera_motion, *feature_list, *camera_motion,
with_saliency_ ? saliency[k].get() : nullptr, &visualization); with_saliency_ ? saliency[k].get() : nullptr, &visualization);
cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp); cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp);
} }
// Output dense foreground mask. // Output dense foreground mask.
@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
cv::Mat foreground = formats::MatView(foreground_frame.get()); cv::Mat foreground = formats::MatView(foreground_frame.get());
motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion, motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion,
&foreground); &foreground);
cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp); cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp);
} }
// Output flow features if requested. // Output flow features if requested.
if (region_flow_feature_output_) { if (region_flow_feature_output_) {
cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp); cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp);
} }
// Output camera motion. // Output camera motion.
if (camera_motion_output_) { if (camera_motion_output_) {
cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp); cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp);
} }
if (video_output_) { if (video_output_) {
cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]); cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]);
} }
// Output saliency. // Output saliency.
if (saliency_output_) { if (saliency_output_) {
cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp); cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp);
} }
} }

View File

@ -27,6 +27,12 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
// cv::VideoCapture set data type to unsigned char by default. Therefore, the // cv::VideoCapture set data type to unsigned char by default. Therefore, the
// image format is only related to the number of channles the cv::Mat has. // image format is only related to the number of channles the cv::Mat has.
ImageFormat::Format GetImageFormat(int num_channels) { ImageFormat::Format GetImageFormat(int num_channels) {
@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) {
class OpenCvVideoDecoderCalculator : public CalculatorBase { class OpenCvVideoDecoderCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>(); cc->InputSidePackets().Tag(kInputFilePathTag).Set<std::string>();
cc->Outputs().Tag("VIDEO").Set<ImageFrame>(); cc->Outputs().Tag(kVideoTag).Set<ImageFrame>();
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
cc->Outputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>(); cc->Outputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
} }
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set<std::string>(); cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
const std::string& input_file_path = const std::string& input_file_path =
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>(); cc->InputSidePackets().Tag(kInputFilePathTag).Get<std::string>();
cap_ = absl::make_unique<cv::VideoCapture>(input_file_path); cap_ = absl::make_unique<cv::VideoCapture>(input_file_path);
if (!cap_->isOpened()) { if (!cap_->isOpened()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
header->frame_rate = fps; header->frame_rate = fps;
header->duration = frame_count_ / fps; header->duration = frame_count_ / fps;
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
cc->Outputs() cc->Outputs()
.Tag("VIDEO_PRESTREAM") .Tag(kVideoPrestreamTag)
.Add(header.release(), Timestamp::PreStream()); .Add(header.release(), Timestamp::PreStream());
cc->Outputs().Tag("VIDEO_PRESTREAM").Close(); cc->Outputs().Tag(kVideoPrestreamTag).Close();
} }
// Rewind to the very first frame. // Rewind to the very first frame.
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0); cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
#ifdef HAVE_FFMPEG #ifdef HAVE_FFMPEG
std::string saved_audio_path = std::tmpnam(nullptr); std::string saved_audio_path = std::tmpnam(nullptr);
std::string ffmpeg_command = std::string ffmpeg_command =
@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str()); int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str());
if (status_code == 0) { if (status_code == 0) {
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag("SAVED_AUDIO_PATH") .Tag(kSavedAudioPathTag)
.Set(MakePacket<std::string>(saved_audio_path)); .Set(MakePacket<std::string>(saved_audio_path));
} else { } else {
LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path
<< " by executing the following command: " << " by executing the following command: "
<< ffmpeg_command; << ffmpeg_command;
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag("SAVED_AUDIO_PATH") .Tag(kSavedAudioPathTag)
.Set(MakePacket<std::string>(std::string())); .Set(MakePacket<std::string>(std::string()));
} }
#else #else
@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
// If the timestamp of the current frame is not greater than the one of the // If the timestamp of the current frame is not greater than the one of the
// previous frame, the new frame will be discarded. // previous frame, the new frame will be discarded.
if (prev_timestamp_ < timestamp) { if (prev_timestamp_ < timestamp) {
cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp); cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp);
prev_timestamp_ = timestamp; prev_timestamp_ = timestamp;
decoded_frames_++; decoded_frames_++;
} }

View File

@ -29,6 +29,10 @@ namespace mediapipe {
namespace { namespace {
constexpr char kVideoTag[] = "VIDEO";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
CalculatorGraphConfig::Node node_config = CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath("./",
"/mediapipe/calculators/video/" "/mediapipe/calculators/video/"
"testdata/format_MP4_AVC720P_AAC.video")); "testdata/format_MP4_AVC720P_AAC.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM") .Tag(kVideoPrestreamTag)
.packets[0] .packets[0]
.ValidateAsType<VideoHeader>()); .ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header = const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>(); runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(1280, header.width); EXPECT_EQ(1280, header.width);
EXPECT_EQ(640, header.height); EXPECT_EQ(640, header.height);
@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
// The number of the output packets should be 180. // The number of the output packets should be 180.
// Some OpenCV version returns the first two frames with the same timestamp on // Some OpenCV version returns the first two frames with the same timestamp on
// macos and we might miss one frame here. // macos and we might miss one frame here.
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
EXPECT_GE(num_of_packets, 179); EXPECT_GE(num_of_packets, 179);
for (int i = 0; i < num_of_packets; ++i) { for (int i = 0; i < num_of_packets; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat = cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>())); formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(1280, output_mat.size().width); EXPECT_EQ(1280, output_mat.size().width);
@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath("./",
"/mediapipe/calculators/video/" "/mediapipe/calculators/video/"
"testdata/format_FLV_H264_AAC.video")); "testdata/format_FLV_H264_AAC.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM") .Tag(kVideoPrestreamTag)
.packets[0] .packets[0]
.ValidateAsType<VideoHeader>()); .ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header = const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>(); runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(640, header.width); EXPECT_EQ(640, header.width);
EXPECT_EQ(320, header.height); EXPECT_EQ(320, header.height);
@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
// can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4). // can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4).
// EXPECT_FLOAT_EQ(6.0f, header.duration); // EXPECT_FLOAT_EQ(6.0f, header.duration);
// EXPECT_FLOAT_EQ(30.0f, header.frame_rate); // EXPECT_FLOAT_EQ(30.0f, header.frame_rate);
EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size()); EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size());
for (int i = 0; i < 180; ++i) { for (int i = 0; i < 180; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat = cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>())); formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(640, output_mat.size().width); EXPECT_EQ(640, output_mat.size().width);
@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath("./",
"/mediapipe/calculators/video/" "/mediapipe/calculators/video/"
"testdata/format_MKV_VP8_VORBIS.video")); "testdata/format_MKV_VP8_VORBIS.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM") .Tag(kVideoPrestreamTag)
.packets[0] .packets[0]
.ValidateAsType<VideoHeader>()); .ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header = const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>(); runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format); EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(640, header.width); EXPECT_EQ(640, header.width);
EXPECT_EQ(320, header.height); EXPECT_EQ(320, header.height);
@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
// The number of the output packets should be 180. // The number of the output packets should be 180.
// Some OpenCV version returns the first two frames with the same timestamp on // Some OpenCV version returns the first two frames with the same timestamp on
// macos and we might miss one frame here. // macos and we might miss one frame here.
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size(); int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
EXPECT_GE(num_of_packets, 179); EXPECT_GE(num_of_packets, 179);
for (int i = 0; i < num_of_packets; ++i) { for (int i = 0; i < num_of_packets; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i]; Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat = cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>())); formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(640, output_mat.size().width); EXPECT_EQ(640, output_mat.size().width);

View File

@ -36,6 +36,11 @@
namespace mediapipe { namespace mediapipe {
constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH";
constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kVideoTag[] = "VIDEO";
// Encodes the input video stream and produces a media file. // Encodes the input video stream and produces a media file.
// The media file can be output to the output_file_path specified as a side // The media file can be output to the output_file_path specified as a side
// packet. Currently, the calculator only supports one video stream (in // packet. Currently, the calculator only supports one video stream (in
@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase {
}; };
absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) { absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("VIDEO")); RET_CHECK(cc->Inputs().HasTag(kVideoTag));
cc->Inputs().Tag("VIDEO").Set<ImageFrame>(); cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>(); cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
} }
RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH")); RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag));
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set<std::string>(); cc->InputSidePackets().Tag(kOutputFilePathTag).Set<std::string>();
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set<std::string>(); cc->InputSidePackets().Tag(kAudioFilePathTag).Set<std::string>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
<< "Video format must be specified in " << "Video format must be specified in "
"OpenCvVideoEncoderCalculatorOptions"; "OpenCvVideoEncoderCalculatorOptions";
output_file_path_ = output_file_path_ =
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get<std::string>(); cc->InputSidePackets().Tag(kOutputFilePathTag).Get<std::string>();
std::vector<std::string> splited_file_path = std::vector<std::string> splited_file_path =
absl::StrSplit(output_file_path_, '.'); absl::StrSplit(output_file_path_, '.');
RET_CHECK(splited_file_path.size() >= 2 && RET_CHECK(splited_file_path.size() >= 2 &&
@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
// If the video header will be available, the video metadata will be fetched // If the video header will be available, the video metadata will be fetched
// from the video header directly. The calculator will receive the video // from the video header directly. The calculator will receive the video
// header packet at timestamp prestream. // header packet at timestamp prestream.
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
return absl::OkStatus(); return absl::OkStatus();
} }
return SetUpVideoWriter(options.fps(), options.width(), options.height()); return SetUpVideoWriter(options.fps(), options.width(), options.height());
@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream()) { if (cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header = const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>(); cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
return SetUpVideoWriter(video_header.frame_rate, video_header.width, return SetUpVideoWriter(video_header.frame_rate, video_header.width,
video_header.height); video_header.height);
} }
const ImageFrame& image_frame = const ImageFrame& image_frame =
cc->Inputs().Tag("VIDEO").Value().Get<ImageFrame>(); cc->Inputs().Tag(kVideoTag).Value().Get<ImageFrame>();
ImageFormat::Format format = image_frame.Format(); ImageFormat::Format format = image_frame.Format();
cv::Mat frame; cv::Mat frame;
if (format == ImageFormat::GRAY8) { if (format == ImageFormat::GRAY8) {
@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (frame.empty()) { if (frame.empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Receive empty frame at timestamp " << "Receive empty frame at timestamp "
<< cc->Inputs().Tag("VIDEO").Value().Timestamp() << cc->Inputs().Tag(kVideoTag).Value().Timestamp()
<< " in OpenCvVideoEncoderCalculator::Process()"; << " in OpenCvVideoEncoderCalculator::Process()";
} }
} else { } else {
@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (tmp_frame.empty()) { if (tmp_frame.empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Receive empty frame at timestamp " << "Receive empty frame at timestamp "
<< cc->Inputs().Tag("VIDEO").Value().Timestamp() << cc->Inputs().Tag(kVideoTag).Value().Timestamp()
<< " in OpenCvVideoEncoderCalculator::Process()"; << " in OpenCvVideoEncoderCalculator::Process()";
} }
if (format == ImageFormat::SRGB) { if (format == ImageFormat::SRGB) {
@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) {
if (writer_ && writer_->isOpened()) { if (writer_ && writer_->isOpened()) {
writer_->release(); writer_->release();
} }
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
#ifdef HAVE_FFMPEG #ifdef HAVE_FFMPEG
const std::string& audio_file_path = const std::string& audio_file_path =
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get<std::string>(); cc->InputSidePackets().Tag(kAudioFilePathTag).Get<std::string>();
if (audio_file_path.empty()) { if (audio_file_path.empty()) {
LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the " LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the "
"audio tracks to the generated video because the audio " "audio tracks to the generated video because the audio "

View File

@ -23,6 +23,11 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW";
constexpr char kForwardFlowTag[] = "FORWARD_FLOW";
constexpr char kSecondFrameTag[] = "SECOND_FRAME";
constexpr char kFirstFrameTag[] = "FIRST_FRAME";
// Checks that img1 and img2 have the same dimensions. // Checks that img1 and img2 have the same dimensions.
bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) { bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) {
return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height()); return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height());
@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase {
}; };
absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) { absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().HasTag("FIRST_FRAME") || if (!cc->Inputs().HasTag(kFirstFrameTag) ||
!cc->Inputs().HasTag("SECOND_FRAME")) { !cc->Inputs().HasTag(kSecondFrameTag)) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME "
"must be specified."); "must be specified.");
} }
cc->Inputs().Tag("FIRST_FRAME").Set<ImageFrame>(); cc->Inputs().Tag(kFirstFrameTag).Set<ImageFrame>();
cc->Inputs().Tag("SECOND_FRAME").Set<ImageFrame>(); cc->Inputs().Tag(kSecondFrameTag).Set<ImageFrame>();
if (cc->Outputs().HasTag("FORWARD_FLOW")) { if (cc->Outputs().HasTag(kForwardFlowTag)) {
cc->Outputs().Tag("FORWARD_FLOW").Set<OpticalFlowField>(); cc->Outputs().Tag(kForwardFlowTag).Set<OpticalFlowField>();
} }
if (cc->Outputs().HasTag("BACKWARD_FLOW")) { if (cc->Outputs().HasTag(kBackwardFlowTag)) {
cc->Outputs().Tag("BACKWARD_FLOW").Set<OpticalFlowField>(); cc->Outputs().Tag(kBackwardFlowTag).Set<OpticalFlowField>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
absl::MutexLock lock(&mutex_); absl::MutexLock lock(&mutex_);
tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1());
} }
if (cc->Outputs().HasTag("FORWARD_FLOW")) { if (cc->Outputs().HasTag(kForwardFlowTag)) {
forward_requested_ = true; forward_requested_ = true;
} }
if (cc->Outputs().HasTag("BACKWARD_FLOW")) { if (cc->Outputs().HasTag(kBackwardFlowTag)) {
backward_requested_ = true; backward_requested_ = true;
} }
@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
const ImageFrame& first_frame = const ImageFrame& first_frame =
cc->Inputs().Tag("FIRST_FRAME").Value().Get<ImageFrame>(); cc->Inputs().Tag(kFirstFrameTag).Value().Get<ImageFrame>();
const ImageFrame& second_frame = const ImageFrame& second_frame =
cc->Inputs().Tag("SECOND_FRAME").Value().Get<ImageFrame>(); cc->Inputs().Tag(kSecondFrameTag).Value().Get<ImageFrame>();
if (forward_requested_) { if (forward_requested_) {
auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>(); auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>();
MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame, MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame,
forward_optical_flow_field.get())); forward_optical_flow_field.get()));
cc->Outputs() cc->Outputs()
.Tag("FORWARD_FLOW") .Tag(kForwardFlowTag)
.Add(forward_optical_flow_field.release(), cc->InputTimestamp()); .Add(forward_optical_flow_field.release(), cc->InputTimestamp());
} }
if (backward_requested_) { if (backward_requested_) {
@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame, MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame,
backward_optical_flow_field.get())); backward_optical_flow_field.get()));
cc->Outputs() cc->Outputs()
.Tag("BACKWARD_FLOW") .Tag(kBackwardFlowTag)
.Add(backward_optical_flow_field.release(), cc->InputTimestamp()); .Add(backward_optical_flow_field.release(), cc->InputTimestamp());
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -19,6 +19,9 @@
namespace mediapipe { namespace mediapipe {
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kFrameTag[] = "FRAME";
// Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp // Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp
// PreStream. Note that this calculator only fills in format, width, and height, // PreStream. Note that this calculator only fills in format, width, and height,
// i.e. frame_rate and duration will not be filled, unless: // i.e. frame_rate and duration will not be filled, unless:
@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().UsesTags()) { if (!cc->Inputs().UsesTags()) {
cc->Inputs().Index(0).Set<ImageFrame>(); cc->Inputs().Index(0).Set<ImageFrame>();
} else { } else {
cc->Inputs().Tag("FRAME").Set<ImageFrame>(); cc->Inputs().Tag(kFrameTag).Set<ImageFrame>();
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>(); cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
} }
cc->Outputs().Index(0).Set<VideoHeader>(); cc->Outputs().Index(0).Set<VideoHeader>();
return absl::OkStatus(); return absl::OkStatus();
@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) {
frame_rate_in_prestream_ = cc->Inputs().UsesTags() && frame_rate_in_prestream_ = cc->Inputs().UsesTags() &&
cc->Inputs().HasTag("FRAME") && cc->Inputs().HasTag(kFrameTag) &&
cc->Inputs().HasTag("VIDEO_PRESTREAM"); cc->Inputs().HasTag(kVideoPrestreamTag);
header_ = absl::make_unique<VideoHeader>(); header_ = absl::make_unique<VideoHeader>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream(
CalculatorContext* cc) { CalculatorContext* cc) {
cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment(); cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment();
if (cc->InputTimestamp() == Timestamp::PreStream()) { if (cc->InputTimestamp() == Timestamp::PreStream()) {
RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty()); RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty());
RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty());
*header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>(); *header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero"; RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero";
} else { } else {
RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty()) RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty())
<< "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream()."; << "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream().";
RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty());
const auto& frame = cc->Inputs().Tag("FRAME").Get<ImageFrame>(); const auto& frame = cc->Inputs().Tag(kFrameTag).Get<ImageFrame>();
header_->format = frame.Format(); header_->format = frame.Format();
header_->width = frame.Width(); header_->width = frame.Width();
header_->height = frame.Height(); header_->height = frame.Height();

View File

@ -44,28 +44,32 @@ using mediapipe::MakePacket;
using mediapipe::OutputStreamShardSet; using mediapipe::OutputStreamShardSet;
using mediapipe::Timestamp; using mediapipe::Timestamp;
namespace proto_ns = mediapipe::proto_ns; namespace proto_ns = mediapipe::proto_ns;
constexpr char kEventTag[] = "EVENT";
constexpr char kOutTag[] = "OUT";
using mediapipe::CalculatorGraph; using mediapipe::CalculatorGraph;
using mediapipe::Packet; using mediapipe::Packet;
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase { class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
public: public:
static absl::Status GetContract(mediapipe::CalculatorContract* cc) { static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
cc->Outputs().Tag("OUT").Set<int>(); cc->Outputs().Tag(kOutTag).Set<int>();
cc->Outputs().Tag("EVENT").Set<int>(); cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1))); cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Tag("OUT").AddPacket( cc->Outputs().Tag(kOutTag).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_))); MakePacket<int>(count_).At(Timestamp(count_)));
count_++; count_++;
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Close(CalculatorContext* cc) override { absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2))); cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -81,11 +85,11 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
cc->Inputs().Get("", i).SetAny(); cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
} }
cc->Outputs().Tag("EVENT").Set<int>(); cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1))); cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
@ -98,7 +102,7 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
: mediapipe::tool::StatusStop(); : mediapipe::tool::StatusStop();
} }
absl::Status Close(CalculatorContext* cc) override { absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2))); cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -65,6 +65,16 @@ namespace mediapipe {
namespace { namespace {
constexpr char kCounter2Tag[] = "COUNTER2";
constexpr char kCounter1Tag[] = "COUNTER1";
constexpr char kExtraTag[] = "EXTRA";
constexpr char kWaitSemTag[] = "WAIT_SEM";
constexpr char kPostSemTag[] = "POST_SEM";
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
constexpr char kOutputTag[] = "OUTPUT";
constexpr char kInputTag[] = "INPUT";
constexpr char kSelectTag[] = "SELECT";
using testing::ElementsAre; using testing::ElementsAre;
using testing::HasSubstr; using testing::HasSubstr;
@ -125,8 +135,8 @@ class DemuxTimedCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag("SELECT").Set<int>(); cc->Inputs().Tag(kSelectTag).Set<int>();
PacketType* data_input = &cc->Inputs().Tag("INPUT"); PacketType* data_input = &cc->Inputs().Tag(kInputTag);
data_input->SetAny(); data_input->SetAny();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) { id < cc->Outputs().EndId("OUTPUT"); ++id) {
@ -182,7 +192,7 @@ REGISTER_CALCULATOR(DemuxTimedCalculator);
class MuxTimedCalculator : public CalculatorBase { class MuxTimedCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("SELECT").Set<int>(); cc->Inputs().Tag(kSelectTag).Set<int>();
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
PacketType* data_input0 = &cc->Inputs().Get(data_input_id); PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny(); data_input0->SetAny();
@ -191,7 +201,7 @@ class MuxTimedCalculator : public CalculatorBase {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0); cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
} }
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0);
cc->SetTimestampOffset(TimestampDiff(0)); cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -598,12 +608,12 @@ class ErrorOnOpenCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>(); cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) { if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
return absl::NotFoundError("expected error"); return absl::NotFoundError("expected error");
} }
return absl::OkStatus(); return absl::OkStatus();
@ -920,8 +930,8 @@ class SemaphoreCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("POST_SEM").Set<Semaphore*>(); cc->InputSidePackets().Tag(kPostSemTag).Set<Semaphore*>();
cc->InputSidePackets().Tag("WAIT_SEM").Set<Semaphore*>(); cc->InputSidePackets().Tag(kWaitSemTag).Set<Semaphore*>();
cc->SetTimestampOffset(TimestampDiff(0)); cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -929,8 +939,8 @@ class SemaphoreCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
cc->InputSidePackets().Tag("POST_SEM").Get<Semaphore*>()->Release(1); cc->InputSidePackets().Tag(kPostSemTag).Get<Semaphore*>()->Release(1);
cc->InputSidePackets().Tag("WAIT_SEM").Get<Semaphore*>()->Acquire(1); cc->InputSidePackets().Tag(kWaitSemTag).Get<Semaphore*>()->Acquire(1);
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus(); return absl::OkStatus();
} }
@ -1177,9 +1187,9 @@ class IncrementingStatusHandler : public StatusHandler {
static absl::Status FillExpectations( static absl::Status FillExpectations(
const MediaPipeOptions& extendable_options, const MediaPipeOptions& extendable_options,
PacketTypeSet* input_side_packets) { PacketTypeSet* input_side_packets) {
input_side_packets->Tag("EXTRA").SetAny().Optional(); input_side_packets->Tag(kExtraTag).SetAny().Optional();
input_side_packets->Tag("COUNTER1").Set<std::unique_ptr<int>>(); input_side_packets->Tag(kCounter1Tag).Set<std::unique_ptr<int>>();
input_side_packets->Tag("COUNTER2").Set<std::unique_ptr<int>>(); input_side_packets->Tag(kCounter2Tag).Set<std::unique_ptr<int>>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -1187,7 +1197,7 @@ class IncrementingStatusHandler : public StatusHandler {
const MediaPipeOptions& extendable_options, const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, // const PacketSet& input_side_packets, //
const absl::Status& pre_run_status) { const absl::Status& pre_run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER1")); int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter1Tag));
(*counter)++; (*counter)++;
return pre_run_status_result_; return pre_run_status_result_;
} }
@ -1195,7 +1205,7 @@ class IncrementingStatusHandler : public StatusHandler {
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, // const PacketSet& input_side_packets, //
const absl::Status& run_status) { const absl::Status& run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER2")); int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter2Tag));
(*counter)++; (*counter)++;
return post_run_status_result_; return post_run_status_result_;
} }
@ -2228,20 +2238,20 @@ class DemuxUntimedCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag("INPUT").SetAny(); cc->Inputs().Tag(kInputTag).SetAny();
cc->Inputs().Tag("SELECT").Set<int>(); cc->Inputs().Tag(kSelectTag).Set<int>();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) { id < cc->Outputs().EndId("OUTPUT"); ++id) {
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT")); cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag));
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
int index = cc->Inputs().Tag("SELECT").Get<int>(); int index = cc->Inputs().Tag(kSelectTag).Get<int>();
if (!cc->Inputs().Tag("INPUT").IsEmpty()) { if (!cc->Inputs().Tag(kInputTag).IsEmpty()) {
cc->Outputs() cc->Outputs()
.Get("OUTPUT", index) .Get("OUTPUT", index)
.AddPacket(cc->Inputs().Tag("INPUT").Value()); .AddPacket(cc->Inputs().Tag(kInputTag).Value());
} else { } else {
cc->Outputs() cc->Outputs()
.Get("OUTPUT", index) .Get("OUTPUT", index)

View File

@ -32,6 +32,11 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kTag[] = "";
constexpr char kBTag[] = "B";
constexpr char kATag[] = "A";
constexpr char kSideOutputTag[] = "SIDE_OUTPUT";
// Inputs: 2 streams with ints. Headers are strings. // Inputs: 2 streams with ints. Headers are strings.
// Input side packets: 1. // Input side packets: 1.
// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from // Outputs: 3 streams with ints. #0 and #1 will contain the negated values from
@ -48,7 +53,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0)); cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0));
cc->InputSidePackets().Index(0).SetAny(); cc->InputSidePackets().Index(0).SetAny();
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag("SIDE_OUTPUT") .Tag(kSideOutputTag)
.SetSameAs(&cc->InputSidePackets().Index(0)); .SetSameAs(&cc->InputSidePackets().Index(0));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -64,7 +69,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
Adopt(new std::string(absl::StrCat(input_header_string, i)))); Adopt(new std::string(absl::StrCat(input_header_string, i))));
} }
cc->OutputSidePackets() cc->OutputSidePackets()
.Tag("SIDE_OUTPUT") .Tag(kSideOutputTag)
.Set(cc->InputSidePackets().Index(0)); .Set(cc->InputSidePackets().Index(0));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -152,7 +157,7 @@ TEST(CalculatorRunner, RunsCalculator) {
Adopt(new int(input_side_packet_content)); Adopt(new int(input_side_packet_content));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
EXPECT_EQ(input_side_packet_content, EXPECT_EQ(input_side_packet_content,
runner.OutputSidePackets().Tag("SIDE_OUTPUT").Get<int>()); runner.OutputSidePackets().Tag(kSideOutputTag).Get<int>());
const auto& outputs = runner.Outputs(); const auto& outputs = runner.Outputs();
ASSERT_EQ(3, outputs.NumEntries()); ASSERT_EQ(3, outputs.NumEntries());
@ -209,9 +214,9 @@ TEST(CalculatorRunner, MultiTagTestCalculatorOk) {
const auto& outputs = runner.Outputs(); const auto& outputs = runner.Outputs();
ASSERT_EQ(3, outputs.NumEntries()); ASSERT_EQ(3, outputs.NumEntries());
for (int ts = 0; ts < 5; ++ts) { for (int ts = 0; ts < 5; ++ts) {
const std::vector<Packet>& a_packets = outputs.Tag("A").packets; const std::vector<Packet>& a_packets = outputs.Tag(kATag).packets;
const std::vector<Packet>& b_packets = outputs.Tag("B").packets; const std::vector<Packet>& b_packets = outputs.Tag(kBTag).packets;
const std::vector<Packet>& c_packets = outputs.Tag("").packets; const std::vector<Packet>& c_packets = outputs.Tag(kTag).packets;
EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp()); EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp());
EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp()); EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp());
EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp()); EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp());

View File

@ -24,6 +24,10 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kTag2Tag[] = "TAG_2";
constexpr char kTag0Tag[] = "TAG_0";
constexpr char kTag1Tag[] = "TAG_1";
TEST(CollectionTest, BasicByIndex) { TEST(CollectionTest, BasicByIndex) {
tool::TagAndNameInfo info; tool::TagAndNameInfo info;
info.names.push_back("name_1"); info.names.push_back("name_1");
@ -55,14 +59,14 @@ TEST(CollectionTest, BasicByTag) {
info.names.push_back("name_2"); info.names.push_back("name_2");
info.tags.push_back("TAG_2"); info.tags.push_back("TAG_2");
internal::Collection<int> collection(info); internal::Collection<int> collection(info);
collection.Tag("TAG_1") = 101; collection.Tag(kTag1Tag) = 101;
collection.Tag("TAG_0") = 100; collection.Tag(kTag0Tag) = 100;
collection.Tag("TAG_2") = 102; collection.Tag(kTag2Tag) = 102;
// Test the stored values. // Test the stored values.
EXPECT_EQ(100, collection.Tag("TAG_0")); EXPECT_EQ(100, collection.Tag(kTag0Tag));
EXPECT_EQ(101, collection.Tag("TAG_1")); EXPECT_EQ(101, collection.Tag(kTag1Tag));
EXPECT_EQ(102, collection.Tag("TAG_2")); EXPECT_EQ(102, collection.Tag(kTag2Tag));
// Test access using a range based for. // Test access using a range based for.
int i = 0; int i = 0;
for (int num : collection) { for (int num : collection) {

View File

@ -134,6 +134,21 @@ void Tensor::AllocateMtlBuffer(id<MTLDevice> device) const {
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
bool Tensor::NeedsHalfFloatRenderTarget() const {
static bool has_color_buffer_float =
gl_context_->HasGlExtension("WEBGL_color_buffer_float") ||
gl_context_->HasGlExtension("EXT_color_buffer_float");
if (!has_color_buffer_float) {
static bool has_color_buffer_half_float =
gl_context_->HasGlExtension("EXT_color_buffer_half_float");
LOG_IF(FATAL, !has_color_buffer_half_float)
<< "EXT_color_buffer_half_float or WEBGL_color_buffer_float "
<< "required on web to use MP tensor";
return true;
}
return false;
}
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const { Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
LOG_IF(FATAL, valid_ == kValidNone) LOG_IF(FATAL, valid_ == kValidNone)
<< "Tensor must be written prior to read from."; << "Tensor must be written prior to read from.";
@ -164,8 +179,24 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
// Set alignment for the proper value (default) to avoid address sanitizer // Set alignment for the proper value (default) to avoid address sanitizer
// error "out of boundary reading". // error "out of boundary reading".
glPixelStorei(GL_UNPACK_ALIGNMENT, 4); glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
#ifdef __EMSCRIPTEN__
// Under WebGL1, format must match in order to use glTexSubImage2D, so if we
// have a half-float texture, then uploading from GL_FLOAT here would fail.
// We change the texture's data type to float here to accommodate.
// Furthermore, for a full-image replacement operation, glTexImage2D is
// expected to be more performant than glTexSubImage2D. Note that for WebGL2
// we cannot use glTexImage2D, because we allocate using glTexStorage2D in
// that case, which is incompatible.
if (gl_context_->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
0, GL_RGBA, GL_FLOAT, temp_buffer.get());
texture_is_half_float_ = false;
} else
#endif // __EMSCRIPTEN__
{
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_, glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
GL_RGBA, GL_FLOAT, temp_buffer.get()); GL_RGBA, GL_FLOAT, temp_buffer.get());
}
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
valid_ |= kValidOpenGlTexture2d; valid_ |= kValidOpenGlTexture2d;
} }
@ -175,6 +206,16 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const { Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_); auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
AllocateOpenGlTexture2d(); AllocateOpenGlTexture2d();
#ifdef __EMSCRIPTEN__
// On web, we may have to change type from float to half-float
if (!texture_is_half_float_ && NeedsHalfFloatRenderTarget()) {
glBindTexture(GL_TEXTURE_2D, opengl_texture2d_);
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, 0,
GL_RGBA, GL_HALF_FLOAT_OES, 0 /* data */);
glBindTexture(GL_TEXTURE_2D, 0);
texture_is_half_float_ = true;
}
#endif
valid_ = kValidOpenGlTexture2d; valid_ = kValidOpenGlTexture2d;
return {opengl_texture2d_, std::move(lock)}; return {opengl_texture2d_, std::move(lock)};
} }
@ -255,8 +296,18 @@ void Tensor::AllocateOpenGlTexture2d() const {
<< "with GLES 2.0"; << "with GLES 2.0";
// Allocate the image data; note that it's no longer RGBA32F, so will be // Allocate the image data; note that it's no longer RGBA32F, so will be
// lower precision. // lower precision.
auto type = GL_FLOAT;
// On web, we might need to change type to half-float (e.g. for iOS-
// Safari) in order to have a valid framebuffer. See b/194442743 for more
// details.
#ifdef __EMSCRIPTEN__
if (NeedsHalfFloatRenderTarget()) {
type = GL_HALF_FLOAT_OES;
texture_is_half_float_ = true;
}
#endif // __EMSCRIPTEN
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
0, GL_RGBA, GL_FLOAT, 0 /* data */); 0, GL_RGBA, type, 0 /* data */);
} }
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
glGenFramebuffers(1, &frame_buffer_); glGenFramebuffers(1, &frame_buffer_);
@ -443,7 +494,6 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
glPixelStorei(GL_PACK_ALIGNMENT, 4); glPixelStorei(GL_PACK_ALIGNMENT, 4);
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT, glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
buffer); buffer);
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_); uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
const int actual_depth_size = const int actual_depth_size =
BhwcDepthFromShape(shape_) * element_size(); BhwcDepthFromShape(shape_) * element_size();

View File

@ -266,11 +266,15 @@ class Tensor {
mutable GLuint frame_buffer_ = GL_INVALID_INDEX; mutable GLuint frame_buffer_ = GL_INVALID_INDEX;
mutable int texture_width_; mutable int texture_width_;
mutable int texture_height_; mutable int texture_height_;
#ifdef __EMSCRIPTEN__
mutable bool texture_is_half_float_ = false;
#endif // __EMSCRIPTEN__
void AllocateOpenGlTexture2d() const; void AllocateOpenGlTexture2d() const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
mutable GLuint opengl_buffer_ = GL_INVALID_INDEX; mutable GLuint opengl_buffer_ = GL_INVALID_INDEX;
void AllocateOpenGlBuffer() const; void AllocateOpenGlBuffer() const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
bool NeedsHalfFloatRenderTarget() const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
}; };

View File

@ -31,6 +31,11 @@ namespace mediapipe {
namespace { namespace {
constexpr char kOutputTag[] = "OUTPUT";
constexpr char kEnableTag[] = "ENABLE";
constexpr char kSelectTag[] = "SELECT";
constexpr char kSideinputTag[] = "SIDEINPUT";
// Shows validation success for a graph and a subgraph. // Shows validation success for a graph and a subgraph.
TEST(GraphValidationTest, InitializeGraphFromProtos) { TEST(GraphValidationTest, InitializeGraphFromProtos) {
auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -323,20 +328,21 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) {
class OptionalSideInputTestCalculator : public CalculatorBase { class OptionalSideInputTestCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("SIDEINPUT").Set<std::string>().Optional(); cc->InputSidePackets().Tag(kSideinputTag).Set<std::string>().Optional();
cc->Inputs().Tag("SELECT").Set<int>().Optional(); cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
cc->Inputs().Tag("ENABLE").Set<bool>().Optional(); cc->Inputs().Tag(kEnableTag).Set<bool>().Optional();
cc->Outputs().Tag("OUTPUT").Set<std::string>(); cc->Outputs().Tag(kOutputTag).Set<std::string>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
std::string value("default"); std::string value("default");
if (cc->InputSidePackets().HasTag("SIDEINPUT")) { if (cc->InputSidePackets().HasTag(kSideinputTag)) {
value = cc->InputSidePackets().Tag("SIDEINPUT").Get<std::string>(); value = cc->InputSidePackets().Tag(kSideinputTag).Get<std::string>();
} }
cc->Outputs().Tag("OUTPUT").Add(new std::string(value), cc->Outputs()
cc->InputTimestamp()); .Tag(kOutputTag)
.Add(new std::string(value), cc->InputTimestamp());
return absl::OkStatus(); return absl::OkStatus();
} }
}; };

View File

@ -26,17 +26,20 @@ namespace {
namespace test_ns { namespace test_ns {
constexpr char kOutTag[] = "OUT";
constexpr char kInTag[] = "IN";
class TestSinkCalculator : public CalculatorBase { class TestSinkCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>(); cc->Inputs().Tag(kInTag).Set<mediapipe::InputOnlyProto>();
cc->Outputs().Tag("OUT").Set<int>(); cc->Outputs().Tag(kOutTag).Set<int>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x(); int x = cc->Inputs().Tag(kInTag).Get<mediapipe::InputOnlyProto>().x();
cc->Outputs().Tag("OUT").AddPacket( cc->Outputs().Tag(kOutTag).AddPacket(
MakePacket<int>(x).At(cc->InputTimestamp())); MakePacket<int>(x).At(cc->InputTimestamp()));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -34,6 +34,19 @@
namespace mediapipe { namespace mediapipe {
constexpr char kOutTag[] = "OUT";
constexpr char kClockTag[] = "CLOCK";
constexpr char kSleepMicrosTag[] = "SLEEP_MICROS";
constexpr char kCloseTag[] = "CLOSE";
constexpr char kProcessTag[] = "PROCESS";
constexpr char kOpenTag[] = "OPEN";
constexpr char kTag[] = "";
constexpr char kMeanTag[] = "MEAN";
constexpr char kDataTag[] = "DATA";
constexpr char kPairTag[] = "PAIR";
constexpr char kLowTag[] = "LOW";
constexpr char kHighTag[] = "HIGH";
using RandomEngine = std::mt19937_64; using RandomEngine = std::mt19937_64;
// A Calculator that outputs twice the value of its input packet (an int). // A Calculator that outputs twice the value of its input packet (an int).
@ -95,9 +108,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
PacketTypeSet* input_side_packets, // PacketTypeSet* input_side_packets, //
PacketTypeSet* output_side_packets) { PacketTypeSet* output_side_packets) {
input_side_packets->Index(0).Set<uint64>(); input_side_packets->Index(0).Set<uint64>();
output_side_packets->Tag("HIGH").Set<uint32>(); output_side_packets->Tag(kHighTag).Set<uint32>();
output_side_packets->Tag("LOW").Set<uint32>(); output_side_packets->Tag(kLowTag).Set<uint32>();
output_side_packets->Tag("PAIR").Set<std::pair<uint32, uint32>>(); output_side_packets->Tag(kPairTag).Set<std::pair<uint32, uint32>>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -108,9 +121,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
uint64 value = input_side_packets.Index(0).Get<uint64>(); uint64 value = input_side_packets.Index(0).Get<uint64>();
uint32 high = value >> 32; uint32 high = value >> 32;
uint32 low = value & 0xFFFFFFFF; uint32 low = value & 0xFFFFFFFF;
output_side_packets->Tag("HIGH") = Adopt(new uint32(high)); output_side_packets->Tag(kHighTag) = Adopt(new uint32(high));
output_side_packets->Tag("LOW") = Adopt(new uint32(low)); output_side_packets->Tag(kLowTag) = Adopt(new uint32(low));
output_side_packets->Tag("PAIR") = output_side_packets->Tag(kPairTag) =
Adopt(new std::pair<uint32, uint32>(high, low)); Adopt(new std::pair<uint32, uint32>(high, low));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -221,8 +234,8 @@ class StdDevCalculator : public CalculatorBase {
StdDevCalculator() {} StdDevCalculator() {}
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("DATA").Set<int>(); cc->Inputs().Tag(kDataTag).Set<int>();
cc->Inputs().Tag("MEAN").Set<double>(); cc->Inputs().Tag(kMeanTag).Set<double>();
cc->Outputs().Index(0).Set<int>(); cc->Outputs().Index(0).Set<int>();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -234,15 +247,15 @@ class StdDevCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
if (cc->InputTimestamp() == Timestamp::PreStream()) { if (cc->InputTimestamp() == Timestamp::PreStream()) {
RET_CHECK(cc->Inputs().Tag("DATA").Value().IsEmpty()); RET_CHECK(cc->Inputs().Tag(kDataTag).Value().IsEmpty());
RET_CHECK(!cc->Inputs().Tag("MEAN").Value().IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
mean_ = cc->Inputs().Tag("MEAN").Get<double>(); mean_ = cc->Inputs().Tag(kMeanTag).Get<double>();
initialized_ = true; initialized_ = true;
} else { } else {
RET_CHECK(initialized_); RET_CHECK(initialized_);
RET_CHECK(!cc->Inputs().Tag("DATA").Value().IsEmpty()); RET_CHECK(!cc->Inputs().Tag(kDataTag).Value().IsEmpty());
RET_CHECK(cc->Inputs().Tag("MEAN").Value().IsEmpty()); RET_CHECK(cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
double diff = cc->Inputs().Tag("DATA").Get<int>() - mean_; double diff = cc->Inputs().Tag(kDataTag).Get<int>() - mean_;
cummulative_variance_ += diff * diff; cummulative_variance_ += diff * diff;
++count_; ++count_;
} }
@ -564,8 +577,8 @@ class LambdaCalculator : public CalculatorBase {
id < cc->Outputs().EndId(); ++id) { id < cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).SetAny(); cc->Outputs().Get(id).SetAny();
} }
if (cc->InputSidePackets().HasTag("") > 0) { if (cc->InputSidePackets().HasTag(kTag) > 0) {
cc->InputSidePackets().Tag("").Set<ProcessFunction>(); cc->InputSidePackets().Tag(kTag).Set<ProcessFunction>();
} }
for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) { for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) {
if (cc->InputSidePackets().HasTag(tag)) { if (cc->InputSidePackets().HasTag(tag)) {
@ -576,24 +589,24 @@ class LambdaCalculator : public CalculatorBase {
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("OPEN")) { if (cc->InputSidePackets().HasTag(kOpenTag)) {
return GetContextFn(cc, "OPEN")(cc); return GetContextFn(cc, "OPEN")(cc);
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("PROCESS")) { if (cc->InputSidePackets().HasTag(kProcessTag)) {
return GetContextFn(cc, "PROCESS")(cc); return GetContextFn(cc, "PROCESS")(cc);
} }
if (cc->InputSidePackets().HasTag("") > 0) { if (cc->InputSidePackets().HasTag(kTag) > 0) {
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs()); return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Close(CalculatorContext* cc) final { absl::Status Close(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("CLOSE")) { if (cc->InputSidePackets().HasTag(kCloseTag)) {
return GetContextFn(cc, "CLOSE")(cc); return GetContextFn(cc, "CLOSE")(cc);
} }
return absl::OkStatus(); return absl::OkStatus();
@ -645,17 +658,18 @@ class PassThroughWithSleepCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>(); cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>(); cc->InputSidePackets().Tag(kSleepMicrosTag).Set<int>();
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>(); cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>(); sleep_micros_ = cc->InputSidePackets().Tag(kSleepMicrosTag).Get<int>();
if (sleep_micros_ < 0) { if (sleep_micros_ < 0) {
return absl::InternalError("SLEEP_MICROS should be >= 0"); return absl::InternalError("SLEEP_MICROS should be >= 0");
} }
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>(); clock_ =
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
@ -678,8 +692,8 @@ class MultiplyIntCalculator : public CalculatorBase {
cc->Inputs().Index(0).Set<int>(); cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); // cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
RET_CHECK(cc->Outputs().HasTag("OUT")); RET_CHECK(cc->Outputs().HasTag(kOutTag));
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0)); cc->Outputs().Tag(kOutTag).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
@ -689,7 +703,7 @@ class MultiplyIntCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
int x = cc->Inputs().Index(0).Value().Get<int>(); int x = cc->Inputs().Index(0).Value().Get<int>();
int y = cc->Inputs().Index(1).Value().Get<int>(); int y = cc->Inputs().Index(1).Value().Get<int>();
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp()); cc->Outputs().Tag(kOutTag).Add(new int(x * y), cc->InputTimestamp());
return absl::OkStatus(); return absl::OkStatus();
} }
}; };

View File

@ -60,6 +60,18 @@ def mediapipe_aar(
assets: additional assets to be included into the archive. assets: additional assets to be included into the archive.
assets_dir: path where the assets will the packaged. assets_dir: path where the assets will the packaged.
""" """
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
# the OpenCV so libraries will be excluded from the AAR package to
# save the package size.
native.config_setting(
name = "exclude_opencv_so_lib",
define_values = {
"EXCLUDE_OPENCV_SO_LIB": "1",
},
visibility = ["//visibility:public"],
)
_mediapipe_jni( _mediapipe_jni(
name = name + "_jni", name = name + "_jni",
gen_libmediapipe = gen_libmediapipe, gen_libmediapipe = gen_libmediapipe,
@ -133,6 +145,7 @@ EOF
] + select({ ] + select({
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"], "//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
"//mediapipe/framework/port:disable_opencv": [], "//mediapipe/framework/port:disable_opencv": [],
"exclude_opencv_so_lib": [],
}), }),
assets = assets, assets = assets,
assets_dir = assets_dir, assets_dir = assets_dir,

View File

@ -245,14 +245,6 @@ void Box::Fit(const std::vector<T>& vertices) {
auto system_g = system_h.colPivHouseholderQr(); auto system_g = system_h.colPivHouseholderQr();
auto solution = system_g.solve(v).eval(); auto solution = system_g.solve(v).eval();
transformation_.topLeftCorner<3, 4>() = solution.transpose(); transformation_.topLeftCorner<3, 4>() = solution.transpose();
// Adjust rotation matrix to its nearest orthogonal matrix.
const auto rotation = transformation_.topLeftCorner<3, 3>();
Eigen::JacobiSVD<Eigen::Matrix3f> svd(
rotation, Eigen::ComputeFullV | Eigen::ComputeFullU);
const Eigen::Matrix3f matrix_u = svd.matrixU();
const Eigen::Matrix3f matrix_v = svd.matrixV();
transformation_.topLeftCorner<3, 3>() = matrix_u * matrix_v.transpose();
Update(); Update();
} }

View File

@ -254,11 +254,11 @@ class SolutionBase:
for stream_name in self._output_stream_type_info.keys(): for stream_name in self._output_stream_type_info.keys():
self._graph.observe_output_stream(stream_name, callback, True) self._graph.observe_output_stream(stream_name, callback, True)
input_side_packets = { self._input_side_packets = {
name: self._make_packet(self._side_input_type_info[name], data) name: self._make_packet(self._side_input_type_info[name], data)
for name, data in (side_inputs or {}).items() for name, data in (side_inputs or {}).items()
} }
self._graph.start_run(input_side_packets) self._graph.start_run(self._input_side_packets)
# TODO: Use "inspect.Parameter" to fetch the input argument names and # TODO: Use "inspect.Parameter" to fetch the input argument names and
# types from "_input_stream_type_info" and then auto generate the process # types from "_input_stream_type_info" and then auto generate the process
@ -353,6 +353,12 @@ class SolutionBase:
self._input_stream_type_info = None self._input_stream_type_info = None
self._output_stream_type_info = None self._output_stream_type_info = None
def reset(self) -> None:
"""Resets the graph for another run."""
if self._graph:
self._graph.close()
self._graph.start_run(self._input_side_packets)
def _initialize_graph_interface( def _initialize_graph_interface(
self, self,
validated_graph: validated_graph_config.ValidatedGraphConfig, validated_graph: validated_graph_config.ValidatedGraphConfig,

View File

@ -298,6 +298,56 @@ class SolutionBaseTest(parameterized.TestCase):
'ImageTransformation.output_height': 0 'ImageTransformation.output_height': 0
}) })
@parameterized.named_parameters(('graph_without_side_packets', """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:transformed_image_in'
output_stream: 'IMAGE:image_out'
}
""", None), ('graph_with_side_packets', """
input_stream: 'image_in'
input_side_packet: 'allow_signal'
input_side_packet: 'rotation_degrees'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'GateCalculator'
input_stream: 'transformed_image_in'
input_side_packet: 'ALLOW:allow_signal'
output_stream: 'image_out_to_transform'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_out_to_transform'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:image_out'
}""", {
'allow_signal': True,
'rotation_degrees': 0
}))
def test_solution_reset(self, text_config, side_inputs):
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
with solution_base.SolutionBase(
graph_config=config_proto, side_inputs=side_inputs) as solution:
for _ in range(20):
outputs = solution.process(input_image)
self.assertTrue(np.array_equal(input_image, outputs.image_out))
solution.reset()
def _process_and_verify(self, def _process_and_verify(self,
config_proto, config_proto,
side_inputs=None, side_inputs=None,

View File

@ -26,7 +26,7 @@ import numpy.testing as npt
from mediapipe.python.solutions import objectron as mp_objectron from mediapipe.python.solutions import objectron as mp_objectron
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
DIFF_THRESHOLD = 35 # pixels DIFF_THRESHOLD = 30 # pixels
EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457], EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457],
[383, 505], [80, 478], [408, 345], [383, 505], [80, 478], [408, 345],
[130, 347], [384, 355], [72, 353]], [130, 347], [384, 355], [72, 353]],