Project import generated by Copybara.

GitOrigin-RevId: 8e1da4611d93ccb7d9674713157d43be0348d98f
This commit is contained in:
MediaPipe Team 2021-07-27 18:20:05 -07:00 committed by chuoling
parent 50c92c6623
commit b899d17f18
79 changed files with 1808 additions and 946 deletions

View File

@ -220,6 +220,7 @@ import cv2
import mediapipe as mp
mp_drawing = mp.solutions.drawing_utils
mp_hands = mp.solutions.hands
drawing_styles = mp.solutions.drawing_styles
# For static images:
IMAGE_FILES = []
@ -248,7 +249,9 @@ with mp_hands.Hands(
f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})'
)
mp_drawing.draw_landmarks(
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
drawing_styles.get_default_hand_landmark_style(),
drawing_styles.get_default_hand_connection_style())
cv2.imwrite(
'/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1))
@ -278,7 +281,9 @@ with mp_hands.Hands(
if results.multi_hand_landmarks:
for hand_landmarks in results.multi_hand_landmarks:
mp_drawing.draw_landmarks(
image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
drawing_styles.get_default_hand_landmark_style(),
drawing_styles.get_default_hand_connection_style())
cv2.imshow('MediaPipe Hands', image)
if cv2.waitKey(5) & 0xFF == 27:
break

View File

@ -24,6 +24,9 @@
namespace mediapipe {
constexpr char kDataTag[] = "DATA";
constexpr char kHeaderTag[] = "HEADER";
class AddHeaderCalculatorTest : public ::testing::Test {};
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) {
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
runner.MutableInputs()->Tag(kHeaderTag).header =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
}
// Run calculator.
@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableInputs()->Tag("HEADER").header =
runner.MutableInputs()->Tag(kHeaderTag).header =
Adopt(new std::string("my_header"));
runner.MutableInputs()->Tag("HEADER").packets.push_back(
Adopt(new std::string("not allowed")));
runner.MutableInputs()
->Tag(kHeaderTag)
.packets.push_back(Adopt(new std::string("not allowed")));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
}
// Run calculator.
@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
CalculatorRunner runner(node);
// Set header and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") =
runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
}
// Run calculator.
@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
CalculatorRunner runner(node);
// Set both headers and add 5 packets.
runner.MutableSidePackets()->Tag("HEADER") =
runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header"));
runner.MutableSidePackets()->Tag("HEADER") =
runner.MutableSidePackets()->Tag(kHeaderTag) =
Adopt(new std::string("my_header"));
for (int i = 0; i < 5; ++i) {
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
}
// Run should fail because header can only be provided one way.

View File

@ -19,6 +19,13 @@
namespace mediapipe {
constexpr char kIncrementTag[] = "INCREMENT";
constexpr char kInitialValueTag[] = "INITIAL_VALUE";
constexpr char kBatchSizeTag[] = "BATCH_SIZE";
constexpr char kErrorCountTag[] = "ERROR_COUNT";
constexpr char kMaxCountTag[] = "MAX_COUNT";
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
// sequential numbers from INITIAL_VALUE (default 0) with a common
// difference of INCREMENT (default 1) between successive numbers (with
@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) {
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
}
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) ||
cc->InputSidePackets().HasTag(kErrorCountTag));
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
cc->InputSidePackets().Tag(kMaxCountTag).Set<int>();
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
cc->InputSidePackets().Tag(kErrorCountTag).Set<int>();
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
cc->InputSidePackets().Tag(kBatchSizeTag).Set<int>();
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
cc->InputSidePackets().Tag(kInitialValueTag).Set<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
cc->InputSidePackets().Tag(kIncrementTag).Set<int>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) &&
cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
return absl::NotFoundError("expected error");
}
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get<int>();
RET_CHECK_LE(0, error_count_);
}
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get<int>();
RET_CHECK_LE(0, max_count_);
}
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get<int>();
RET_CHECK_LT(0, batch_size_);
}
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get<int>();
}
if (cc->InputSidePackets().HasTag("INCREMENT")) {
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get<int>();
RET_CHECK_LT(0, increment_);
}
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);

View File

@ -35,11 +35,14 @@
// }
namespace mediapipe {
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
constexpr char kEncodedTag[] = "ENCODED";
class DequantizeByteArrayCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("ENCODED").Set<std::string>();
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
cc->Inputs().Tag(kEncodedTag).Set<std::string>();
cc->Outputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
return absl::OkStatus();
}
@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
const std::string& encoded =
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
cc->Inputs().Tag(kEncodedTag).Value().Get<std::string>();
std::vector<float> float_vector;
float_vector.reserve(encoded.length());
for (int i = 0; i < encoded.length(); ++i) {
@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
}
cc->Outputs()
.Tag("FLOAT_VECTOR")
.Tag(kFloatVectorTag)
.AddPacket(MakePacket<std::vector<float>>(float_vector)
.At(cc->InputTimestamp()));
return absl::OkStatus();

View File

@ -25,6 +25,9 @@
namespace mediapipe {
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
constexpr char kEncodedTag[] = "ENCODED";
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -39,7 +42,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
@ -64,7 +69,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
@ -89,7 +96,9 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
)pb");
CalculatorRunner runner(node_config);
std::string empty_string;
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(empty_string).At(Timestamp(0)));
auto status = runner.Run();
EXPECT_FALSE(status.ok());
@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
)pb");
CalculatorRunner runner(node_config);
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
runner.MutableInputs()
->Tag(kEncodedTag)
.packets.push_back(
MakePacket<std::string>(
std::string(reinterpret_cast<char const*>(input), 4))
.At(Timestamp(0)));
auto status = runner.Run();
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs =
runner.Outputs().Tag("FLOAT_VECTOR").packets;
runner.Outputs().Tag(kFloatVectorTag).packets;
EXPECT_EQ(1, outputs.size());
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
ASSERT_FALSE(result.empty());

View File

@ -24,6 +24,11 @@
namespace mediapipe {
constexpr char kFinishedTag[] = "FINISHED";
constexpr char kAllowTag[] = "ALLOW";
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
constexpr char kOptionsTag[] = "OPTIONS";
// FlowLimiterCalculator is used to limit the number of frames in flight
// by dropping input frames when necessary.
//
@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
auto& side_inputs = cc->InputSidePackets();
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
side_inputs.Tag(kOptionsTag).Set<FlowLimiterCalculatorOptions>().Optional();
cc->Inputs()
.Tag(kOptionsTag)
.Set<FlowLimiterCalculatorOptions>()
.Optional();
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
}
cc->Inputs().Get("FINISHED", 0).SetAny();
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional();
cc->Outputs().Tag("ALLOW").Set<bool>().Optional();
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>().Optional();
cc->Outputs().Tag(kAllowTag).Set<bool>().Optional();
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final {
options_ = cc->Options<FlowLimiterCalculatorOptions>();
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
options_.set_max_in_flight(
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>());
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
}
input_queues_.resize(cc->Inputs().NumEntries(""));
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase {
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag("ALLOW")) {
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts));
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
}
}
@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase {
options_ = tool::RetrieveOptions(options_, cc->Inputs());
// Process the FINISHED input stream.
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value();
Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value();
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
while (!frames_in_flight_.empty() &&
frames_in_flight_.front() <= finished_packet.Timestamp()) {
@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase {
Timestamp bound =
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
if (cc->Outputs().HasTag("ALLOW")) {
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW"));
if (cc->Outputs().HasTag(kAllowTag)) {
SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag));
}
}

View File

@ -36,6 +36,13 @@
namespace mediapipe {
namespace {
constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS";
constexpr char kClockTag[] = "CLOCK";
constexpr char kWarmupTimeTag[] = "WARMUP_TIME";
constexpr char kSleepTimeTag[] = "SLEEP_TIME";
constexpr char kPacketTag[] = "PACKET";
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
class SleepCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>();
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>();
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>();
cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
cc->InputSidePackets().Tag(kSleepTimeTag).Set<int64>();
cc->InputSidePackets().Tag(kWarmupTimeTag).Set<int64>();
cc->InputSidePackets().Tag(kClockTag).Set<mediapipe::Clock*>();
cc->SetTimestampOffset(0);
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>();
clock_ = cc->InputSidePackets().Tag(kClockTag).Get<mediapipe::Clock*>();
return absl::OkStatus();
}
@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase {
++packet_count;
absl::Duration sleep_time = absl::Microseconds(
packet_count == 1
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>()
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>());
? cc->InputSidePackets().Tag(kWarmupTimeTag).Get<int64>()
: cc->InputSidePackets().Tag(kSleepTimeTag).Get<int64>());
clock_->Sleep(sleep_time);
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
cc->Outputs()
.Tag(kPacketTag)
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
return absl::OkStatus();
}
@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator);
class DropCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>();
cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
cc->InputSidePackets().Tag(kDropTimestampsTag).Set<bool>();
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
++packet_count;
}
bool drop = (packet_count == 3);
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
cc->Outputs()
.Tag(kPacketTag)
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
}
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) {
cc->Outputs().Tag("PACKET").SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get<bool>()) {
cc->Outputs()
.Tag(kPacketTag)
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
}
return absl::OkStatus();
}

View File

@ -21,6 +21,11 @@
namespace mediapipe {
namespace {
constexpr char kStateChangeTag[] = "STATE_CHANGE";
constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW";
enum GateState {
GATE_UNINITIALIZED,
GATE_ALLOW,
@ -83,30 +88,31 @@ class GateCalculator : public CalculatorBase {
GateCalculator() {}
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
cc->InputSidePackets().HasTag("DISALLOW");
bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) ||
cc->InputSidePackets().HasTag(kDisallowTag);
bool input_via_stream =
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag);
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
// input.
RET_CHECK(input_via_side_packet ^ input_via_stream);
if (input_via_side_packet) {
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
cc->InputSidePackets().HasTag("DISALLOW"));
RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^
cc->InputSidePackets().HasTag(kDisallowTag));
if (cc->InputSidePackets().HasTag("ALLOW")) {
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
if (cc->InputSidePackets().HasTag(kAllowTag)) {
cc->InputSidePackets().Tag(kAllowTag).Set<bool>();
} else {
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
cc->InputSidePackets().Tag(kDisallowTag).Set<bool>();
}
} else {
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^
cc->Inputs().HasTag(kDisallowTag));
if (cc->Inputs().HasTag("ALLOW")) {
cc->Inputs().Tag("ALLOW").Set<bool>();
if (cc->Inputs().HasTag(kAllowTag)) {
cc->Inputs().Tag(kAllowTag).Set<bool>();
} else {
cc->Inputs().Tag("DISALLOW").Set<bool>();
cc->Inputs().Tag(kDisallowTag).Set<bool>();
}
}
return absl::OkStatus();
@ -125,8 +131,8 @@ class GateCalculator : public CalculatorBase {
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
}
if (cc->Outputs().HasTag("STATE_CHANGE")) {
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
if (cc->Outputs().HasTag(kStateChangeTag)) {
cc->Outputs().Tag(kStateChangeTag).Set<bool>();
}
return absl::OkStatus();
@ -134,14 +140,14 @@ class GateCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final {
use_side_packet_for_allow_disallow_ = false;
if (cc->InputSidePackets().HasTag("ALLOW")) {
if (cc->InputSidePackets().HasTag(kAllowTag)) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
cc->InputSidePackets().Tag(kAllowTag).Get<bool>();
} else if (cc->InputSidePackets().HasTag(kDisallowTag)) {
use_side_packet_for_allow_disallow_ = true;
allow_by_side_packet_decision_ =
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
!cc->InputSidePackets().Tag(kDisallowTag).Get<bool>();
}
cc->SetOffset(TimestampDiff(0));
@ -160,18 +166,18 @@ class GateCalculator : public CalculatorBase {
if (use_side_packet_for_allow_disallow_) {
allow = allow_by_side_packet_decision_;
} else {
if (cc->Inputs().HasTag("ALLOW") &&
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
if (cc->Inputs().HasTag(kAllowTag) &&
!cc->Inputs().Tag(kAllowTag).IsEmpty()) {
allow = cc->Inputs().Tag(kAllowTag).Get<bool>();
}
if (cc->Inputs().HasTag("DISALLOW") &&
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
if (cc->Inputs().HasTag(kDisallowTag) &&
!cc->Inputs().Tag(kDisallowTag).IsEmpty()) {
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
}
}
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
if (cc->Outputs().HasTag("STATE_CHANGE")) {
if (cc->Outputs().HasTag(kStateChangeTag)) {
if (last_gate_state_ != GATE_UNINITIALIZED &&
last_gate_state_ != new_gate_state) {
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
@ -179,7 +185,7 @@ class GateCalculator : public CalculatorBase {
<< ToString(last_gate_state_) << " to "
<< ToString(new_gate_state);
cc->Outputs()
.Tag("STATE_CHANGE")
.Tag(kStateChangeTag)
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
}
}

View File

@ -22,6 +22,9 @@ namespace mediapipe {
namespace {
constexpr char kDisallowTag[] = "DISALLOW";
constexpr char kAllowTag[] = "ALLOW";
class GateCalculatorTest : public ::testing::Test {
protected:
// Helper to run a graph and return status.
@ -117,7 +120,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
@ -139,7 +142,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
@ -161,7 +164,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
@ -179,7 +182,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
input_stream: "test_input"
output_stream: "test_output"
)");
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);

View File

@ -39,20 +39,24 @@ using testing::ElementsAre;
namespace mediapipe {
namespace {
constexpr char kClockTag[] = "CLOCK";
using mediapipe::Clock;
// A Calculator with a fixed Process call latency.
class SleepCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
clock_ =
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
return absl::OkStatus();
}

View File

@ -29,6 +29,9 @@
namespace mediapipe {
namespace {
constexpr char kMinuendTag[] = "MINUEND";
constexpr char kSubtrahendTag[] = "SUBTRAHEND";
// A 3x4 Matrix of random integers in [0,1000).
const char kMatrixText[] =
"rows: 3\n"
@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
Adopt(input_matrix).At(Timestamp(0)));
runner.MutableInputs()
->Tag(kMinuendTag)
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
CalculatorRunner runner(node_config);
Matrix* side_matrix = new Matrix();
MatrixFromTextProto(kMatrixText, side_matrix);
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix);
Matrix* input_matrix = new Matrix();
MatrixFromTextProto(kMatrixText2, input_matrix);
runner.MutableInputs()
->Tag("SUBTRAHEND")
->Tag(kSubtrahendTag)
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());

View File

@ -17,6 +17,9 @@
namespace mediapipe {
constexpr char kPresenceTag[] = "PRESENCE";
constexpr char kPacketTag[] = "PACKET";
// For each non empty input packet, emits a single output packet containing a
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
// bound updates) This can be used to "flag" the presence of an arbitrary packet
@ -58,8 +61,8 @@ namespace mediapipe {
class PacketPresenceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("PACKET").SetAny();
cc->Outputs().Tag("PRESENCE").Set<bool>();
cc->Inputs().Tag(kPacketTag).SetAny();
cc->Outputs().Tag(kPresenceTag).Set<bool>();
// Process() function is invoked in response to input stream timestamp
// bound updates.
cc->SetProcessTimestampBounds(true);
@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
cc->Outputs()
.Tag("PRESENCE")
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty())
.Tag(kPresenceTag)
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag(kPacketTag).IsEmpty())
.At(cc->InputTimestamp()));
return absl::OkStatus();
}

View File

@ -39,6 +39,11 @@ namespace mediapipe {
REGISTER_CALCULATOR(PacketResamplerCalculator);
namespace {
constexpr char kSeedTag[] = "SEED";
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
constexpr char kOptionsTag[] = "OPTIONS";
// Returns a TimestampDiff (assuming microseconds) corresponding to the
// given time in seconds.
TimestampDiff TimestampDiffFromSeconds(double seconds) {
@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
const auto& resampler_options =
cc->Options<PacketResamplerCalculatorOptions>();
if (cc->InputSidePackets().HasTag("OPTIONS")) {
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
}
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
if (!input_data_id.IsValid()) {
input_data_id = cc->Inputs().GetId("", 0);
}
cc->Inputs().Get(input_data_id).SetAny();
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
if (cc->Inputs().HasTag(kVideoHeaderTag)) {
cc->Inputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
}
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
output_data_id = cc->Outputs().GetId("", 0);
}
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
if (cc->Outputs().HasTag(kVideoHeaderTag)) {
cc->Outputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
}
if (resampler_options.jitter() != 0.0) {
RET_CHECK_GT(resampler_options.jitter(), 0.0);
RET_CHECK_LE(resampler_options.jitter(), 1.0);
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
cc->InputSidePackets().Tag("SEED").Set<std::string>();
RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag));
cc->InputSidePackets().Tag(kSeedTag).Set<std::string>();
}
return absl::OkStatus();
}
@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream() &&
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) &&
!cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) {
video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get<VideoHeader>();
video_header_.frame_rate = frame_rate_;
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
return absl::OkStatus();
@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open(
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
"ignored, because we are adding jitter.";
}
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
random_ = CreateSecureRandom(seed);
if (random_ == nullptr) {
return absl::InvalidArgumentError(
@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
base_timestamp_ +
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
}
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) {
cc->Outputs()
.Tag("VIDEO_HEADER")
.Tag(kVideoHeaderTag)
.Add(new VideoHeader(calculator_->video_header_),
Timestamp::PreStream());
}

View File

@ -32,6 +32,12 @@ namespace mediapipe {
using ::testing::ElementsAre;
namespace {
constexpr char kOptionsTag[] = "OPTIONS";
constexpr char kSeedTag[] = "SEED";
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
constexpr char kDataTag[] = "DATA";
// A simple version of CalculatorRunner with built-in convenience
// methods for setting inputs from a vector and checking outputs
// against expected outputs (both timestamps and contents).
@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
)pb"));
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
runner.MutableInputs()->Tag("DATA").packets.push_back(
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
}
VideoHeader video_header_in;
@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
video_header_in.duration = 1.0;
video_header_in.format = ImageFormat::SRGB;
runner.MutableInputs()
->Tag("VIDEO_HEADER")
->Tag(kVideoHeaderTag)
.packets.push_back(
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
MP_ASSERT_OK(runner.Run());
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size());
EXPECT_EQ(Timestamp::PreStream(),
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp());
const VideoHeader& video_header_out =
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(video_header_in.width, video_header_out.width);
EXPECT_EQ(video_header_in.height, video_header_out.height);
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
[mediapipe.PacketResamplerCalculatorOptions.ext] {
frame_rate: 30
})pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
frame_rate: 30
base_timestamp: 0
})pb"));
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
runner.SetInput({-222, 15000, 32000, 49999, 150000});
MP_ASSERT_OK(runner.Run());

View File

@ -29,6 +29,8 @@
namespace mediapipe {
namespace {
constexpr char kPeriodTag[] = "PERIOD";
// A simple version of CalculatorRunner with built-in convenience methods for
// setting inputs from a vector and checking outputs against a vector of
// expected outputs.
@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 8, 14};
@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
SimpleRunner runner(node);
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
MP_ASSERT_OK(runner.Run());
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};

View File

@ -39,6 +39,8 @@ using ::testing::Pair;
using ::testing::Value;
namespace {
constexpr char kDisallowTag[] = "DISALLOW";
// Returns the timestamp values for a vector of Packets.
// TODO: puth this kind of test util in a common place.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Tag("DISALLOW").Set<bool>();
cc->Inputs().Tag(kDisallowTag).Set<bool>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!cc->Inputs().Index(0).IsEmpty() &&
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
!cc->Inputs().Tag(kDisallowTag).Get<bool>()) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
}
return absl::OkStatus();

View File

@ -41,11 +41,14 @@
// }
namespace mediapipe {
constexpr char kEncodedTag[] = "ENCODED";
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
class QuantizeFloatVectorCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
cc->Outputs().Tag("ENCODED").Set<std::string>();
cc->Inputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
cc->Outputs().Tag(kEncodedTag).Set<std::string>();
return absl::OkStatus();
}
@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
const std::vector<float>& float_vector =
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
cc->Inputs().Tag(kFloatVectorTag).Value().Get<std::vector<float>>();
int feature_size = float_vector.size();
std::string encoded_features;
encoded_features.reserve(feature_size);
@ -86,7 +89,9 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
(old_value - min_quantized_value_) * (255.0 / range_));
encoded_features += encoded;
}
cc->Outputs().Tag("ENCODED").AddPacket(
cc->Outputs()
.Tag(kEncodedTag)
.AddPacket(
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
return absl::OkStatus();
}

View File

@ -25,6 +25,9 @@
namespace mediapipe {
constexpr char kEncodedTag[] = "ENCODED";
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
auto status = runner.Run();
@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
CalculatorRunner runner(node_config);
std::vector<float> empty_vector;
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size());
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
CalculatorRunner runner(node_config);
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());
@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
CalculatorRunner runner(node_config);
std::vector<float> vector = {-65.0f, 65.0f};
runner.MutableInputs()
->Tag("FLOAT_VECTOR")
->Tag(kFloatVectorTag)
.packets.push_back(
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
const std::vector<Packet>& outputs =
runner.Outputs().Tag(kEncodedTag).packets;
EXPECT_EQ(1, outputs.size());
const std::string& result = outputs[0].Get<std::string>();
ASSERT_FALSE(result.empty());

View File

@ -23,6 +23,9 @@
namespace mediapipe {
constexpr char kAllowTag[] = "ALLOW";
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
// processing operations in a section of the graph.
//
@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
}
cc->Inputs().Get("FINISHED", 0).SetAny();
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>();
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>();
}
if (cc->Outputs().HasTag("ALLOW")) {
cc->Outputs().Tag("ALLOW").Set<bool>();
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).Set<bool>();
}
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) final {
finished_id_ = cc->Inputs().GetId("FINISHED", 0);
max_in_flight_ = 1;
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>();
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>();
}
RET_CHECK_GE(max_in_flight_, 1);
num_in_flight_ = 0;

View File

@ -33,6 +33,9 @@
namespace mediapipe {
namespace {
constexpr char kFinishedTag[] = "FINISHED";
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) {
Timestamp timestamp =
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
runner.MutableInputs()
->Tag("FINISHED")
->Tag(kFinishedTag)
.packets.push_back(MakePacket<bool>(true).At(timestamp));
}

View File

@ -22,6 +22,8 @@ namespace mediapipe {
namespace {
constexpr char kPacketOffsetTag[] = "PACKET_OFFSET";
// Adds packets containing integers equal to their original timestamp.
void AddPackets(CalculatorRunner* runner) {
for (int i = 0; i < 10; ++i) {
@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) {
CalculatorRunner runner(node);
AddPackets(&runner);
runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2));
runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& input_packets =
runner.MutableInputs()->Index(0).packets;

View File

@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode(
// IMAGE: ImageFrame representing the input image.
// IMAGE_GPU: GpuBuffer representing the input image.
//
// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as
// pair<int, int>. If set, it will override corresponding field in calculator
// options and input side packet.
//
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
// degrees. This allows different rotation angles for different frames. It has
// to be a multiple of 90 degrees. If provided, it overrides the
@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract(
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<std::pair<int, int>>();
}
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
}
@ -329,6 +337,13 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) {
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
}
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") &&
!cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
const auto& image_size =
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<std::pair<int, int>>();
output_width_ = image_size.first;
output_height_ = image_size.second;
}
if (use_gpu_) {
#if !MEDIAPIPE_DISABLE_GPU

View File

@ -88,6 +88,13 @@ proto_library(
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "tensor_to_vector_string_calculator_options_proto",
srcs = ["tensor_to_vector_string_calculator_options.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "unpack_media_sequence_calculator_proto",
srcs = ["unpack_media_sequence_calculator.proto"],
@ -257,6 +264,14 @@ mediapipe_cc_proto_library(
deps = [":tensor_to_vector_float_calculator_options_proto"],
)
mediapipe_cc_proto_library(
name = "tensor_to_vector_string_calculator_options_cc_proto",
srcs = ["tensor_to_vector_string_calculator_options.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":tensor_to_vector_string_calculator_options_proto"],
)
mediapipe_cc_proto_library(
name = "unpack_media_sequence_calculator_cc_proto",
srcs = ["unpack_media_sequence_calculator.proto"],
@ -694,6 +709,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tensor_to_vector_string_calculator",
srcs = ["tensor_to_vector_string_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check",
":tensor_to_vector_string_calculator_options_cc_proto",
] + select({
"//conditions:default": [
"@org_tensorflow//tensorflow/core:framework",
],
"//mediapipe:android": [
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
],
}),
alwayslink = 1,
)
cc_library(
name = "unpack_media_sequence_calculator",
srcs = ["unpack_media_sequence_calculator.cc"],
@ -1059,6 +1094,20 @@ cc_test(
],
)
cc_test(
name = "tensor_to_vector_string_calculator_test",
srcs = ["tensor_to_vector_string_calculator_test.cc"],
deps = [
":tensor_to_vector_string_calculator",
":tensor_to_vector_string_calculator_options_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "unpack_media_sequence_calculator_test",
srcs = ["unpack_media_sequence_calculator_test.cc"],

View File

@ -40,6 +40,24 @@ namespace {
namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence;
constexpr char kBboxTag[] = "BBOX";
constexpr char kEncodedMediaStartTimestampTag[] =
"ENCODED_MEDIA_START_TIMESTAMP";
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION";
constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST";
constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED";
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
constexpr char kAudioTestTag[] = "AUDIO_TEST";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
constexpr char kImageTag[] = "IMAGE";
class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected:
void SetUpCalculator(const std::vector<std::string>& input_streams,
@ -83,17 +101,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -127,17 +145,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()
->Tag("IMAGE_PREFIX")
->Tag(kImagePrefixTag)
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -161,21 +179,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST")
->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_OTHER")
->Tag(kFloatFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -228,20 +246,20 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3);
runner_->MutableInputs()
->Tag("FLOAT_CONTEXT_FEATURE_TEST")
->Tag(kFloatContextFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
vf_ptr = absl::make_unique<std::vector<float>>(2, 4);
runner_->MutableInputs()
->Tag("FLOAT_CONTEXT_FEATURE_OTHER")
->Tag(kFloatContextFeatureOtherTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -259,7 +277,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
SetUpCalculator({"IMAGE:images"}, context, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
@ -268,13 +286,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
encoded_image.set_encoded_image(bytes.data(), bytes.size());
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(0)));
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -307,17 +325,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -371,17 +389,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -450,11 +468,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) {
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
auto status = runner_->Run();
@ -498,7 +516,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
->Tag(kBboxPredictedTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
@ -513,16 +531,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -564,18 +582,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
runner_->MutableInputs()
->Tag("KEYPOINTS_TEST")
->Tag(kKeypointsTestTag)
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
runner_->MutableInputs()
->Tag("KEYPOINTS_TEST")
->Tag(kKeypointsTestTag)
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -615,17 +633,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
detections->push_back(detection);
runner_->MutableInputs()
->Tag("CLASS_SEGMENTATION")
->Tag(kClassSegmentationTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -664,17 +682,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -710,11 +728,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
auto flow_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
runner_->MutableInputs()
->Tag("FORWARD_FLOW_ENCODED")
->Tag(kForwardFlowEncodedTag)
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
absl::Status status = runner_->Run();
@ -731,13 +749,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) {
mpms::AddImageTimestamp(1, input_sequence.get());
mpms::AddImageTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -757,13 +775,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) {
mpms::AddForwardFlowTimestamp(1, input_sequence.get());
mpms::AddForwardFlowTimestamp(2, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -794,13 +812,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) {
mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
ASSERT_EQ(num_timesteps,
mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -826,7 +844,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
}
@ -838,11 +856,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
@ -879,7 +897,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}
@ -893,7 +911,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
runner_->MutableInputs()->Tag("BBOX").packets.push_back(
runner_->MutableInputs()->Tag(kBboxTag).packets.push_back(
Adopt(detections.release()).At(Timestamp(i)));
}
@ -909,7 +927,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
mpms::AddBBoxTrackIndex({-1}, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
// If the all the previous values aren't cleared, this assert will fail.
MP_ASSERT_OK(runner_->Run());
@ -925,11 +943,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) {
for (int i = 0; i < num_timesteps; ++i) {
auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i);
runner_->MutableInputs()
->Tag("FLOAT_FEATURE_TEST")
->Tag(kFloatFeatureTestTag)
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
ASSERT_FALSE(runner_->Run().ok());
}

View File

@ -26,6 +26,8 @@ namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kReferenceTag[] = "REFERENCE";
constexpr char kMatrix[] = "MATRIX";
constexpr char kTensor[] = "TENSOR";
@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test {
if (include_rate) {
header->set_packet_rate(1.0);
}
runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release());
runner_->MutableInputs()->Tag(kReferenceTag).header =
Adopt(header.release());
}
std::unique_ptr<CalculatorRunner> runner_;

View File

@ -0,0 +1,118 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Calculator converts from one-dimensional Tensor of DT_STRING to
// vector<std::string> OR from (batched) two-dimensional Tensor of DT_STRING to
// vector<vector<std::string>.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
namespace mediapipe {
namespace tf = ::tensorflow;
class TensorToVectorStringCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
TensorToVectorStringCalculatorOptions options_;
};
REGISTER_CALCULATOR(TensorToVectorStringCalculator);
absl::Status TensorToVectorStringCalculator::GetContract(
CalculatorContract* cc) {
// Start with only one input packet.
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
<< "Only one input stream is supported.";
cc->Inputs().Index(0).Set<tf::Tensor>(
// Input Tensor
);
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
<< "Only one output stream is supported.";
const auto& options = cc->Options<TensorToVectorStringCalculatorOptions>();
if (options.tensor_is_2d()) {
RET_CHECK(!options.flatten_nd());
cc->Outputs().Index(0).Set<std::vector<std::vector<std::string>>>(
/* "Output vector<vector<std::string>>." */);
} else {
cc->Outputs().Index(0).Set<std::vector<std::string>>(
// Output vector<std::string>.
);
}
return absl::OkStatus();
}
absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<TensorToVectorStringCalculatorOptions>();
// Inform mediapipe that this calculator produces an output at time t for
// each input received at time t (i.e. this calculator does not buffer
// inputs). This enables mediapipe to propagate time of arrival estimates in
// mediapipe graphs through this calculator.
cc->SetOffset(/*offset=*/0);
return absl::OkStatus();
}
absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) {
const tf::Tensor& input_tensor =
cc->Inputs().Index(0).Value().Get<tf::Tensor>();
RET_CHECK(tf::DT_STRING == input_tensor.dtype())
<< "expected DT_STRING input but got "
<< tensorflow::DataTypeString(input_tensor.dtype());
if (options_.tensor_is_2d()) {
RET_CHECK(2 == input_tensor.dims())
<< "Expected 2-dimensional Tensor, but the tensor shape is: "
<< input_tensor.shape().DebugString();
auto output = absl::make_unique<std::vector<std::vector<std::string>>>(
input_tensor.dim_size(0),
std::vector<std::string>(input_tensor.dim_size(1)));
for (int i = 0; i < input_tensor.dim_size(0); ++i) {
auto& instance_output = output->at(i);
const auto& slice =
input_tensor.Slice(i, i + 1).unaligned_flat<tensorflow::tstring>();
for (int j = 0; j < input_tensor.dim_size(1); ++j) {
instance_output.at(j) = slice(j);
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
} else {
if (!options_.flatten_nd()) {
RET_CHECK(1 == input_tensor.dims())
<< "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the "
<< "tensor shape is: " << input_tensor.shape().DebugString();
}
auto output =
absl::make_unique<std::vector<std::string>>(input_tensor.NumElements());
const auto& tensor_values = input_tensor.flat<tensorflow::tstring>();
for (int i = 0; i < input_tensor.NumElements(); ++i) {
output->at(i) = tensor_values(i);
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
}
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorToVectorStringCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorToVectorStringCalculatorOptions ext = 386534187;
}
// If true, unpack a 2d tensor (matrix) into a vector<vector<string>>. If
// false, convert a 1d tensor (vector) into a vector<string>.
optional bool tensor_is_2d = 1 [default = false];
// If true, an N-D tensor will be flattened to a vector<string>. This is
// exclusive with tensor_is_2d.
optional bool flatten_nd = 2 [default = false];
}

View File

@ -0,0 +1,130 @@
// Copyright 2018 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mediapipe {
namespace {
namespace tf = ::tensorflow;
class TensorToVectorStringCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToVectorStringCalculator");
config.add_input_stream("input_tensor");
config.add_output_stream("output_tensor");
auto options = config.mutable_options()->MutableExtension(
TensorToVectorStringCalculatorOptions::ext);
options->set_tensor_is_2d(tensor_is_2d);
options->set_flatten_nd(flatten_nd);
runner_ = absl::make_unique<CalculatorRunner>(config);
}
std::unique_ptr<CalculatorRunner> runner_;
};
TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) {
SetUpRunner(false, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto tensor_vec = tensor->vec<tensorflow::tstring>();
for (int i = 0; i < 5; ++i) {
tensor_vec(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::string>& output_vector =
output_packets[0].Get<std::vector<std::string>>();
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) {
SetUpRunner(true, false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{1, 5});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto slice = tensor->Slice(0, 1).flat<tensorflow::tstring>();
for (int i = 0; i < 5; ++i) {
slice(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::vector<std::string>>& output_vectors =
output_packets[0].Get<std::vector<std::vector<std::string>>>();
ASSERT_EQ(1, output_vectors.size());
const std::vector<std::string>& output_vector = output_vectors[0];
EXPECT_EQ(5, output_vector.size());
for (int i = 0; i < 5; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) {
SetUpRunner(false, true);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 2, 2});
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
auto slice = tensor->flat<tensorflow::tstring>();
for (int i = 0; i < 2 * 2 * 2; ++i) {
slice(i) = absl::StrCat("foo", i);
}
const int64 time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Index(0).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const std::vector<std::string>& output_vector =
output_packets[0].Get<std::vector<std::string>>();
EXPECT_EQ(2 * 2 * 2, output_vector.size());
for (int i = 0; i < 2 * 2 * 2; ++i) {
const std::string expected = absl::StrCat("foo", i);
EXPECT_EQ(expected, output_vector[i]);
}
}
} // namespace
} // namespace mediapipe

View File

@ -49,6 +49,11 @@ namespace tf = ::tensorflow;
namespace mediapipe {
namespace {
constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS";
constexpr char kSessionTag[] = "SESSION";
constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
// This is a simple implementation of a semaphore using standard C++ libraries.
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
}
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
// For this calculator it must include a tag_to_tensor_map.
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
cc->InputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) {
cc->InputSidePackets()
.Tag("RECURRENT_INIT_TENSORS")
.Tag(kRecurrentInitTensorsTag)
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
}
return absl::OkStatus();
@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
std::unique_ptr<InferenceState> inference_state =
absl::make_unique<InferenceState>();
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) &&
!cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) {
std::map<std::string, tf::Tensor>* init_tensor_map;
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
cc->InputSidePackets().Tag(kRecurrentInitTensorsTag));
for (const auto& p : *init_tensor_map) {
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
}
@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag));
session_ = cc->InputSidePackets()
.Tag("SESSION")
.Tag(kSessionTag)
.Get<TensorFlowSession>()
.session.get();
tag_to_tensor_map_ = cc->InputSidePackets()
.Tag("SESSION")
.Tag(kSessionTag)
.Get<TensorFlowSession>()
.tag_to_tensor_map;

View File

@ -41,6 +41,11 @@ namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kMultipliedTag[] = "MULTIPLIED";
constexpr char kBTag[] = "B";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() {
#ifdef __APPLE__
char path[1024];
@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
input_side_packets, &output_side_packets));
runner_->MutableSidePackets()->Tag("SESSION") =
output_side_packets.Tag("SESSION");
runner_->MutableSidePackets()->Tag(kSessionTag) =
output_side_packets.Tag(kSessionTag);
}
Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_b =
runner_->Outputs().Tag("B").packets;
runner_->Outputs().Tag(kBTag).packets;
ASSERT_EQ(output_packets_b.size(), 1);
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({1, 3});
@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3});
@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
tf::TensorShape expected_shape({3});
@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(3, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(5, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0;
@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(2, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
LOG(INFO) << "timestamp: " << 0;
@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) {
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(0, output_packets_mult.size());
}
@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets_mult =
runner_->Outputs().Tag("MULTIPLIED").packets;
runner_->Outputs().Tag(kMultipliedTag).packets;
ASSERT_EQ(1, output_packets_mult.size());
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});

View File

@ -47,6 +47,11 @@ namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kSessionTag[] = "SESSION";
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
// Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) {
@ -64,27 +69,29 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
bool has_exactly_one_model =
!options.graph_proto_path().empty()
? !(cc->InputSidePackets().HasTag("STRING_MODEL") |
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"))
: (cc->InputSidePackets().HasTag("STRING_MODEL") ^
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"));
? !(cc->InputSidePackets().HasTag(kStringModelTag) |
cc->InputSidePackets().HasTag(kStringModelFilePathTag))
: (cc->InputSidePackets().HasTag(kStringModelTag) ^
cc->InputSidePackets().HasTag(kStringModelFilePathTag));
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of graph_proto_path in options or "
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
cc->InputSidePackets()
.Tag("STRING_MODEL")
.Tag(kStringModelTag)
.Set<std::string>(
// String model from embedded path
);
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
cc->InputSidePackets()
.Tag("STRING_MODEL_FILE_PATH")
.Tag(kStringModelFilePathTag)
.Set<std::string>(
// Filename of std::string model.
);
}
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>(
cc->OutputSidePackets()
.Tag(kSessionTag)
.Set<TensorFlowSession>(
// A TensorFlow model loaded and ready for use along with
// a map from tags to tensor names.
);
@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
session->session.reset(tf::NewSession(session_options));
std::string graph_def_serialized;
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
graph_def_serialized =
cc->InputSidePackets().Tag("STRING_MODEL").Get<std::string>();
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
cc->InputSidePackets().Tag(kStringModelTag).Get<std::string>();
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
const std::string& frozen_graph = cc->InputSidePackets()
.Tag("STRING_MODEL_FILE_PATH")
.Tag(kStringModelFilePathTag)
.Get<std::string>();
RET_CHECK_OK(
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
}
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds.";

View File

@ -37,6 +37,10 @@ namespace {
namespace tf = ::tensorflow;
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() {
return mediapipe::file::JoinPath("./",
"mediapipe/calculators/tensorflow/"
@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session);
}
@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") =
runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session);
}
@ -213,12 +217,12 @@ TEST_F(
}
})",
calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session);
}
@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
}
})",
calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
auto run_status = runner.Run();
EXPECT_THAT(
@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
}
})",
calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") =
runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
auto run_status = runner.Run();
EXPECT_THAT(
@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
}
})",
calculator_options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
runner.MutableSidePackets()->Tag("STRING_MODEL") =
runner.MutableSidePackets()->Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
auto run_status = runner.Run();
EXPECT_THAT(
@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
VerifySignatureMap(session);
}

View File

@ -43,6 +43,11 @@ namespace mediapipe {
namespace tf = ::tensorflow;
namespace {
constexpr char kSessionTag[] = "SESSION";
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
// Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
for (auto& node : *graph_def->mutable_node()) {
@ -64,25 +69,26 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
bool has_exactly_one_model =
!options.graph_proto_path().empty()
? !(input_side_packets->HasTag("STRING_MODEL") |
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"))
: (input_side_packets->HasTag("STRING_MODEL") ^
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"));
? !(input_side_packets->HasTag(kStringModelTag) |
input_side_packets->HasTag(kStringModelFilePathTag))
: (input_side_packets->HasTag(kStringModelTag) ^
input_side_packets->HasTag(kStringModelFilePathTag));
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of graph_proto_path in options or "
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
if (input_side_packets->HasTag("STRING_MODEL")) {
input_side_packets->Tag("STRING_MODEL")
if (input_side_packets->HasTag(kStringModelTag)) {
input_side_packets->Tag(kStringModelTag)
.Set<std::string>(
// String model from embedded path
);
} else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) {
input_side_packets->Tag("STRING_MODEL_FILE_PATH")
} else if (input_side_packets->HasTag(kStringModelFilePathTag)) {
input_side_packets->Tag(kStringModelFilePathTag)
.Set<std::string>(
// Filename of std::string model.
);
}
output_side_packets->Tag("SESSION").Set<TensorFlowSession>(
output_side_packets->Tag(kSessionTag)
.Set<TensorFlowSession>(
// A TensorFlow model loaded and ready for use along with
// a map from tags to tensor names.
);
@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
session->session.reset(tf::NewSession(session_options));
std::string graph_def_serialized;
if (input_side_packets.HasTag("STRING_MODEL")) {
if (input_side_packets.HasTag(kStringModelTag)) {
graph_def_serialized =
input_side_packets.Tag("STRING_MODEL").Get<std::string>();
} else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) {
input_side_packets.Tag(kStringModelTag).Get<std::string>();
} else if (input_side_packets.HasTag(kStringModelFilePathTag)) {
const std::string& frozen_graph =
input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get<std::string>();
input_side_packets.Tag(kStringModelFilePathTag).Get<std::string>();
RET_CHECK_OK(
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
} else {
@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
}
output_side_packets->Tag("SESSION") = Adopt(session.release());
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
<< " microseconds.";

View File

@ -37,6 +37,10 @@ namespace {
namespace tf = ::tensorflow;
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";
constexpr char kSessionTag[] = "SESSION";
std::string GetGraphDefPath() {
return mediapipe::file::JoinPath("./",
"mediapipe/calculators/tensorflow/"
@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
void VerifySignatureMap(PacketSet* output_side_packets) {
const TensorFlowSession& session =
output_side_packets->Tag("SESSION").Get<TensorFlowSession>();
output_side_packets->Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL") =
input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -196,7 +200,7 @@ TEST_F(
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value());
generator_options_->clear_graph_proto_path();
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value());
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value());
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes(
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") =
input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
absl::Status run_status = tool::RunGenerateAndValidateTypes(
@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
std::string serialized_graph_contents;
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
&serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL") =
input_side_packets.Tag(kStringModelTag) =
Adopt(new std::string(serialized_graph_contents));
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
input_side_packets.Tag(kStringModelFilePathTag) =
Adopt(new std::string(GetGraphDefPath()));
generator_options_->clear_graph_proto_path();

View File

@ -31,6 +31,9 @@
namespace mediapipe {
namespace {
constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models
@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
cc->OutputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus();
}
@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
output_signature.first, options)] = output_signature.second.name();
}
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
return absl::OkStatus();
}

View File

@ -35,6 +35,9 @@ namespace {
namespace tf = ::tensorflow;
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
constexpr char kSessionTag[] = "SESSION";
std::string GetSavedModelDir() {
std::string out_path =
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString()));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
}
})",
options_->DebugString()));
runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") =
runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) =
MakePacket<std::string>(GetSavedModelDir());
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString()));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
options_->DebugString()));
MP_ASSERT_OK(runner.Run());
const TensorFlowSession& session =
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices;

View File

@ -33,6 +33,9 @@
namespace mediapipe {
namespace {
constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models
@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
output_side_packets->Tag("SESSION").Set<TensorFlowSession>();
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus();
}
@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
output_signature.first, options)] = output_signature.second.name();
}
output_side_packets->Tag("SESSION") = Adopt(session.release());
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
return absl::OkStatus();
}
};

View File

@ -34,6 +34,9 @@ namespace {
namespace tf = ::tensorflow;
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
constexpr char kSessionTag[] = "SESSION";
std::string GetSavedModelDir() {
std::string out_path =
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
generator_options_->clear_saved_model_path();
PacketSet input_side_packets(
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value());
input_side_packets.Tag("STRING_SAVED_MODEL_PATH") =
input_side_packets.Tag(kStringSavedModelPathTag) =
Adopt(new std::string(GetSavedModelDir()));
PacketSet output_side_packets(
tool::CreateTagMap({"SESSION:session"}).value());
@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
}
@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
input_side_packets, &output_side_packets);
MP_EXPECT_OK(run_status) << run_status.message();
const TensorFlowSession& session =
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
// Session must be set.
ASSERT_NE(session.session, nullptr);
std::vector<tensorflow::DeviceAttributes> devices;

View File

@ -33,6 +33,31 @@ namespace {
namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence;
constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE";
constexpr char kEncodedMediaStartTimestampTag[] =
"ENCODED_MEDIA_START_TIMESTAMP";
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS";
constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS";
constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS";
constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS";
constexpr char kDataPathTag[] = "DATA_PATH";
constexpr char kDatasetRootTag[] = "DATASET_ROOT";
constexpr char kMediaIdTag[] = "MEDIA_ID";
constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX";
constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG";
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
constexpr char kAudioTestTag[] = "AUDIO_TEST";
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
constexpr char kBboxPrefixTag[] = "BBOX_PREFIX";
constexpr char kKeypointsTag[] = "KEYPOINTS";
constexpr char kBboxTag[] = "BBOX";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kImageTag[] = "IMAGE";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
protected:
void SetUpCalculator(const std::vector<std::string>& output_streams,
@ -95,13 +120,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) {
mpms::AddImageEncoded(test_image_string, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets;
runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
@ -124,13 +149,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) {
mpms::AddImageEncoded(test_image_string, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets;
runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
@ -154,13 +179,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) {
mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE_PREFIX").packets;
runner_->Outputs().Tag(kImagePrefixTag).packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
@ -182,12 +207,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) {
mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
ASSERT_EQ(num_forward_flow_images, output_packets.size());
for (int i = 0; i < num_forward_flow_images; ++i) {
@ -211,12 +236,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) {
mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
ASSERT_EQ(num_forward_flow_images, output_packets.size());
for (int i = 0; i < num_forward_flow_images; ++i) {
@ -240,13 +265,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) {
mpms::AddBBoxTimestamp(i, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("BBOX").packets;
runner_->Outputs().Tag(kBboxTag).packets;
ASSERT_EQ(bboxes.size(), output_packets.size());
for (int i = 0; i < bboxes.size(); ++i) {
@ -274,13 +299,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) {
mpms::AddBBoxTimestamp(prefix, i, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("BBOX_PREFIX").packets;
runner_->Outputs().Tag(kBboxPrefixTag).packets;
ASSERT_EQ(bboxes.size(), output_packets.size());
for (int i = 0; i < bboxes.size(); ++i) {
@ -306,13 +331,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets;
runner_->Outputs().Tag(kFloatFeatureTestTag).packets;
ASSERT_EQ(num_float_lists, output_packets.size());
for (int i = 0; i < num_float_lists; ++i) {
@ -322,7 +347,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
}
const std::vector<Packet>& output_packets_other =
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
ASSERT_EQ(num_float_lists, output_packets_other.size());
for (int i = 0; i < num_float_lists; ++i) {
@ -352,12 +377,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get());
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets;
runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
@ -366,7 +391,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
}
const std::vector<Packet>& output_packets_other =
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
ASSERT_EQ(num_float_lists, output_packets_other.size());
for (int i = 0; i < num_float_lists; ++i) {
@ -389,12 +414,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& fdense_avg_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets;
runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets;
ASSERT_EQ(fdense_avg_packets.size(), 1);
const auto& fdense_avg_vector =
fdense_avg_packets[0].Get<std::vector<float>>();
@ -403,7 +428,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
::testing::Eq(Timestamp::PostStream()));
const std::vector<Packet>& fdense_max_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
ASSERT_EQ(fdense_max_packets.size(), 1);
const auto& fdense_max_vector =
fdense_max_packets[0].Get<std::vector<float>>();
@ -430,13 +455,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("IMAGE").packets;
runner_->Outputs().Tag(kImageTag).packets;
ASSERT_EQ(num_images, output_packets.size());
for (int i = 0; i < num_images; ++i) {
@ -463,13 +488,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
input_sequence.get());
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& fdense_max_packets =
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
ASSERT_EQ(fdense_max_packets.size(), 1);
const auto& fdense_max_vector =
fdense_max_packets[0].Get<std::vector<float>>();
@ -481,17 +506,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
std::string root = "test_root";
runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root);
runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root);
MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH")
.Tag(kDataPathTag)
.ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
root + "/" + data_path_);
}
@ -501,28 +526,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) {
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_dataset_root_directory(root);
SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH")
.Tag(kDataPathTag)
.ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
root + "/" + data_path_);
}
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
SetUpCalculator({}, {"DATA_PATH:data_path"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_ASSERT_OK(runner_->OutputSidePackets()
.Tag("DATA_PATH")
.Tag(kDataPathTag)
.ValidateAsType<std::string>());
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
data_path_);
}
@ -534,20 +559,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) {
->set_padding_after_label(2);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>()
.start_time(),
2.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>()
.end_time(),
7.0, 1e-5);
@ -563,20 +588,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) {
->set_force_decoding_from_start_of_media(true);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>()
.start_time(),
0.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Tag(kAudioDecoderOptionsTag)
.Get<AudioDecoderOptions>()
.end_time(),
7.0, 1e-5);
@ -594,27 +619,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
->mutable_base_packet_resampler_options()
->set_frame_rate(1.0);
SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS")
.Tag(kResamplerOptionsTag)
.ValidateAsType<CalculatorOptions>());
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS")
.Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext)
.start_time(),
2000000, 1);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS")
.Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext)
.end_time(),
7000000, 1);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("RESAMPLER_OPTIONS")
.Tag(kResamplerOptionsTag)
.Get<CalculatorOptions>()
.GetExtension(PacketResamplerCalculatorOptions::ext)
.frame_rate(),
@ -623,13 +648,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) {
SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"});
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("IMAGE_FRAME_RATE")
.Tag(kImageFrameRateTag)
.ValidateAsType<double>());
EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get<double>(),
EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get<double>(),
image_frame_rate_);
}

View File

@ -26,6 +26,10 @@ namespace {
namespace tf = ::tensorflow;
constexpr char kSingleIntTag[] = "SINGLE_INT";
constexpr char kTensorOutTag[] = "TENSOR_OUT";
constexpr char kVectorIntTag[] = "VECTOR_INT";
class VectorIntToTensorCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner(
@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test {
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
tensorflow::DT_INT32, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
->Tag(kSingleIntTag)
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
tensorflow::DT_INT64, false, true);
const int64 time = 1234;
runner_->MutableInputs()
->Tag("SINGLE_INT")
->Tag(kSingleIntTag)
.packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
}
const int64 time = 1234;
runner_->MutableInputs()
->Tag("VECTOR_INT")
->Tag(kVectorIntTag)
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
EXPECT_TRUE(runner_->Run().ok());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("TENSOR_OUT").packets;
runner_->Outputs().Tag(kTensorOutTag).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();

View File

@ -18,6 +18,10 @@
namespace mediapipe {
constexpr char kFloatsTag[] = "FLOATS";
constexpr char kFloatTag[] = "FLOAT";
constexpr char kTensorsTag[] = "TENSORS";
// A calculator for converting TFLite tensors to to a float or a float vector.
//
// Input:
@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
cc->Outputs().HasTag(kFloatTag));
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
if (cc->Outputs().HasTag("FLOATS")) {
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
if (cc->Outputs().HasTag(kFloatsTag)) {
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
}
if (cc->Outputs().HasTag("FLOAT")) {
cc->Outputs().Tag("FLOAT").Set<float>();
if (cc->Outputs().HasTag(kFloatTag)) {
cc->Outputs().Tag(kFloatTag).Set<float>();
}
return absl::OkStatus();
@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) {
}
absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty());
const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
// TODO: Add option to specify which tensor to take from.
const TfLiteTensor* raw_tensor = &input_tensors[0];
const float* raw_floats = raw_tensor->data.f;
@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
num_values *= raw_tensor->dims->data[i];
}
if (cc->Outputs().HasTag("FLOAT")) {
if (cc->Outputs().HasTag(kFloatTag)) {
// TODO: Could add an index in the option to specifiy returning one
// value of a float array.
RET_CHECK_EQ(num_values, 1);
cc->Outputs().Tag("FLOAT").AddPacket(
cc->Outputs().Tag(kFloatTag).AddPacket(
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("FLOATS")) {
if (cc->Outputs().HasTag(kFloatsTag)) {
auto output_floats = absl::make_unique<std::vector<float>>(
raw_floats, raw_floats + num_values);
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
cc->InputTimestamp());
cc->Outputs()
.Tag(kFloatsTag)
.Add(output_floats.release(), cc->InputTimestamp());
}
return absl::OkStatus();

View File

@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) {
// Initialize the clock.
if (cc->InputSidePackets().HasTag(kClockTag)) {
clock_ = cc->InputSidePackets()
.Tag("CLOCK")
.Tag(kClockTag)
.Get<std::shared_ptr<::mediapipe::Clock>>();
} else {
clock_.reset(

View File

@ -27,6 +27,8 @@
namespace mediapipe {
constexpr char kIterableTag[] = "ITERABLE";
typedef CollectionHasMinSizeCalculator<std::vector<int>>
TestIntCollectionHasMinSizeCalculator;
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
void AddInputVector(const std::vector<int>& input, int64 timestamp,
CalculatorRunner* runner) {
runner->MutableInputs()
->Tag("ITERABLE")
->Tag(kIterableTag)
.packets.push_back(
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
}

View File

@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase {
}
cc->Outputs()
.Tag("DETECTIONS")
.Tag(kDetectionsTag)
.Add(output_detections.release(), cc->InputTimestamp());
return absl::OkStatus();
}

View File

@ -25,6 +25,9 @@
namespace mediapipe {
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kDetectionsTag[] = "DETECTIONS";
LocationData CreateRelativeLocationData(double xmin, double ymin, double width,
double height) {
LocationData location_data;
@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) {
detections->push_back(
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
runner.MutableInputs()
->Tag("LETTERBOX_PADDING")
->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets;
runner.Outputs().Tag(kDetectionsTag).packets;
ASSERT_EQ(1, output.size());
const auto& output_detections = output[0].Get<std::vector<Detection>>();
@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) {
detections->push_back(
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f});
runner.MutableInputs()
->Tag("LETTERBOX_PADDING")
->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets;
runner.Outputs().Tag(kDetectionsTag).packets;
ASSERT_EQ(1, output.size());
const auto& output_detections = output[0].Get<std::vector<Detection>>();

View File

@ -31,6 +31,9 @@
namespace mediapipe {
namespace {
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kDetectionsTag[] = "DETECTIONS";
using ::testing::ElementsAre;
using ::testing::FloatNear;
@ -74,19 +77,19 @@ absl::StatusOr<Detection> RunProjectionCalculator(
)pb"));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(MakePacket<std::vector<Detection>>(
std::vector<Detection>({std::move(detection)}))
.At(Timestamp::PostStream()));
runner.MutableInputs()
->Tag("PROJECTION_MATRIX")
->Tag(kProjectionMatrixTag)
.packets.push_back(
MakePacket<std::array<float, 16>>(std::move(project_mat))
.At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output =
runner.Outputs().Tag("DETECTIONS").packets;
runner.Outputs().Tag(kDetectionsTag).packets;
RET_CHECK_EQ(output.size(), 1);
const auto& output_detections = output[0].Get<std::vector<Detection>>();

View File

@ -32,6 +32,14 @@
namespace mediapipe {
namespace {
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kRectsTag[] = "RECTS";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kRectTag[] = "RECT";
constexpr char kDetectionTag[] = "DETECTION";
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
testing::Value(arg.y_center(), testing::Eq(y_center)) &&
@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) {
DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<Rect>();
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
@ -120,16 +128,16 @@ absl::StatusOr<Rect> RunDetectionKeyPointsToRectCalculation(
)pb"));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(std::move(detection))
.At(Timestamp::PostStream()));
runner.MutableInputs()
->Tag("IMAGE_SIZE")
->Tag(kImageSizeTag)
.packets.push_back(MakePacket<std::pair<int, int>>(image_size)
.At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
RET_CHECK_EQ(output.size(), 1);
return output[0].Get<Rect>();
}
@ -176,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) {
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<NormalizedRect>();
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
@ -201,12 +210,13 @@ absl::StatusOr<NormalizedRect> RunDetectionKeyPointsToNormRectCalculation(
)pb"));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(std::move(detection))
.At(Timestamp::PostStream()));
MP_RETURN_IF_ERROR(runner.Run());
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
RET_CHECK_EQ(output.size(), 1);
return output[0].Get<NormalizedRect>();
}
@ -248,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) {
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<Rect>();
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
@ -271,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) {
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
const std::vector<Packet>& output =
runner.Outputs().Tag(kNormRectTag).packets;
ASSERT_EQ(1, output.size());
const auto& rect = output[0].Get<NormalizedRect>();
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
@ -294,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) {
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<Rect>>();
ASSERT_EQ(rects.size(), 2);
@ -319,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) {
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("NORM_RECTS").packets;
runner.Outputs().Tag(kNormRectsTag).packets;
ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
ASSERT_EQ(rects.size(), 2);
@ -344,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) {
DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<Rect>>();
EXPECT_EQ(rects.size(), 1);
@ -367,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) {
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs()
->Tag("DETECTION")
->Tag(kDetectionTag)
.packets.push_back(
Adopt(detection.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("NORM_RECTS").packets;
runner.Outputs().Tag(kNormRectsTag).packets;
ASSERT_EQ(1, output.size());
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
ASSERT_EQ(rects.size(), 1);
@ -391,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) {
detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
@ -411,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) {
detections->push_back(DetectionWithLocationData(100, 200, 300, 400));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));

View File

@ -30,6 +30,10 @@
namespace mediapipe {
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kDetectionListTag[] = "DETECTION_LIST";
using ::testing::DoubleNear;
// Error tolerance for pixels, distances, etc.
@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) {
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag");
runner.MutableInputs()
->Tag("DETECTION_LIST")
->Tag(kDetectionListTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("RENDER_DATA").packets;
runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, output.size());
const auto& actual = output[0].Get<RenderData>();
EXPECT_EQ(actual.render_annotations_size(), 3);
@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) {
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output =
runner.Outputs().Tag("RENDER_DATA").packets;
runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, output.size());
const auto& actual = output[0].Get<RenderData>();
EXPECT_EQ(actual.render_annotations_size(), 3);
@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
*(detection_list->add_detection()) =
CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1");
runner.MutableInputs()
->Tag("DETECTION_LIST")
->Tag(kDetectionListTag)
.packets.push_back(
Adopt(detection_list.release()).At(Timestamp::PostStream()));
@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
detections->push_back(
CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2"));
runner.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& actual =
runner.Outputs().Tag("RENDER_DATA").packets;
runner.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, actual.size());
// Check the feature tag for item from detection list.
EXPECT_EQ(
@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
auto detection_list1(absl::make_unique<DetectionList>());
runner1.MutableInputs()
->Tag("DETECTION_LIST")
->Tag(kDetectionListTag)
.packets.push_back(
Adopt(detection_list1.release()).At(Timestamp::PostStream()));
auto detections1(absl::make_unique<std::vector<Detection>>());
runner1.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections1.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed.";
const std::vector<Packet>& exact1 =
runner1.Outputs().Tag("RENDER_DATA").packets;
runner1.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(0, exact1.size());
// Check when produce_empty_packet is true.
@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
auto detection_list2(absl::make_unique<DetectionList>());
runner2.MutableInputs()
->Tag("DETECTION_LIST")
->Tag(kDetectionListTag)
.packets.push_back(
Adopt(detection_list2.release()).At(Timestamp::PostStream()));
auto detections2(absl::make_unique<std::vector<Detection>>());
runner2.MutableInputs()
->Tag("DETECTIONS")
->Tag(kDetectionsTag)
.packets.push_back(
Adopt(detections2.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed.";
const std::vector<Packet>& exact2 =
runner2.Outputs().Tag("RENDER_DATA").packets;
runner2.Outputs().Tag(kRenderDataTag).packets;
ASSERT_EQ(1, exact2.size());
EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0);
}

View File

@ -32,6 +32,12 @@
namespace mediapipe {
constexpr char kRenderDataTag[] = "RENDER_DATA";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kScoresTag[] = "SCORES";
constexpr char kLabelsTag[] = "LABELS";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr float kFontHeightScale = 1.25f;
// A calculator takes in pairs of labels and scores or classifications, outputs
@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase {
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
if (cc->Inputs().HasTag(kClassificationsTag)) {
cc->Inputs().Tag(kClassificationsTag).Set<ClassificationList>();
} else {
RET_CHECK(cc->Inputs().HasTag("LABELS"))
RET_CHECK(cc->Inputs().HasTag(kLabelsTag))
<< "Must provide input stream \"LABELS\"";
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
if (cc->Inputs().HasTag("SCORES")) {
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
cc->Inputs().Tag(kLabelsTag).Set<std::vector<std::string>>();
if (cc->Inputs().HasTag(kScoresTag)) {
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
}
}
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
}
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
return absl::OkStatus();
}
@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
}
absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
if (cc->Inputs().HasTag(kVideoPrestreamTag) &&
cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
video_width_ = video_header.width;
video_height_ = video_header.height;
return absl::OkStatus();
@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
std::vector<std::string> labels;
std::vector<float> scores;
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
if (cc->Inputs().HasTag(kClassificationsTag)) {
const ClassificationList& classifications =
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
cc->Inputs().Tag(kClassificationsTag).Get<ClassificationList>();
labels.resize(classifications.classification_size());
scores.resize(classifications.classification_size());
for (int i = 0; i < classifications.classification_size(); ++i) {
@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
}
} else {
const std::vector<std::string>& label_vector =
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
cc->Inputs().Tag(kLabelsTag).Get<std::vector<std::string>>();
labels.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) {
labels[i] = label_vector[i];
}
if (cc->Inputs().HasTag("SCORES")) {
if (cc->Inputs().HasTag(kScoresTag)) {
std::vector<float> score_vector =
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
CHECK_EQ(label_vector.size(), score_vector.size());
scores.resize(label_vector.size());
for (int i = 0; i < label_vector.size(); ++i) {
@ -169,7 +175,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* text = label_annotation->mutable_text();
std::string display_text = labels[i];
if (cc->Inputs().HasTag("SCORES")) {
if (cc->Inputs().HasTag(kScoresTag)) {
absl::StrAppend(&display_text, ":", scores[i]);
}
text->set_display_text(display_text);
@ -179,7 +185,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_font_face(options_.font_face());
}
cc->Outputs()
.Tag("RENDER_DATA")
.Tag(kRenderDataTag)
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
return absl::OkStatus();

View File

@ -24,6 +24,9 @@
namespace mediapipe {
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
constexpr char kLandmarksTag[] = "LANDMARKS";
NormalizedLandmark CreateLandmark(float x, float y) {
NormalizedLandmark landmark;
landmark.set_x(x);
@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) {
*landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f);
*landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f);
runner.MutableInputs()
->Tag("LANDMARKS")
->Tag(kLandmarksTag)
.packets.push_back(
Adopt(landmarks.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
runner.MutableInputs()
->Tag("LETTERBOX_PADDING")
->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
const std::vector<Packet>& output =
runner.Outputs().Tag(kLandmarksTag).packets;
ASSERT_EQ(1, output.size());
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) {
landmark = landmarks->add_landmark();
*landmark = CreateLandmark(0.7f, 0.7f);
runner.MutableInputs()
->Tag("LANDMARKS")
->Tag(kLandmarksTag)
.packets.push_back(
Adopt(landmarks.release()).At(Timestamp::PostStream()));
auto padding = absl::make_unique<std::array<float, 4>>(
std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f});
runner.MutableInputs()
->Tag("LETTERBOX_PADDING")
->Tag(kLetterboxPaddingTag)
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
const std::vector<Packet>& output =
runner.Outputs().Tag(kLandmarksTag).packets;
ASSERT_EQ(1, output.size());
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();

View File

@ -16,6 +16,10 @@
namespace mediapipe {
namespace {
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
mediapipe::CalculatorRunner runner(
@ -26,17 +30,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag("NORM_LANDMARKS")
->Tag(kNormLandmarksTag)
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag("NORM_RECT")
->Tag(kNormRectTag)
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}
@ -104,17 +108,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
output_stream: "NORM_LANDMARKS:projected_landmarks"
)pb"));
runner.MutableInputs()
->Tag("NORM_LANDMARKS")
->Tag(kNormLandmarksTag)
.packets.push_back(
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
.At(Timestamp(1)));
runner.MutableInputs()
->Tag("PROJECTION_MATRIX")
->Tag(kProjectionMatrixTag)
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
.At(Timestamp(1)));
MP_RETURN_IF_ERROR(runner.Run());
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
RET_CHECK_EQ(output_packets.size(), 1);
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
}

View File

@ -20,6 +20,11 @@
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
constexpr char kContentsTag[] = "CONTENTS";
constexpr char kFileSuffixTag[] = "FILE_SUFFIX";
constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY";
// The calculator takes the path to local directory and desired file suffix to
// mach as input side packets, and outputs the contents of those files that
// match the pattern. Those matched files will be sent sequentially through the
@ -35,16 +40,16 @@ namespace mediapipe {
class LocalFilePatternContentsCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("FILE_DIRECTORY").Set<std::string>();
cc->InputSidePackets().Tag("FILE_SUFFIX").Set<std::string>();
cc->Outputs().Tag("CONTENTS").Set<std::string>();
cc->InputSidePackets().Tag(kFileDirectoryTag).Set<std::string>();
cc->InputSidePackets().Tag(kFileSuffixTag).Set<std::string>();
cc->Outputs().Tag(kContentsTag).Set<std::string>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory(
cc->InputSidePackets().Tag("FILE_DIRECTORY").Get<std::string>(),
cc->InputSidePackets().Tag("FILE_SUFFIX").Get<std::string>(),
cc->InputSidePackets().Tag(kFileDirectoryTag).Get<std::string>(),
cc->InputSidePackets().Tag(kFileSuffixTag).Get<std::string>(),
&filenames_));
return absl::OkStatus();
}
@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase {
filenames_[current_output_], contents.get()));
++current_output_;
cc->Outputs()
.Tag("CONTENTS")
.Tag(kContentsTag)
.Add(contents.release(), Timestamp(current_output_));
} else {
return tool::StatusStop();

View File

@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) {
// Initialize the clock.
if (cc->InputSidePackets().HasTag(kClockTag)) {
clock_ = cc->InputSidePackets()
.Tag("CLOCK")
.Tag(kClockTag)
.Get<std::shared_ptr<::mediapipe::Clock>>();
} else {
clock_ = std::shared_ptr<::mediapipe::Clock>(

View File

@ -17,6 +17,12 @@
namespace mediapipe {
constexpr char kThresholdTag[] = "THRESHOLD";
constexpr char kRejectTag[] = "REJECT";
constexpr char kAcceptTag[] = "ACCEPT";
constexpr char kFlagTag[] = "FLAG";
constexpr char kFloatTag[] = "FLOAT";
// Applies a threshold on a stream of numeric values and outputs a flag and/or
// accept/reject stream. The threshold can be specified by one of the following:
// 1) Input stream.
@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase {
REGISTER_CALCULATOR(ThresholdingCalculator);
absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("FLOAT"));
cc->Inputs().Tag("FLOAT").Set<float>();
RET_CHECK(cc->Inputs().HasTag(kFloatTag));
cc->Inputs().Tag(kFloatTag).Set<float>();
if (cc->Outputs().HasTag("FLAG")) {
cc->Outputs().Tag("FLAG").Set<bool>();
if (cc->Outputs().HasTag(kFlagTag)) {
cc->Outputs().Tag(kFlagTag).Set<bool>();
}
if (cc->Outputs().HasTag("ACCEPT")) {
cc->Outputs().Tag("ACCEPT").Set<bool>();
if (cc->Outputs().HasTag(kAcceptTag)) {
cc->Outputs().Tag(kAcceptTag).Set<bool>();
}
if (cc->Outputs().HasTag("REJECT")) {
cc->Outputs().Tag("REJECT").Set<bool>();
if (cc->Outputs().HasTag(kRejectTag)) {
cc->Outputs().Tag(kRejectTag).Set<bool>();
}
if (cc->Inputs().HasTag("THRESHOLD")) {
cc->Inputs().Tag("THRESHOLD").Set<double>();
if (cc->Inputs().HasTag(kThresholdTag)) {
cc->Inputs().Tag(kThresholdTag).Set<double>();
}
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
cc->InputSidePackets().Tag("THRESHOLD").Set<double>();
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
cc->InputSidePackets().Tag(kThresholdTag).Set<double>();
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
<< "Using both the threshold input side packet and input stream is not "
"supported.";
}
@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<::mediapipe::ThresholdingCalculatorOptions>();
if (options.has_threshold()) {
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
<< "Using both the threshold option and input stream is not supported.";
RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD"))
RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag))
<< "Using both the threshold option and input side packet is not "
"supported.";
threshold_ = options.threshold();
}
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get<double>();
}
return absl::OkStatus();
}
absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag("THRESHOLD") &&
!cc->Inputs().Tag("THRESHOLD").IsEmpty()) {
threshold_ = cc->Inputs().Tag("THRESHOLD").Get<double>();
if (cc->Inputs().HasTag(kThresholdTag) &&
!cc->Inputs().Tag(kThresholdTag).IsEmpty()) {
threshold_ = cc->Inputs().Tag(kThresholdTag).Get<double>();
}
bool accept = false;
RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty());
accept =
static_cast<double>(cc->Inputs().Tag("FLOAT").Get<float>()) > threshold_;
RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty());
accept = static_cast<double>(cc->Inputs().Tag(kFloatTag).Get<float>()) >
threshold_;
if (cc->Outputs().HasTag("FLAG")) {
cc->Outputs().Tag("FLAG").AddPacket(
if (cc->Outputs().HasTag(kFlagTag)) {
cc->Outputs().Tag(kFlagTag).AddPacket(
MakePacket<bool>(accept).At(cc->InputTimestamp()));
}
if (accept && cc->Outputs().HasTag("ACCEPT")) {
cc->Outputs().Tag("ACCEPT").AddPacket(
MakePacket<bool>(true).At(cc->InputTimestamp()));
if (accept && cc->Outputs().HasTag(kAcceptTag)) {
cc->Outputs()
.Tag(kAcceptTag)
.AddPacket(MakePacket<bool>(true).At(cc->InputTimestamp()));
}
if (!accept && cc->Outputs().HasTag("REJECT")) {
cc->Outputs().Tag("REJECT").AddPacket(
MakePacket<bool>(false).At(cc->InputTimestamp()));
if (!accept && cc->Outputs().HasTag(kRejectTag)) {
cc->Outputs()
.Tag(kRejectTag)
.AddPacket(MakePacket<bool>(false).At(cc->InputTimestamp()));
}
return absl::OkStatus();

View File

@ -39,6 +39,14 @@
namespace mediapipe {
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
constexpr char kSummaryTag[] = "SUMMARY";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kScoresTag[] = "SCORES";
// A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements, classification protos, and summary std::string
// (in csv format).
@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase {
REGISTER_CALCULATOR(TopKScoresCalculator);
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("SCORES"));
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
cc->Outputs().Tag("TOP_K_INDEXES").Set<std::vector<int>>();
RET_CHECK(cc->Inputs().HasTag(kScoresTag));
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs().Tag(kTopKIndexesTag).Set<std::vector<int>>();
}
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
cc->Outputs().Tag("TOP_K_SCORES").Set<std::vector<float>>();
if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs().Tag(kTopKScoresTag).Set<std::vector<float>>();
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
}
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
if (cc->Outputs().HasTag(kClassificationsTag)) {
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
}
if (cc->Outputs().HasTag("SUMMARY")) {
cc->Outputs().Tag("SUMMARY").Set<std::string>();
if (cc->Outputs().HasTag(kSummaryTag)) {
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
}
return absl::OkStatus();
}
@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
if (options.has_label_map_path()) {
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
RET_CHECK(!label_map_.empty());
}
return absl::OkStatus();
@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
const std::vector<float>& input_vector =
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
std::vector<int> top_k_indexes;
std::vector<float> top_k_scores;
@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
top_k_labels.push_back(label_map_[index]);
}
}
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs()
.Tag("TOP_K_INDEXES")
.Tag(kTopKIndexesTag)
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs()
.Tag("TOP_K_SCORES")
.Tag(kTopKScoresTag)
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs()
.Tag("TOP_K_LABELS")
.Tag(kTopKLabelsTag)
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("SUMMARY")) {
if (cc->Outputs().HasTag(kSummaryTag)) {
std::vector<std::string> results;
for (int index = 0; index < top_k_indexes.size(); ++index) {
if (label_map_loaded_) {
@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
}
}
cc->Outputs().Tag("SUMMARY").AddPacket(
MakePacket<std::string>(absl::StrJoin(results, ","))
cc->Outputs()
.Tag(kSummaryTag)
.AddPacket(MakePacket<std::string>(absl::StrJoin(results, ","))
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification =

View File

@ -23,6 +23,10 @@
namespace mediapipe {
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kScoresTag[] = "SCORES";
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "TopKScoresCalculator"
@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, indexes.size());
EXPECT_EQ(3, indexes[0]);
EXPECT_EQ(0, indexes[1]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(2, scores.size());
@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(4, indexes.size());
@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
EXPECT_EQ(2, indexes[2]);
EXPECT_EQ(1, indexes[3]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(4, scores.size());
@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
runner.MutableInputs()
->Tag(kScoresTag)
.packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
runner.Outputs().Tag(kTopKIndexesTag).packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(3, indexes.size());
@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
EXPECT_EQ(0, indexes[1]);
EXPECT_EQ(2, indexes[2]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
runner.Outputs().Tag(kTopKScoresTag).packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(3, scores.size());

View File

@ -47,6 +47,21 @@
namespace mediapipe {
constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT";
constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME";
constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING";
constexpr char kVizTag[] = "VIZ";
constexpr char kBoxesTag[] = "BOXES";
constexpr char kReacqSwitchTag[] = "REACQ_SWITCH";
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
constexpr char kAddIndexTag[] = "ADD_INDEX";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kDescriptorsTag[] = "DESCRIPTORS";
constexpr char kFeaturesTag[] = "FEATURES";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES";
constexpr char kTrackingTag[] = "TRACKING";
// A calculator to detect reappeared box positions from single frame.
//
// Input stream:
@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase {
REGISTER_CALCULATOR(BoxDetectorCalculator);
absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("TRACKING")) {
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
if (cc->Inputs().HasTag(kTrackingTag)) {
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
}
if (cc->Inputs().HasTag("TRACKED_BOXES")) {
cc->Inputs().Tag("TRACKED_BOXES").Set<TimedBoxProtoList>();
if (cc->Inputs().HasTag(kTrackedBoxesTag)) {
cc->Inputs().Tag(kTrackedBoxesTag).Set<TimedBoxProtoList>();
}
if (cc->Inputs().HasTag("VIDEO")) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag("FEATURES")) {
RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS"))
if (cc->Inputs().HasTag(kFeaturesTag)) {
RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag))
<< "FEATURES and DESCRIPTORS need to be specified together.";
cc->Inputs().Tag("FEATURES").Set<std::vector<cv::KeyPoint>>();
cc->Inputs().Tag(kFeaturesTag).Set<std::vector<cv::KeyPoint>>();
}
if (cc->Inputs().HasTag("DESCRIPTORS")) {
RET_CHECK(cc->Inputs().HasTag("FEATURES"))
if (cc->Inputs().HasTag(kDescriptorsTag)) {
RET_CHECK(cc->Inputs().HasTag(kFeaturesTag))
<< "FEATURES and DESCRIPTORS need to be specified together.";
cc->Inputs().Tag("DESCRIPTORS").Set<std::vector<float>>();
cc->Inputs().Tag(kDescriptorsTag).Set<std::vector<float>>();
}
if (cc->Inputs().HasTag("IMAGE_SIZE")) {
cc->Inputs().Tag("IMAGE_SIZE").Set<std::pair<int, int>>();
if (cc->Inputs().HasTag(kImageSizeTag)) {
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
}
if (cc->Inputs().HasTag("ADD_INDEX")) {
cc->Inputs().Tag("ADD_INDEX").Set<std::string>();
if (cc->Inputs().HasTag(kAddIndexTag)) {
cc->Inputs().Tag(kAddIndexTag).Set<std::string>();
}
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
}
if (cc->Inputs().HasTag("REACQ_SWITCH")) {
cc->Inputs().Tag("REACQ_SWITCH").Set<bool>();
if (cc->Inputs().HasTag(kReacqSwitchTag)) {
cc->Inputs().Tag(kReacqSwitchTag).Set<bool>();
}
if (cc->Outputs().HasTag("BOXES")) {
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
}
if (cc->Outputs().HasTag("VIZ")) {
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
if (cc->Outputs().HasTag(kVizTag)) {
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
<< "Output stream VIZ requires VIDEO to be present.";
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
}
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set<std::string>();
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
cc->InputSidePackets().Tag(kIndexProtoStringTag).Set<std::string>();
}
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set<std::string>();
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set<std::string>();
}
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set<int>();
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
cc->InputSidePackets().Tag(kFrameAlignmentTag).Set<int>();
}
return absl::OkStatus();
@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<BoxDetectorCalculatorOptions>();
box_detector_ = BoxDetectorInterface::Create(options_.detector_options());
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
BoxDetectorIndex predefined_index;
if (!predefined_index.ParseFromString(cc->InputSidePackets()
.Tag("INDEX_PROTO_STRING")
.Tag(kIndexProtoStringTag)
.Get<std::string>())) {
LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING";
}
@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
box_detector_->AddBoxDetectorIndex(predefined_index);
}
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
write_index_ = true;
}
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get<int>();
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
frame_alignment_ =
cc->InputSidePackets().Tag(kFrameAlignmentTag).Get<int>();
}
return absl::OkStatus();
@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
const int64 timestamp_msec = timestamp.Value() / 1000;
InputStream* cancel_object_id_stream =
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
cc->Inputs().HasTag(kCancelObjectIdTag)
? &(cc->Inputs().Tag(kCancelObjectIdTag))
: nullptr;
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
const int cancel_object_id = cancel_object_id_stream->Get<int>();
box_detector_->CancelBoxDetection(cancel_object_id);
}
InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX")
? &(cc->Inputs().Tag("ADD_INDEX"))
InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag)
? &(cc->Inputs().Tag(kAddIndexTag))
: nullptr;
if (add_index_stream && !add_index_stream->IsEmpty()) {
BoxDetectorIndex predefined_index;
@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
box_detector_->AddBoxDetectorIndex(predefined_index);
}
InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH")
? &(cc->Inputs().Tag("REACQ_SWITCH"))
InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag)
? &(cc->Inputs().Tag(kReacqSwitchTag))
: nullptr;
if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) {
detector_switch_ = reacq_switch_stream->Get<bool>();
@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
? &(cc->Inputs().Tag("TRACKING"))
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
? &(cc->Inputs().Tag(kTrackingTag))
: nullptr;
InputStream* video_stream =
cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
InputStream* feature_stream = cc->Inputs().HasTag("FEATURES")
? &(cc->Inputs().Tag("FEATURES"))
cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag)
? &(cc->Inputs().Tag(kFeaturesTag))
: nullptr;
InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS")
? &(cc->Inputs().Tag("DESCRIPTORS"))
InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag)
? &(cc->Inputs().Tag(kDescriptorsTag))
: nullptr;
CHECK(track_stream != nullptr || video_stream != nullptr ||
@ -266,8 +282,9 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
<< "One and only one of {tracking_data, input image frame, "
"feature/descriptor} need to be valid.";
InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES")
? &(cc->Inputs().Tag("TRACKED_BOXES"))
InputStream* tracked_boxes_stream =
cc->Inputs().HasTag(kTrackedBoxesTag)
? &(cc->Inputs().Tag(kTrackedBoxesTag))
: nullptr;
std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList());
@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
}
const auto& image_size =
cc->Inputs().Tag("IMAGE_SIZE").Get<std::pair<int, int>>();
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
float inv_scale = 1.0f / std::max(image_size.first, image_size.second);
TimedBoxProtoList tracked_boxes;
@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
detected_boxes.get());
}
if (cc->Outputs().HasTag("VIZ")) {
if (cc->Outputs().HasTag(kVizTag)) {
cv::Mat viz_view;
std::unique_ptr<ImageFrame> viz_frame;
if (video_stream != nullptr && !video_stream->IsEmpty()) {
@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
for (const auto& box : detected_boxes->box()) {
RenderBox(box, &viz_view);
}
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
}
if (cc->Outputs().HasTag("BOXES")) {
cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp);
if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp);
}
return absl::OkStatus();
@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) {
if (write_index_) {
BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex();
MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents(
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get<std::string>(),
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get<std::string>(),
index.SerializeAsString()));
}
return absl::OkStatus();

View File

@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2;
namespace {
constexpr char kCacheDirTag[] = "CACHE_DIR";
constexpr char kInitialPosTag[] = "INITIAL_POS";
constexpr char kRaBoxesTag[] = "RA_BOXES";
constexpr char kBoxesTag[] = "BOXES";
constexpr char kVizTag[] = "VIZ";
constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING";
constexpr char kRaTrackTag[] = "RA_TRACK";
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
constexpr char kRestartPosTag[] = "RESTART_POS";
constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING";
constexpr char kStartPosTag[] = "START_POS";
constexpr char kStartTag[] = "START";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kTrackTimeTag[] = "TRACK_TIME";
constexpr char kTrackingTag[] = "TRACKING";
// Convert box position according to rotation angle in degrees.
void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom,
float in_right, int rotation, float* out_top,
@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec,
} // namespace.
absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("TRACKING")) {
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
if (cc->Inputs().HasTag(kTrackingTag)) {
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
}
if (cc->Inputs().HasTag("TRACK_TIME")) {
RET_CHECK(cc->Inputs().HasTag("TRACKING"))
if (cc->Inputs().HasTag(kTrackTimeTag)) {
RET_CHECK(cc->Inputs().HasTag(kTrackingTag))
<< "TRACK_TIME needs TRACKING input";
cc->Inputs().Tag("TRACK_TIME").SetAny();
cc->Inputs().Tag(kTrackTimeTag).SetAny();
}
if (cc->Inputs().HasTag("VIDEO")) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
}
if (cc->Inputs().HasTag("START")) {
if (cc->Inputs().HasTag(kStartTag)) {
// Actual packet content does not matter.
cc->Inputs().Tag("START").SetAny();
cc->Inputs().Tag(kStartTag).SetAny();
}
if (cc->Inputs().HasTag("START_POS")) {
cc->Inputs().Tag("START_POS").Set<TimedBoxProtoList>();
if (cc->Inputs().HasTag(kStartPosTag)) {
cc->Inputs().Tag(kStartPosTag).Set<TimedBoxProtoList>();
}
if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) {
cc->Inputs().Tag("START_POS_PROTO_STRING").Set<std::string>();
if (cc->Inputs().HasTag(kStartPosProtoStringTag)) {
cc->Inputs().Tag(kStartPosProtoStringTag).Set<std::string>();
}
if (cc->Inputs().HasTag("RESTART_POS")) {
cc->Inputs().Tag("RESTART_POS").Set<TimedBoxProtoList>();
if (cc->Inputs().HasTag(kRestartPosTag)) {
cc->Inputs().Tag(kRestartPosTag).Set<TimedBoxProtoList>();
}
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
}
if (cc->Inputs().HasTag("RA_TRACK")) {
cc->Inputs().Tag("RA_TRACK").Set<TimedBoxProtoList>();
if (cc->Inputs().HasTag(kRaTrackTag)) {
cc->Inputs().Tag(kRaTrackTag).Set<TimedBoxProtoList>();
}
if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) {
cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set<std::string>();
if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) {
cc->Inputs().Tag(kRaTrackProtoStringTag).Set<std::string>();
}
if (cc->Outputs().HasTag("VIZ")) {
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
if (cc->Outputs().HasTag(kVizTag)) {
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
<< "Output stream VIZ requires VIDEO to be present.";
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag("BOXES")) {
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
if (cc->Outputs().HasTag(kBoxesTag)) {
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
}
if (cc->Outputs().HasTag("RA_BOXES")) {
cc->Outputs().Tag("RA_BOXES").Set<TimedBoxProtoList>();
if (cc->Outputs().HasTag(kRaBoxesTag)) {
cc->Outputs().Tag(kRaBoxesTag).Set<TimedBoxProtoList>();
}
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS"))
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag))
<< "Unsupported on mobile";
#else
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
cc->InputSidePackets().Tag("INITIAL_POS").Set<std::string>();
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
cc->InputSidePackets().Tag(kInitialPosTag).Set<std::string>();
}
#endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
}
RET_CHECK(cc->Inputs().HasTag("TRACKING") !=
cc->InputSidePackets().HasTag("CACHE_DIR"))
RET_CHECK(cc->Inputs().HasTag(kTrackingTag) !=
cc->InputSidePackets().HasTag(kCacheDirTag))
<< "Either TRACKING or CACHE_DIR needs to be specified.";
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(),
cc->InputSidePackets(), kOptionsTag);
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") ||
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) ||
!options_.has_initial_position())
<< "Can not specify initial position as side packet and via options";
@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
}
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__)
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
LOG(INFO) << "Parsing: "
<< cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>();
<< cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>();
initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>(
cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>());
cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>());
}
#endif // !defined(__ANDROID__) && !defined(__APPLE__) &&
// !defined(__EMSCRIPTEN__)
@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
}
visualize_tracking_data_ =
options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ");
visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ");
options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag);
visualize_state_ =
options_.visualize_state() && cc->Outputs().HasTag(kVizTag);
visualize_internal_state_ =
options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ");
options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag);
// Force recording of internal state for rendering.
if (visualize_internal_state_) {
@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
options_.mutable_tracker_options()->set_record_path_states(true);
}
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
RET_CHECK(!cache_dir_.empty());
box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options()));
} else {
@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
}
if (options_.streaming_track_data_cache_size() > 0) {
RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR"))
RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag))
<< "Streaming mode not compatible with cache dir.";
}
@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
? &(cc->Inputs().Tag("TRACKING"))
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
? &(cc->Inputs().Tag(kTrackingTag))
: nullptr;
InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME")
? &(cc->Inputs().Tag("TRACK_TIME"))
InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag)
? &(cc->Inputs().Tag(kTrackTimeTag))
: nullptr;
// Cache tracking data if possible.
@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
}
InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS")
? &(cc->Inputs().Tag("START_POS"))
InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag)
? &(cc->Inputs().Tag(kStartPosTag))
: nullptr;
MotionBoxMap fast_forward_boxes;
@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
InputStream* start_pos_proto_string_stream =
cc->Inputs().HasTag("START_POS_PROTO_STRING")
? &(cc->Inputs().Tag("START_POS_PROTO_STRING"))
cc->Inputs().HasTag(kStartPosProtoStringTag)
? &(cc->Inputs().Tag(kStartPosProtoStringTag))
: nullptr;
if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) {
if (start_pos_proto_string_stream &&
@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
}
InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS")
? &(cc->Inputs().Tag("RESTART_POS"))
InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag)
? &(cc->Inputs().Tag(kRestartPosTag))
: nullptr;
if (restart_pos_stream && !restart_pos_stream->IsEmpty()) {
@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
InputStream* cancel_object_id_stream =
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
cc->Inputs().HasTag(kCancelObjectIdTag)
? &(cc->Inputs().Tag(kCancelObjectIdTag))
: nullptr;
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
const int cancel_object_id = cancel_object_id_stream->Get<int>();
@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
TrackingData track_data_to_render;
if (cc->Outputs().HasTag("VIZ")) {
InputStream* video_stream = &(cc->Inputs().Tag("VIDEO"));
if (cc->Outputs().HasTag(kVizTag)) {
InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag));
if (!video_stream->IsEmpty()) {
input_view = formats::MatView(&video_stream->Get<ImageFrame>());
@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
++frame_num_since_reset_;
// Generate results for queued up request.
if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) {
if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) {
for (int j = 0; j < queued_track_requests_.size(); ++j) {
const Timestamp& past_time = queued_track_requests_[j];
RET_CHECK(past_time.Value() < timestamp.Value())
@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
// Output for every time.
cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time);
cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time);
}
queued_track_requests_.clear();
@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
// Handle random access track requests.
InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK")
? &(cc->Inputs().Tag("RA_TRACK"))
InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag)
? &(cc->Inputs().Tag(kRaTrackTag))
: nullptr;
if (ra_track_stream && !ra_track_stream->IsEmpty()) {
@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
}
InputStream* ra_track_proto_string_stream =
cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")
? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING"))
cc->Inputs().HasTag(kRaTrackProtoStringTag)
? &(cc->Inputs().Tag(kRaTrackProtoStringTag))
: nullptr;
if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) {
if (ra_track_proto_string_stream &&
@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
// Always output in batch, only output in streaming if tracking data
// is present (might be in fast forward mode instead).
if (cc->Outputs().HasTag("BOXES") &&
if (cc->Outputs().HasTag(kBoxesTag) &&
(box_tracker_ || !track_stream->IsEmpty())) {
std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList());
*boxes = std::move(box_track_list);
cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp);
cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp);
}
if (viz_frame) {
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
}
return absl::OkStatus();
@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack(
}
cc->Outputs()
.Tag("RA_BOXES")
.Tag(kRaBoxesTag)
.Add(result_list.release(), cc->InputTimestamp());
}

View File

@ -29,6 +29,13 @@
namespace mediapipe {
constexpr char kCacheDirTag[] = "CACHE_DIR";
constexpr char kCompleteTag[] = "COMPLETE";
constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK";
constexpr char kTrackingTag[] = "TRACKING";
constexpr char kCameraTag[] = "CAMERA";
constexpr char kFlowTag[] = "FLOW";
using mediapipe::CameraMotion;
using mediapipe::FlowPackager;
using mediapipe::RegionFlowFeatureList;
@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase {
REGISTER_CALCULATOR(FlowPackagerCalculator);
absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().HasTag("FLOW")) {
if (!cc->Inputs().HasTag(kFlowTag)) {
return tool::StatusFail("No input flow was specified.");
}
cc->Inputs().Tag("FLOW").Set<RegionFlowFeatureList>();
cc->Inputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
if (cc->Inputs().HasTag("CAMERA")) {
cc->Inputs().Tag("CAMERA").Set<CameraMotion>();
if (cc->Inputs().HasTag(kCameraTag)) {
cc->Inputs().Tag(kCameraTag).Set<CameraMotion>();
}
if (cc->Outputs().HasTag("TRACKING")) {
cc->Outputs().Tag("TRACKING").Set<TrackingData>();
if (cc->Outputs().HasTag(kTrackingTag)) {
cc->Outputs().Tag(kTrackingTag).Set<TrackingData>();
}
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
cc->Outputs().Tag("TRACKING_CHUNK").Set<TrackingDataChunk>();
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs().Tag(kTrackingChunkTag).Set<TrackingDataChunk>();
}
if (cc->Outputs().HasTag("COMPLETE")) {
cc->Outputs().Tag("COMPLETE").Set<bool>();
if (cc->Outputs().HasTag(kCompleteTag)) {
cc->Outputs().Tag(kCompleteTag).Set<bool>();
}
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
}
return absl::OkStatus();
@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) {
flow_packager_.reset(new FlowPackager(options_.flow_packager_options()));
use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR");
build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK");
use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag);
build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag);
if (use_caching_) {
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
}
return absl::OkStatus();
}
absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
InputStream* flow_stream = &(cc->Inputs().Tag("FLOW"));
InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag));
const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>();
const Timestamp timestamp = flow_stream->Value().Timestamp();
const CameraMotion* camera_motion = nullptr;
if (cc->Inputs().HasTag("CAMERA")) {
InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA"));
if (cc->Inputs().HasTag(kCameraTag)) {
InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag));
camera_motion = &camera_stream->Get<CameraMotion>();
}
@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
if (frame_idx_ > 0) {
item->set_prev_timestamp_usec(prev_timestamp_.Value());
}
if (cc->Outputs().HasTag("TRACKING")) {
if (cc->Outputs().HasTag(kTrackingTag)) {
// Need to copy as output is requested.
*item->mutable_tracking_data() = *tracking_data;
} else {
@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
options_.caching_chunk_size_msec() * (chunk_idx_ + 1);
if (timestamp.Value() / 1000 >= next_chunk_msec) {
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs()
.Tag("TRACKING_CHUNK")
.Tag(kTrackingChunkTag)
.Add(new TrackingDataChunk(tracking_chunk_),
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
}
@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
}
}
if (cc->Outputs().HasTag("TRACKING")) {
if (cc->Outputs().HasTag(kTrackingTag)) {
cc->Outputs()
.Tag("TRACKING")
.Tag(kTrackingTag)
.Add(tracking_data.release(), flow_stream->Value().Timestamp());
}
@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
if (frame_idx_ > 0) {
tracking_chunk_.set_last_chunk(true);
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
cc->Outputs()
.Tag("TRACKING_CHUNK")
.Tag(kTrackingChunkTag)
.Add(new TrackingDataChunk(tracking_chunk_),
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
}
@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
}
}
if (cc->Outputs().HasTag("COMPLETE")) {
cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream());
if (cc->Outputs().HasTag(kCompleteTag)) {
cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream());
}
return absl::OkStatus();

View File

@ -38,6 +38,18 @@
namespace mediapipe {
constexpr char kDownsampleTag[] = "DOWNSAMPLE";
constexpr char kCsvFileTag[] = "CSV_FILE";
constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT";
constexpr char kVideoOutTag[] = "VIDEO_OUT";
constexpr char kDenseFgTag[] = "DENSE_FG";
constexpr char kVizTag[] = "VIZ";
constexpr char kSaliencyTag[] = "SALIENCY";
constexpr char kCameraTag[] = "CAMERA";
constexpr char kFlowTag[] = "FLOW";
constexpr char kSelectionTag[] = "SELECTION";
constexpr char kVideoTag[] = "VIDEO";
using mediapipe::AffineAdapter;
using mediapipe::CameraMotion;
using mediapipe::FrameSelectionResult;
@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase {
REGISTER_CALCULATOR(MotionAnalysisCalculator);
absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) {
if (cc->Inputs().HasTag("VIDEO")) {
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
if (cc->Inputs().HasTag(kVideoTag)) {
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
}
// Optional input stream from frame selection calculator.
if (cc->Inputs().HasTag("SELECTION")) {
cc->Inputs().Tag("SELECTION").Set<FrameSelectionResult>();
if (cc->Inputs().HasTag(kSelectionTag)) {
cc->Inputs().Tag(kSelectionTag).Set<FrameSelectionResult>();
}
RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION"))
RET_CHECK(cc->Inputs().HasTag(kVideoTag) ||
cc->Inputs().HasTag(kSelectionTag))
<< "Either VIDEO, SELECTION must be specified.";
if (cc->Outputs().HasTag("FLOW")) {
cc->Outputs().Tag("FLOW").Set<RegionFlowFeatureList>();
if (cc->Outputs().HasTag(kFlowTag)) {
cc->Outputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
}
if (cc->Outputs().HasTag("CAMERA")) {
cc->Outputs().Tag("CAMERA").Set<CameraMotion>();
if (cc->Outputs().HasTag(kCameraTag)) {
cc->Outputs().Tag(kCameraTag).Set<CameraMotion>();
}
if (cc->Outputs().HasTag("SALIENCY")) {
cc->Outputs().Tag("SALIENCY").Set<SalientPointFrame>();
if (cc->Outputs().HasTag(kSaliencyTag)) {
cc->Outputs().Tag(kSaliencyTag).Set<SalientPointFrame>();
}
if (cc->Outputs().HasTag("VIZ")) {
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
if (cc->Outputs().HasTag(kVizTag)) {
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag("DENSE_FG")) {
cc->Outputs().Tag("DENSE_FG").Set<ImageFrame>();
if (cc->Outputs().HasTag(kDenseFgTag)) {
cc->Outputs().Tag(kDenseFgTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag("VIDEO_OUT")) {
cc->Outputs().Tag("VIDEO_OUT").Set<ImageFrame>();
if (cc->Outputs().HasTag(kVideoOutTag)) {
cc->Outputs().Tag(kVideoOutTag).Set<ImageFrame>();
}
if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) {
if (cc->Outputs().HasTag(kGrayVideoOutTag)) {
// We only output grayscale video if we're actually performing full region-
// flow analysis on the video.
RET_CHECK(cc->Inputs().HasTag("VIDEO") &&
!cc->Inputs().HasTag("SELECTION"));
cc->Outputs().Tag("GRAY_VIDEO_OUT").Set<ImageFrame>();
RET_CHECK(cc->Inputs().HasTag(kVideoTag) &&
!cc->Inputs().HasTag(kSelectionTag));
cc->Outputs().Tag(kGrayVideoOutTag).Set<ImageFrame>();
}
if (cc->InputSidePackets().HasTag("CSV_FILE")) {
cc->InputSidePackets().Tag("CSV_FILE").Set<std::string>();
if (cc->InputSidePackets().HasTag(kCsvFileTag)) {
cc->InputSidePackets().Tag(kCsvFileTag).Set<std::string>();
}
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
cc->InputSidePackets().Tag("DOWNSAMPLE").Set<float>();
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
cc->InputSidePackets().Tag(kDownsampleTag).Set<float>();
}
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(),
cc->InputSidePackets(), kOptionsTag);
video_input_ = cc->Inputs().HasTag("VIDEO");
selection_input_ = cc->Inputs().HasTag("SELECTION");
region_flow_feature_output_ = cc->Outputs().HasTag("FLOW");
camera_motion_output_ = cc->Outputs().HasTag("CAMERA");
saliency_output_ = cc->Outputs().HasTag("SALIENCY");
visualize_output_ = cc->Outputs().HasTag("VIZ");
dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG");
video_output_ = cc->Outputs().HasTag("VIDEO_OUT");
grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT");
csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE");
video_input_ = cc->Inputs().HasTag(kVideoTag);
selection_input_ = cc->Inputs().HasTag(kSelectionTag);
region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag);
camera_motion_output_ = cc->Outputs().HasTag(kCameraTag);
saliency_output_ = cc->Outputs().HasTag(kSaliencyTag);
visualize_output_ = cc->Outputs().HasTag(kVizTag);
dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag);
video_output_ = cc->Outputs().HasTag(kVideoOutTag);
grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag);
csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag);
hybrid_meta_analysis_ = options_.meta_analysis() ==
MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID;
@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
if (csv_file_input_) {
// Read from file and parse.
const std::string filename =
cc->InputSidePackets().Tag("CSV_FILE").Get<std::string>();
cc->InputSidePackets().Tag(kCsvFileTag).Get<std::string>();
std::string file_contents;
std::ifstream input_file(filename, std::ios::in);
@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
// Get video header from video or selection input if present.
const VideoHeader* video_header = nullptr;
if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) {
video_header = &(cc->Inputs().Tag("VIDEO").Header().Get<VideoHeader>());
if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) {
video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get<VideoHeader>());
} else if (selection_input_ &&
!cc->Inputs().Tag("SELECTION").Header().IsEmpty()) {
video_header = &(cc->Inputs().Tag("SELECTION").Header().Get<VideoHeader>());
!cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) {
video_header =
&(cc->Inputs().Tag(kSelectionTag).Header().Get<VideoHeader>());
} else {
LOG(WARNING) << "No input video header found. Downstream calculators "
"expecting video headers are likely to fail.";
@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
with_saliency_ = options_.analysis_options().compute_motion_saliency();
// Force computation of saliency if requested as output.
if (cc->Outputs().HasTag("SALIENCY")) {
if (cc->Outputs().HasTag(kSaliencyTag)) {
with_saliency_ = true;
if (!options_.analysis_options().compute_motion_saliency()) {
LOG(WARNING) << "Enable saliency computation. Set "
@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
}
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
options_.mutable_analysis_options()
->mutable_flow_options()
->set_downsample_factor(
cc->InputSidePackets().Tag("DOWNSAMPLE").Get<float>());
cc->InputSidePackets().Tag(kDownsampleTag).Get<float>());
}
// If no video header is provided, just return and initialize on the first
@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE ///////////////
if (visualize_output_) {
cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header)));
cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header)));
}
if (video_output_) {
cc->Outputs()
.Tag("VIDEO_OUT")
.Tag(kVideoOutTag)
.SetHeader(Adopt(new VideoHeader(*video_header)));
}
if (cc->Outputs().HasTag("DENSE_FG")) {
if (cc->Outputs().HasTag(kDenseFgTag)) {
std::unique_ptr<VideoHeader> foreground_header(
new VideoHeader(*video_header));
foreground_header->format = ImageFormat::GRAY8;
cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release()));
}
if (cc->Outputs().HasTag("CAMERA")) {
cc->Outputs().Tag("CAMERA").SetHeader(
Adopt(new VideoHeader(*video_header)));
}
if (cc->Outputs().HasTag("SALIENCY")) {
cc->Outputs()
.Tag("SALIENCY")
.Tag(kDenseFgTag)
.SetHeader(Adopt(foreground_header.release()));
}
if (cc->Outputs().HasTag(kCameraTag)) {
cc->Outputs()
.Tag(kCameraTag)
.SetHeader(Adopt(new VideoHeader(*video_header)));
}
if (cc->Outputs().HasTag(kSaliencyTag)) {
cc->Outputs()
.Tag(kSaliencyTag)
.SetHeader(Adopt(new VideoHeader(*video_header)));
}
@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
}
InputStream* video_stream =
video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
InputStream* selection_stream =
selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr;
selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr;
// Checked on Open.
CHECK(video_stream || selection_stream);
@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
CameraMotion output_motion = meta_motions_.front();
meta_motions_.pop_front();
output_motion.set_timestamp_usec(timestamp.Value());
cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion),
timestamp);
cc->Outputs()
.Tag(kCameraTag)
.Add(new CameraMotion(output_motion), timestamp);
}
if (region_flow_feature_output_) {
@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
meta_features_.pop_front();
output_features.set_timestamp_usec(timestamp.Value());
cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features),
timestamp);
cc->Outputs().Tag(kFlowTag).Add(
new RegionFlowFeatureList(output_features), timestamp);
}
++frame_idx_;
@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) {
// Output concatenated results, nothing to compute here.
if (camera_motion_output_) {
cc->Outputs().Tag("CAMERA").Add(
frame_selection_result->release_camera_motion(), timestamp);
cc->Outputs()
.Tag(kCameraTag)
.Add(frame_selection_result->release_camera_motion(), timestamp);
}
if (region_flow_feature_output_) {
cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(),
timestamp);
cc->Outputs().Tag(kFlowTag).Add(
frame_selection_result->release_features(), timestamp);
}
if (video_output_) {
cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value());
cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value());
}
return absl::OkStatus();
@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
grayscale_mat.copyTo(image_frame_mat);
cc->Outputs()
.Tag("GRAY_VIDEO_OUT")
.Tag(kGrayVideoOutTag)
.Add(grayscale_image.release(), timestamp);
}
@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
*feature_list, *camera_motion,
with_saliency_ ? saliency[k].get() : nullptr, &visualization);
cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp);
cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp);
}
// Output dense foreground mask.
@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
cv::Mat foreground = formats::MatView(foreground_frame.get());
motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion,
&foreground);
cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp);
cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp);
}
// Output flow features if requested.
if (region_flow_feature_output_) {
cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp);
cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp);
}
// Output camera motion.
if (camera_motion_output_) {
cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp);
cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp);
}
if (video_output_) {
cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]);
cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]);
}
// Output saliency.
if (saliency_output_) {
cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp);
cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp);
}
}

View File

@ -27,6 +27,12 @@
namespace mediapipe {
namespace {
constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kVideoTag[] = "VIDEO";
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
// cv::VideoCapture set data type to unsigned char by default. Therefore, the
// image format is only related to the number of channles the cv::Mat has.
ImageFormat::Format GetImageFormat(int num_channels) {
@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) {
class OpenCvVideoDecoderCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
cc->Outputs().Tag("VIDEO").Set<ImageFrame>();
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
cc->Outputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
cc->InputSidePackets().Tag(kInputFilePathTag).Set<std::string>();
cc->Outputs().Tag(kVideoTag).Set<ImageFrame>();
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
cc->Outputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
}
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set<std::string>();
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set<std::string>();
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
const std::string& input_file_path =
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
cc->InputSidePackets().Tag(kInputFilePathTag).Get<std::string>();
cap_ = absl::make_unique<cv::VideoCapture>(input_file_path);
if (!cap_->isOpened()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
header->frame_rate = fps;
header->duration = frame_count_ / fps;
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
cc->Outputs()
.Tag("VIDEO_PRESTREAM")
.Tag(kVideoPrestreamTag)
.Add(header.release(), Timestamp::PreStream());
cc->Outputs().Tag("VIDEO_PRESTREAM").Close();
cc->Outputs().Tag(kVideoPrestreamTag).Close();
}
// Rewind to the very first frame.
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
#ifdef HAVE_FFMPEG
std::string saved_audio_path = std::tmpnam(nullptr);
std::string ffmpeg_command =
@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str());
if (status_code == 0) {
cc->OutputSidePackets()
.Tag("SAVED_AUDIO_PATH")
.Tag(kSavedAudioPathTag)
.Set(MakePacket<std::string>(saved_audio_path));
} else {
LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path
<< " by executing the following command: "
<< ffmpeg_command;
cc->OutputSidePackets()
.Tag("SAVED_AUDIO_PATH")
.Tag(kSavedAudioPathTag)
.Set(MakePacket<std::string>(std::string()));
}
#else
@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
// If the timestamp of the current frame is not greater than the one of the
// previous frame, the new frame will be discarded.
if (prev_timestamp_ < timestamp) {
cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp);
cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp);
prev_timestamp_ = timestamp;
decoded_frames_++;
}

View File

@ -29,6 +29,10 @@ namespace mediapipe {
namespace {
constexpr char kVideoTag[] = "VIDEO";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/video/"
"testdata/format_MP4_AVC720P_AAC.video"));
MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM")
.Tag(kVideoPrestreamTag)
.packets[0]
.ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(1280, header.width);
EXPECT_EQ(640, header.height);
@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
// The number of the output packets should be 180.
// Some OpenCV version returns the first two frames with the same timestamp on
// macos and we might miss one frame here.
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
EXPECT_GE(num_of_packets, 179);
for (int i = 0; i < num_of_packets; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(1280, output_mat.size().width);
@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/video/"
"testdata/format_FLV_H264_AAC.video"));
MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM")
.Tag(kVideoPrestreamTag)
.packets[0]
.ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(640, header.width);
EXPECT_EQ(320, header.height);
@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
// can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4).
// EXPECT_FLOAT_EQ(6.0f, header.duration);
// EXPECT_FLOAT_EQ(30.0f, header.frame_rate);
EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size());
EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size());
for (int i = 0; i < 180; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(640, output_mat.size().width);
@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
file::JoinPath("./",
"/mediapipe/calculators/video/"
"testdata/format_MKV_VP8_VORBIS.video"));
MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
MP_EXPECT_OK(runner.Outputs()
.Tag("VIDEO_PRESTREAM")
.Tag(kVideoPrestreamTag)
.packets[0]
.ValidateAsType<VideoHeader>());
const mediapipe::VideoHeader& header =
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
EXPECT_EQ(ImageFormat::SRGB, header.format);
EXPECT_EQ(640, header.width);
EXPECT_EQ(320, header.height);
@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
// The number of the output packets should be 180.
// Some OpenCV version returns the first two frames with the same timestamp on
// macos and we might miss one frame here.
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
EXPECT_GE(num_of_packets, 179);
for (int i = 0; i < num_of_packets; ++i) {
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
cv::Mat output_mat =
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
EXPECT_EQ(640, output_mat.size().width);

View File

@ -36,6 +36,11 @@
namespace mediapipe {
constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH";
constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kVideoTag[] = "VIDEO";
// Encodes the input video stream and produces a media file.
// The media file can be output to the output_file_path specified as a side
// packet. Currently, the calculator only supports one video stream (in
@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase {
};
absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("VIDEO"));
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
RET_CHECK(cc->Inputs().HasTag(kVideoTag));
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
}
RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH"));
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set<std::string>();
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set<std::string>();
RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag));
cc->InputSidePackets().Tag(kOutputFilePathTag).Set<std::string>();
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
cc->InputSidePackets().Tag(kAudioFilePathTag).Set<std::string>();
}
return absl::OkStatus();
}
@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
<< "Video format must be specified in "
"OpenCvVideoEncoderCalculatorOptions";
output_file_path_ =
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get<std::string>();
cc->InputSidePackets().Tag(kOutputFilePathTag).Get<std::string>();
std::vector<std::string> splited_file_path =
absl::StrSplit(output_file_path_, '.');
RET_CHECK(splited_file_path.size() >= 2 &&
@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
// If the video header will be available, the video metadata will be fetched
// from the video header directly. The calculator will receive the video
// header packet at timestamp prestream.
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
return absl::OkStatus();
}
return SetUpVideoWriter(options.fps(), options.width(), options.height());
@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream()) {
const VideoHeader& video_header =
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
return SetUpVideoWriter(video_header.frame_rate, video_header.width,
video_header.height);
}
const ImageFrame& image_frame =
cc->Inputs().Tag("VIDEO").Value().Get<ImageFrame>();
cc->Inputs().Tag(kVideoTag).Value().Get<ImageFrame>();
ImageFormat::Format format = image_frame.Format();
cv::Mat frame;
if (format == ImageFormat::GRAY8) {
@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (frame.empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Receive empty frame at timestamp "
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
<< " in OpenCvVideoEncoderCalculator::Process()";
}
} else {
@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
if (tmp_frame.empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Receive empty frame at timestamp "
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
<< " in OpenCvVideoEncoderCalculator::Process()";
}
if (format == ImageFormat::SRGB) {
@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) {
if (writer_ && writer_->isOpened()) {
writer_->release();
}
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
#ifdef HAVE_FFMPEG
const std::string& audio_file_path =
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get<std::string>();
cc->InputSidePackets().Tag(kAudioFilePathTag).Get<std::string>();
if (audio_file_path.empty()) {
LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the "
"audio tracks to the generated video because the audio "

View File

@ -23,6 +23,11 @@
namespace mediapipe {
namespace {
constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW";
constexpr char kForwardFlowTag[] = "FORWARD_FLOW";
constexpr char kSecondFrameTag[] = "SECOND_FRAME";
constexpr char kFirstFrameTag[] = "FIRST_FRAME";
// Checks that img1 and img2 have the same dimensions.
bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) {
return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height());
@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase {
};
absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().HasTag("FIRST_FRAME") ||
!cc->Inputs().HasTag("SECOND_FRAME")) {
if (!cc->Inputs().HasTag(kFirstFrameTag) ||
!cc->Inputs().HasTag(kSecondFrameTag)) {
return absl::InvalidArgumentError(
"Missing required input streams. Both FIRST_FRAME and SECOND_FRAME "
"must be specified.");
}
cc->Inputs().Tag("FIRST_FRAME").Set<ImageFrame>();
cc->Inputs().Tag("SECOND_FRAME").Set<ImageFrame>();
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
cc->Outputs().Tag("FORWARD_FLOW").Set<OpticalFlowField>();
cc->Inputs().Tag(kFirstFrameTag).Set<ImageFrame>();
cc->Inputs().Tag(kSecondFrameTag).Set<ImageFrame>();
if (cc->Outputs().HasTag(kForwardFlowTag)) {
cc->Outputs().Tag(kForwardFlowTag).Set<OpticalFlowField>();
}
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
cc->Outputs().Tag("BACKWARD_FLOW").Set<OpticalFlowField>();
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
cc->Outputs().Tag(kBackwardFlowTag).Set<OpticalFlowField>();
}
return absl::OkStatus();
}
@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
absl::MutexLock lock(&mutex_);
tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1());
}
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
if (cc->Outputs().HasTag(kForwardFlowTag)) {
forward_requested_ = true;
}
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
backward_requested_ = true;
}
@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
const ImageFrame& first_frame =
cc->Inputs().Tag("FIRST_FRAME").Value().Get<ImageFrame>();
cc->Inputs().Tag(kFirstFrameTag).Value().Get<ImageFrame>();
const ImageFrame& second_frame =
cc->Inputs().Tag("SECOND_FRAME").Value().Get<ImageFrame>();
cc->Inputs().Tag(kSecondFrameTag).Value().Get<ImageFrame>();
if (forward_requested_) {
auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>();
MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame,
forward_optical_flow_field.get()));
cc->Outputs()
.Tag("FORWARD_FLOW")
.Tag(kForwardFlowTag)
.Add(forward_optical_flow_field.release(), cc->InputTimestamp());
}
if (backward_requested_) {
@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame,
backward_optical_flow_field.get()));
cc->Outputs()
.Tag("BACKWARD_FLOW")
.Tag(kBackwardFlowTag)
.Add(backward_optical_flow_field.release(), cc->InputTimestamp());
}
return absl::OkStatus();

View File

@ -19,6 +19,9 @@
namespace mediapipe {
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kFrameTag[] = "FRAME";
// Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp
// PreStream. Note that this calculator only fills in format, width, and height,
// i.e. frame_rate and duration will not be filled, unless:
@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
if (!cc->Inputs().UsesTags()) {
cc->Inputs().Index(0).Set<ImageFrame>();
} else {
cc->Inputs().Tag("FRAME").Set<ImageFrame>();
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
cc->Inputs().Tag(kFrameTag).Set<ImageFrame>();
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
}
cc->Outputs().Index(0).Set<VideoHeader>();
return absl::OkStatus();
@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) {
frame_rate_in_prestream_ = cc->Inputs().UsesTags() &&
cc->Inputs().HasTag("FRAME") &&
cc->Inputs().HasTag("VIDEO_PRESTREAM");
cc->Inputs().HasTag(kFrameTag) &&
cc->Inputs().HasTag(kVideoPrestreamTag);
header_ = absl::make_unique<VideoHeader>();
return absl::OkStatus();
}
@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream(
CalculatorContext* cc) {
cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment();
if (cc->InputTimestamp() == Timestamp::PreStream()) {
RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty());
RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty());
*header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty());
RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty());
*header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero";
} else {
RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty())
RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty())
<< "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream().";
RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty());
const auto& frame = cc->Inputs().Tag("FRAME").Get<ImageFrame>();
RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty());
const auto& frame = cc->Inputs().Tag(kFrameTag).Get<ImageFrame>();
header_->format = frame.Format();
header_->width = frame.Width();
header_->height = frame.Height();

View File

@ -44,28 +44,32 @@ using mediapipe::MakePacket;
using mediapipe::OutputStreamShardSet;
using mediapipe::Timestamp;
namespace proto_ns = mediapipe::proto_ns;
constexpr char kEventTag[] = "EVENT";
constexpr char kOutTag[] = "OUT";
using mediapipe::CalculatorGraph;
using mediapipe::Packet;
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
public:
static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
cc->Outputs().Tag("OUT").Set<int>();
cc->Outputs().Tag("EVENT").Set<int>();
cc->Outputs().Tag(kOutTag).Set<int>();
cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Tag("OUT").AddPacket(
cc->Outputs().Tag(kOutTag).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_)));
count_++;
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus();
}
@ -81,11 +85,11 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
cc->Inputs().Get("", i).SetAny();
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
}
cc->Outputs().Tag("EVENT").Set<int>();
cc->Outputs().Tag(kEventTag).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
@ -98,7 +102,7 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
: mediapipe::tool::StatusStop();
}
absl::Status Close(CalculatorContext* cc) override {
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
return absl::OkStatus();
}

View File

@ -65,6 +65,16 @@ namespace mediapipe {
namespace {
constexpr char kCounter2Tag[] = "COUNTER2";
constexpr char kCounter1Tag[] = "COUNTER1";
constexpr char kExtraTag[] = "EXTRA";
constexpr char kWaitSemTag[] = "WAIT_SEM";
constexpr char kPostSemTag[] = "POST_SEM";
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
constexpr char kOutputTag[] = "OUTPUT";
constexpr char kInputTag[] = "INPUT";
constexpr char kSelectTag[] = "SELECT";
using testing::ElementsAre;
using testing::HasSubstr;
@ -125,8 +135,8 @@ class DemuxTimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag("SELECT").Set<int>();
PacketType* data_input = &cc->Inputs().Tag("INPUT");
cc->Inputs().Tag(kSelectTag).Set<int>();
PacketType* data_input = &cc->Inputs().Tag(kInputTag);
data_input->SetAny();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) {
@ -182,7 +192,7 @@ REGISTER_CALCULATOR(DemuxTimedCalculator);
class MuxTimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("SELECT").Set<int>();
cc->Inputs().Tag(kSelectTag).Set<int>();
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
data_input0->SetAny();
@ -191,7 +201,7 @@ class MuxTimedCalculator : public CalculatorBase {
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
}
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0);
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
@ -598,12 +608,12 @@ class ErrorOnOpenCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
return absl::NotFoundError("expected error");
}
return absl::OkStatus();
@ -920,8 +930,8 @@ class SemaphoreCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("POST_SEM").Set<Semaphore*>();
cc->InputSidePackets().Tag("WAIT_SEM").Set<Semaphore*>();
cc->InputSidePackets().Tag(kPostSemTag).Set<Semaphore*>();
cc->InputSidePackets().Tag(kWaitSemTag).Set<Semaphore*>();
cc->SetTimestampOffset(TimestampDiff(0));
return absl::OkStatus();
}
@ -929,8 +939,8 @@ class SemaphoreCalculator : public CalculatorBase {
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) override {
cc->InputSidePackets().Tag("POST_SEM").Get<Semaphore*>()->Release(1);
cc->InputSidePackets().Tag("WAIT_SEM").Get<Semaphore*>()->Acquire(1);
cc->InputSidePackets().Tag(kPostSemTag).Get<Semaphore*>()->Release(1);
cc->InputSidePackets().Tag(kWaitSemTag).Get<Semaphore*>()->Acquire(1);
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return absl::OkStatus();
}
@ -1177,9 +1187,9 @@ class IncrementingStatusHandler : public StatusHandler {
static absl::Status FillExpectations(
const MediaPipeOptions& extendable_options,
PacketTypeSet* input_side_packets) {
input_side_packets->Tag("EXTRA").SetAny().Optional();
input_side_packets->Tag("COUNTER1").Set<std::unique_ptr<int>>();
input_side_packets->Tag("COUNTER2").Set<std::unique_ptr<int>>();
input_side_packets->Tag(kExtraTag).SetAny().Optional();
input_side_packets->Tag(kCounter1Tag).Set<std::unique_ptr<int>>();
input_side_packets->Tag(kCounter2Tag).Set<std::unique_ptr<int>>();
return absl::OkStatus();
}
@ -1187,7 +1197,7 @@ class IncrementingStatusHandler : public StatusHandler {
const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& pre_run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER1"));
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter1Tag));
(*counter)++;
return pre_run_status_result_;
}
@ -1195,7 +1205,7 @@ class IncrementingStatusHandler : public StatusHandler {
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
const PacketSet& input_side_packets, //
const absl::Status& run_status) {
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER2"));
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter2Tag));
(*counter)++;
return post_run_status_result_;
}
@ -2228,20 +2238,20 @@ class DemuxUntimedCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
cc->Inputs().Tag("INPUT").SetAny();
cc->Inputs().Tag("SELECT").Set<int>();
cc->Inputs().Tag(kInputTag).SetAny();
cc->Inputs().Tag(kSelectTag).Set<int>();
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
id < cc->Outputs().EndId("OUTPUT"); ++id) {
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT"));
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag));
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
int index = cc->Inputs().Tag("SELECT").Get<int>();
if (!cc->Inputs().Tag("INPUT").IsEmpty()) {
int index = cc->Inputs().Tag(kSelectTag).Get<int>();
if (!cc->Inputs().Tag(kInputTag).IsEmpty()) {
cc->Outputs()
.Get("OUTPUT", index)
.AddPacket(cc->Inputs().Tag("INPUT").Value());
.AddPacket(cc->Inputs().Tag(kInputTag).Value());
} else {
cc->Outputs()
.Get("OUTPUT", index)

View File

@ -32,6 +32,11 @@
namespace mediapipe {
namespace {
constexpr char kTag[] = "";
constexpr char kBTag[] = "B";
constexpr char kATag[] = "A";
constexpr char kSideOutputTag[] = "SIDE_OUTPUT";
// Inputs: 2 streams with ints. Headers are strings.
// Input side packets: 1.
// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from
@ -48,7 +53,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0));
cc->InputSidePackets().Index(0).SetAny();
cc->OutputSidePackets()
.Tag("SIDE_OUTPUT")
.Tag(kSideOutputTag)
.SetSameAs(&cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
@ -64,7 +69,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
Adopt(new std::string(absl::StrCat(input_header_string, i))));
}
cc->OutputSidePackets()
.Tag("SIDE_OUTPUT")
.Tag(kSideOutputTag)
.Set(cc->InputSidePackets().Index(0));
return absl::OkStatus();
}
@ -152,7 +157,7 @@ TEST(CalculatorRunner, RunsCalculator) {
Adopt(new int(input_side_packet_content));
MP_ASSERT_OK(runner.Run());
EXPECT_EQ(input_side_packet_content,
runner.OutputSidePackets().Tag("SIDE_OUTPUT").Get<int>());
runner.OutputSidePackets().Tag(kSideOutputTag).Get<int>());
const auto& outputs = runner.Outputs();
ASSERT_EQ(3, outputs.NumEntries());
@ -209,9 +214,9 @@ TEST(CalculatorRunner, MultiTagTestCalculatorOk) {
const auto& outputs = runner.Outputs();
ASSERT_EQ(3, outputs.NumEntries());
for (int ts = 0; ts < 5; ++ts) {
const std::vector<Packet>& a_packets = outputs.Tag("A").packets;
const std::vector<Packet>& b_packets = outputs.Tag("B").packets;
const std::vector<Packet>& c_packets = outputs.Tag("").packets;
const std::vector<Packet>& a_packets = outputs.Tag(kATag).packets;
const std::vector<Packet>& b_packets = outputs.Tag(kBTag).packets;
const std::vector<Packet>& c_packets = outputs.Tag(kTag).packets;
EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp());
EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp());
EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp());

View File

@ -24,6 +24,10 @@
namespace mediapipe {
namespace {
constexpr char kTag2Tag[] = "TAG_2";
constexpr char kTag0Tag[] = "TAG_0";
constexpr char kTag1Tag[] = "TAG_1";
TEST(CollectionTest, BasicByIndex) {
tool::TagAndNameInfo info;
info.names.push_back("name_1");
@ -55,14 +59,14 @@ TEST(CollectionTest, BasicByTag) {
info.names.push_back("name_2");
info.tags.push_back("TAG_2");
internal::Collection<int> collection(info);
collection.Tag("TAG_1") = 101;
collection.Tag("TAG_0") = 100;
collection.Tag("TAG_2") = 102;
collection.Tag(kTag1Tag) = 101;
collection.Tag(kTag0Tag) = 100;
collection.Tag(kTag2Tag) = 102;
// Test the stored values.
EXPECT_EQ(100, collection.Tag("TAG_0"));
EXPECT_EQ(101, collection.Tag("TAG_1"));
EXPECT_EQ(102, collection.Tag("TAG_2"));
EXPECT_EQ(100, collection.Tag(kTag0Tag));
EXPECT_EQ(101, collection.Tag(kTag1Tag));
EXPECT_EQ(102, collection.Tag(kTag2Tag));
// Test access using a range based for.
int i = 0;
for (int num : collection) {

View File

@ -134,6 +134,21 @@ void Tensor::AllocateMtlBuffer(id<MTLDevice> device) const {
#endif // MEDIAPIPE_METAL_ENABLED
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
bool Tensor::NeedsHalfFloatRenderTarget() const {
static bool has_color_buffer_float =
gl_context_->HasGlExtension("WEBGL_color_buffer_float") ||
gl_context_->HasGlExtension("EXT_color_buffer_float");
if (!has_color_buffer_float) {
static bool has_color_buffer_half_float =
gl_context_->HasGlExtension("EXT_color_buffer_half_float");
LOG_IF(FATAL, !has_color_buffer_half_float)
<< "EXT_color_buffer_half_float or WEBGL_color_buffer_float "
<< "required on web to use MP tensor";
return true;
}
return false;
}
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
LOG_IF(FATAL, valid_ == kValidNone)
<< "Tensor must be written prior to read from.";
@ -164,8 +179,24 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
// Set alignment for the proper value (default) to avoid address sanitizer
// error "out of boundary reading".
glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
#ifdef __EMSCRIPTEN__
// Under WebGL1, format must match in order to use glTexSubImage2D, so if we
// have a half-float texture, then uploading from GL_FLOAT here would fail.
// We change the texture's data type to float here to accommodate.
// Furthermore, for a full-image replacement operation, glTexImage2D is
// expected to be more performant than glTexSubImage2D. Note that for WebGL2
// we cannot use glTexImage2D, because we allocate using glTexStorage2D in
// that case, which is incompatible.
if (gl_context_->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
0, GL_RGBA, GL_FLOAT, temp_buffer.get());
texture_is_half_float_ = false;
} else
#endif // __EMSCRIPTEN__
{
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
GL_RGBA, GL_FLOAT, temp_buffer.get());
}
glBindTexture(GL_TEXTURE_2D, 0);
valid_ |= kValidOpenGlTexture2d;
}
@ -175,6 +206,16 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
AllocateOpenGlTexture2d();
#ifdef __EMSCRIPTEN__
// On web, we may have to change type from float to half-float
if (!texture_is_half_float_ && NeedsHalfFloatRenderTarget()) {
glBindTexture(GL_TEXTURE_2D, opengl_texture2d_);
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, 0,
GL_RGBA, GL_HALF_FLOAT_OES, 0 /* data */);
glBindTexture(GL_TEXTURE_2D, 0);
texture_is_half_float_ = true;
}
#endif
valid_ = kValidOpenGlTexture2d;
return {opengl_texture2d_, std::move(lock)};
}
@ -255,8 +296,18 @@ void Tensor::AllocateOpenGlTexture2d() const {
<< "with GLES 2.0";
// Allocate the image data; note that it's no longer RGBA32F, so will be
// lower precision.
auto type = GL_FLOAT;
// On web, we might need to change type to half-float (e.g. for iOS-
// Safari) in order to have a valid framebuffer. See b/194442743 for more
// details.
#ifdef __EMSCRIPTEN__
if (NeedsHalfFloatRenderTarget()) {
type = GL_HALF_FLOAT_OES;
texture_is_half_float_ = true;
}
#endif // __EMSCRIPTEN
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
0, GL_RGBA, GL_FLOAT, 0 /* data */);
0, GL_RGBA, type, 0 /* data */);
}
glBindTexture(GL_TEXTURE_2D, 0);
glGenFramebuffers(1, &frame_buffer_);
@ -443,7 +494,6 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
glPixelStorei(GL_PACK_ALIGNMENT, 4);
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
buffer);
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
const int actual_depth_size =
BhwcDepthFromShape(shape_) * element_size();

View File

@ -266,11 +266,15 @@ class Tensor {
mutable GLuint frame_buffer_ = GL_INVALID_INDEX;
mutable int texture_width_;
mutable int texture_height_;
#ifdef __EMSCRIPTEN__
mutable bool texture_is_half_float_ = false;
#endif // __EMSCRIPTEN__
void AllocateOpenGlTexture2d() const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
mutable GLuint opengl_buffer_ = GL_INVALID_INDEX;
void AllocateOpenGlBuffer() const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
bool NeedsHalfFloatRenderTarget() const;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
};

View File

@ -31,6 +31,11 @@ namespace mediapipe {
namespace {
constexpr char kOutputTag[] = "OUTPUT";
constexpr char kEnableTag[] = "ENABLE";
constexpr char kSelectTag[] = "SELECT";
constexpr char kSideinputTag[] = "SIDEINPUT";
// Shows validation success for a graph and a subgraph.
TEST(GraphValidationTest, InitializeGraphFromProtos) {
auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -323,20 +328,21 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) {
class OptionalSideInputTestCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag("SIDEINPUT").Set<std::string>().Optional();
cc->Inputs().Tag("SELECT").Set<int>().Optional();
cc->Inputs().Tag("ENABLE").Set<bool>().Optional();
cc->Outputs().Tag("OUTPUT").Set<std::string>();
cc->InputSidePackets().Tag(kSideinputTag).Set<std::string>().Optional();
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
cc->Inputs().Tag(kEnableTag).Set<bool>().Optional();
cc->Outputs().Tag(kOutputTag).Set<std::string>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
std::string value("default");
if (cc->InputSidePackets().HasTag("SIDEINPUT")) {
value = cc->InputSidePackets().Tag("SIDEINPUT").Get<std::string>();
if (cc->InputSidePackets().HasTag(kSideinputTag)) {
value = cc->InputSidePackets().Tag(kSideinputTag).Get<std::string>();
}
cc->Outputs().Tag("OUTPUT").Add(new std::string(value),
cc->InputTimestamp());
cc->Outputs()
.Tag(kOutputTag)
.Add(new std::string(value), cc->InputTimestamp());
return absl::OkStatus();
}
};

View File

@ -26,17 +26,20 @@ namespace {
namespace test_ns {
constexpr char kOutTag[] = "OUT";
constexpr char kInTag[] = "IN";
class TestSinkCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
cc->Outputs().Tag("OUT").Set<int>();
cc->Inputs().Tag(kInTag).Set<mediapipe::InputOnlyProto>();
cc->Outputs().Tag(kOutTag).Set<int>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
cc->Outputs().Tag("OUT").AddPacket(
int x = cc->Inputs().Tag(kInTag).Get<mediapipe::InputOnlyProto>().x();
cc->Outputs().Tag(kOutTag).AddPacket(
MakePacket<int>(x).At(cc->InputTimestamp()));
return absl::OkStatus();
}

View File

@ -34,6 +34,19 @@
namespace mediapipe {
constexpr char kOutTag[] = "OUT";
constexpr char kClockTag[] = "CLOCK";
constexpr char kSleepMicrosTag[] = "SLEEP_MICROS";
constexpr char kCloseTag[] = "CLOSE";
constexpr char kProcessTag[] = "PROCESS";
constexpr char kOpenTag[] = "OPEN";
constexpr char kTag[] = "";
constexpr char kMeanTag[] = "MEAN";
constexpr char kDataTag[] = "DATA";
constexpr char kPairTag[] = "PAIR";
constexpr char kLowTag[] = "LOW";
constexpr char kHighTag[] = "HIGH";
using RandomEngine = std::mt19937_64;
// A Calculator that outputs twice the value of its input packet (an int).
@ -95,9 +108,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
PacketTypeSet* input_side_packets, //
PacketTypeSet* output_side_packets) {
input_side_packets->Index(0).Set<uint64>();
output_side_packets->Tag("HIGH").Set<uint32>();
output_side_packets->Tag("LOW").Set<uint32>();
output_side_packets->Tag("PAIR").Set<std::pair<uint32, uint32>>();
output_side_packets->Tag(kHighTag).Set<uint32>();
output_side_packets->Tag(kLowTag).Set<uint32>();
output_side_packets->Tag(kPairTag).Set<std::pair<uint32, uint32>>();
return absl::OkStatus();
}
@ -108,9 +121,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
uint64 value = input_side_packets.Index(0).Get<uint64>();
uint32 high = value >> 32;
uint32 low = value & 0xFFFFFFFF;
output_side_packets->Tag("HIGH") = Adopt(new uint32(high));
output_side_packets->Tag("LOW") = Adopt(new uint32(low));
output_side_packets->Tag("PAIR") =
output_side_packets->Tag(kHighTag) = Adopt(new uint32(high));
output_side_packets->Tag(kLowTag) = Adopt(new uint32(low));
output_side_packets->Tag(kPairTag) =
Adopt(new std::pair<uint32, uint32>(high, low));
return absl::OkStatus();
}
@ -221,8 +234,8 @@ class StdDevCalculator : public CalculatorBase {
StdDevCalculator() {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Tag("DATA").Set<int>();
cc->Inputs().Tag("MEAN").Set<double>();
cc->Inputs().Tag(kDataTag).Set<int>();
cc->Inputs().Tag(kMeanTag).Set<double>();
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
@ -234,15 +247,15 @@ class StdDevCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
if (cc->InputTimestamp() == Timestamp::PreStream()) {
RET_CHECK(cc->Inputs().Tag("DATA").Value().IsEmpty());
RET_CHECK(!cc->Inputs().Tag("MEAN").Value().IsEmpty());
mean_ = cc->Inputs().Tag("MEAN").Get<double>();
RET_CHECK(cc->Inputs().Tag(kDataTag).Value().IsEmpty());
RET_CHECK(!cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
mean_ = cc->Inputs().Tag(kMeanTag).Get<double>();
initialized_ = true;
} else {
RET_CHECK(initialized_);
RET_CHECK(!cc->Inputs().Tag("DATA").Value().IsEmpty());
RET_CHECK(cc->Inputs().Tag("MEAN").Value().IsEmpty());
double diff = cc->Inputs().Tag("DATA").Get<int>() - mean_;
RET_CHECK(!cc->Inputs().Tag(kDataTag).Value().IsEmpty());
RET_CHECK(cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
double diff = cc->Inputs().Tag(kDataTag).Get<int>() - mean_;
cummulative_variance_ += diff * diff;
++count_;
}
@ -564,8 +577,8 @@ class LambdaCalculator : public CalculatorBase {
id < cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).SetAny();
}
if (cc->InputSidePackets().HasTag("") > 0) {
cc->InputSidePackets().Tag("").Set<ProcessFunction>();
if (cc->InputSidePackets().HasTag(kTag) > 0) {
cc->InputSidePackets().Tag(kTag).Set<ProcessFunction>();
}
for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) {
if (cc->InputSidePackets().HasTag(tag)) {
@ -576,24 +589,24 @@ class LambdaCalculator : public CalculatorBase {
}
absl::Status Open(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("OPEN")) {
if (cc->InputSidePackets().HasTag(kOpenTag)) {
return GetContextFn(cc, "OPEN")(cc);
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("PROCESS")) {
if (cc->InputSidePackets().HasTag(kProcessTag)) {
return GetContextFn(cc, "PROCESS")(cc);
}
if (cc->InputSidePackets().HasTag("") > 0) {
if (cc->InputSidePackets().HasTag(kTag) > 0) {
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
}
return absl::OkStatus();
}
absl::Status Close(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("CLOSE")) {
if (cc->InputSidePackets().HasTag(kCloseTag)) {
return GetContextFn(cc, "CLOSE")(cc);
}
return absl::OkStatus();
@ -645,17 +658,18 @@ class PassThroughWithSleepCalculator : public CalculatorBase {
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
cc->InputSidePackets().Tag(kSleepMicrosTag).Set<int>();
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
sleep_micros_ = cc->InputSidePackets().Tag(kSleepMicrosTag).Get<int>();
if (sleep_micros_ < 0) {
return absl::InternalError("SLEEP_MICROS should be >= 0");
}
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
clock_ =
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
@ -678,8 +692,8 @@ class MultiplyIntCalculator : public CalculatorBase {
cc->Inputs().Index(0).Set<int>();
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
RET_CHECK(cc->Outputs().HasTag("OUT"));
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
RET_CHECK(cc->Outputs().HasTag(kOutTag));
cc->Outputs().Tag(kOutTag).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
@ -689,7 +703,7 @@ class MultiplyIntCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
int x = cc->Inputs().Index(0).Value().Get<int>();
int y = cc->Inputs().Index(1).Value().Get<int>();
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
cc->Outputs().Tag(kOutTag).Add(new int(x * y), cc->InputTimestamp());
return absl::OkStatus();
}
};

View File

@ -60,6 +60,18 @@ def mediapipe_aar(
assets: additional assets to be included into the archive.
assets_dir: path where the assets will the packaged.
"""
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
# the OpenCV so libraries will be excluded from the AAR package to
# save the package size.
native.config_setting(
name = "exclude_opencv_so_lib",
define_values = {
"EXCLUDE_OPENCV_SO_LIB": "1",
},
visibility = ["//visibility:public"],
)
_mediapipe_jni(
name = name + "_jni",
gen_libmediapipe = gen_libmediapipe,
@ -133,6 +145,7 @@ EOF
] + select({
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
"//mediapipe/framework/port:disable_opencv": [],
"exclude_opencv_so_lib": [],
}),
assets = assets,
assets_dir = assets_dir,

View File

@ -245,14 +245,6 @@ void Box::Fit(const std::vector<T>& vertices) {
auto system_g = system_h.colPivHouseholderQr();
auto solution = system_g.solve(v).eval();
transformation_.topLeftCorner<3, 4>() = solution.transpose();
// Adjust rotation matrix to its nearest orthogonal matrix.
const auto rotation = transformation_.topLeftCorner<3, 3>();
Eigen::JacobiSVD<Eigen::Matrix3f> svd(
rotation, Eigen::ComputeFullV | Eigen::ComputeFullU);
const Eigen::Matrix3f matrix_u = svd.matrixU();
const Eigen::Matrix3f matrix_v = svd.matrixV();
transformation_.topLeftCorner<3, 3>() = matrix_u * matrix_v.transpose();
Update();
}

View File

@ -254,11 +254,11 @@ class SolutionBase:
for stream_name in self._output_stream_type_info.keys():
self._graph.observe_output_stream(stream_name, callback, True)
input_side_packets = {
self._input_side_packets = {
name: self._make_packet(self._side_input_type_info[name], data)
for name, data in (side_inputs or {}).items()
}
self._graph.start_run(input_side_packets)
self._graph.start_run(self._input_side_packets)
# TODO: Use "inspect.Parameter" to fetch the input argument names and
# types from "_input_stream_type_info" and then auto generate the process
@ -353,6 +353,12 @@ class SolutionBase:
self._input_stream_type_info = None
self._output_stream_type_info = None
def reset(self) -> None:
"""Resets the graph for another run."""
if self._graph:
self._graph.close()
self._graph.start_run(self._input_side_packets)
def _initialize_graph_interface(
self,
validated_graph: validated_graph_config.ValidatedGraphConfig,

View File

@ -298,6 +298,56 @@ class SolutionBaseTest(parameterized.TestCase):
'ImageTransformation.output_height': 0
})
@parameterized.named_parameters(('graph_without_side_packets', """
input_stream: 'image_in'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:transformed_image_in'
output_stream: 'IMAGE:image_out'
}
""", None), ('graph_with_side_packets', """
input_stream: 'image_in'
input_side_packet: 'allow_signal'
input_side_packet: 'rotation_degrees'
output_stream: 'image_out'
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_in'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:transformed_image_in'
}
node {
calculator: 'GateCalculator'
input_stream: 'transformed_image_in'
input_side_packet: 'ALLOW:allow_signal'
output_stream: 'image_out_to_transform'
}
node {
calculator: 'ImageTransformationCalculator'
input_stream: 'IMAGE:image_out_to_transform'
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
output_stream: 'IMAGE:image_out'
}""", {
'allow_signal': True,
'rotation_degrees': 0
}))
def test_solution_reset(self, text_config, side_inputs):
config_proto = text_format.Parse(text_config,
calculator_pb2.CalculatorGraphConfig())
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
with solution_base.SolutionBase(
graph_config=config_proto, side_inputs=side_inputs) as solution:
for _ in range(20):
outputs = solution.process(input_image)
self.assertTrue(np.array_equal(input_image, outputs.image_out))
solution.reset()
def _process_and_verify(self,
config_proto,
side_inputs=None,

View File

@ -26,7 +26,7 @@ import numpy.testing as npt
from mediapipe.python.solutions import objectron as mp_objectron
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
DIFF_THRESHOLD = 35 # pixels
DIFF_THRESHOLD = 30 # pixels
EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457],
[383, 505], [80, 478], [408, 345],
[130, 347], [384, 355], [72, 353]],