Project import generated by Copybara.
GitOrigin-RevId: 8e1da4611d93ccb7d9674713157d43be0348d98f
This commit is contained in:
parent
50c92c6623
commit
b899d17f18
|
@ -220,6 +220,7 @@ import cv2
|
|||
import mediapipe as mp
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_hands = mp.solutions.hands
|
||||
drawing_styles = mp.solutions.drawing_styles
|
||||
|
||||
# For static images:
|
||||
IMAGE_FILES = []
|
||||
|
@ -248,7 +249,9 @@ with mp_hands.Hands(
|
|||
f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})'
|
||||
)
|
||||
mp_drawing.draw_landmarks(
|
||||
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
|
||||
annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
|
||||
drawing_styles.get_default_hand_landmark_style(),
|
||||
drawing_styles.get_default_hand_connection_style())
|
||||
cv2.imwrite(
|
||||
'/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1))
|
||||
|
||||
|
@ -278,7 +281,9 @@ with mp_hands.Hands(
|
|||
if results.multi_hand_landmarks:
|
||||
for hand_landmarks in results.multi_hand_landmarks:
|
||||
mp_drawing.draw_landmarks(
|
||||
image, hand_landmarks, mp_hands.HAND_CONNECTIONS)
|
||||
image, hand_landmarks, mp_hands.HAND_CONNECTIONS,
|
||||
drawing_styles.get_default_hand_landmark_style(),
|
||||
drawing_styles.get_default_hand_connection_style())
|
||||
cv2.imshow('MediaPipe Hands', image)
|
||||
if cv2.waitKey(5) & 0xFF == 27:
|
||||
break
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kDataTag[] = "DATA";
|
||||
constexpr char kHeaderTag[] = "HEADER";
|
||||
|
||||
class AddHeaderCalculatorTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
||||
|
@ -36,11 +39,11 @@ TEST_F(AddHeaderCalculatorTest, HeaderStream) {
|
|||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
runner.MutableInputs()->Tag(kHeaderTag).header =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
|
@ -85,13 +88,14 @@ TEST_F(AddHeaderCalculatorTest, NoPacketsOnHeaderStream) {
|
|||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableInputs()->Tag("HEADER").header =
|
||||
runner.MutableInputs()->Tag(kHeaderTag).header =
|
||||
Adopt(new std::string("my_header"));
|
||||
runner.MutableInputs()->Tag("HEADER").packets.push_back(
|
||||
Adopt(new std::string("not allowed")));
|
||||
runner.MutableInputs()
|
||||
->Tag(kHeaderTag)
|
||||
.packets.push_back(Adopt(new std::string("not allowed")));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
|
@ -108,11 +112,11 @@ TEST_F(AddHeaderCalculatorTest, InputSidePacket) {
|
|||
CalculatorRunner runner(node);
|
||||
|
||||
// Set header and add 5 packets.
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run calculator.
|
||||
|
@ -143,13 +147,13 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) {
|
|||
CalculatorRunner runner(node);
|
||||
|
||||
// Set both headers and add 5 packets.
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||
Adopt(new std::string("my_header"));
|
||||
runner.MutableSidePackets()->Tag("HEADER") =
|
||||
runner.MutableSidePackets()->Tag(kHeaderTag) =
|
||||
Adopt(new std::string("my_header"));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Packet packet = Adopt(new int(i)).At(Timestamp(i * 1000));
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(packet);
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(packet);
|
||||
}
|
||||
|
||||
// Run should fail because header can only be provided one way.
|
||||
|
|
|
@ -19,6 +19,13 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kIncrementTag[] = "INCREMENT";
|
||||
constexpr char kInitialValueTag[] = "INITIAL_VALUE";
|
||||
constexpr char kBatchSizeTag[] = "BATCH_SIZE";
|
||||
constexpr char kErrorCountTag[] = "ERROR_COUNT";
|
||||
constexpr char kMaxCountTag[] = "MAX_COUNT";
|
||||
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
|
||||
|
||||
// Source calculator that produces MAX_COUNT*BATCH_SIZE int packets of
|
||||
// sequential numbers from INITIAL_VALUE (default 0) with a common
|
||||
// difference of INCREMENT (default 1) between successive numbers (with
|
||||
|
@ -33,53 +40,53 @@ class CountingSourceCalculator : public CalculatorBase {
|
|||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
|
||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) {
|
||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
|
||||
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag)) {
|
||||
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("MAX_COUNT") ||
|
||||
cc->InputSidePackets().HasTag("ERROR_COUNT"));
|
||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
||||
cc->InputSidePackets().Tag("MAX_COUNT").Set<int>();
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kMaxCountTag) ||
|
||||
cc->InputSidePackets().HasTag(kErrorCountTag));
|
||||
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
|
||||
cc->InputSidePackets().Tag(kMaxCountTag).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
||||
cc->InputSidePackets().Tag("ERROR_COUNT").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
|
||||
cc->InputSidePackets().Tag(kErrorCountTag).Set<int>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
||||
cc->InputSidePackets().Tag("BATCH_SIZE").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
|
||||
cc->InputSidePackets().Tag(kBatchSizeTag).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
||||
cc->InputSidePackets().Tag("INITIAL_VALUE").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
|
||||
cc->InputSidePackets().Tag(kInitialValueTag).Set<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
||||
cc->InputSidePackets().Tag("INCREMENT").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
|
||||
cc->InputSidePackets().Tag(kIncrementTag).Set<int>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") &&
|
||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
|
||||
if (cc->InputSidePackets().HasTag(kErrorOnOpenTag) &&
|
||||
cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
|
||||
return absl::NotFoundError("expected error");
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("ERROR_COUNT")) {
|
||||
error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kErrorCountTag)) {
|
||||
error_count_ = cc->InputSidePackets().Tag(kErrorCountTag).Get<int>();
|
||||
RET_CHECK_LE(0, error_count_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("MAX_COUNT")) {
|
||||
max_count_ = cc->InputSidePackets().Tag("MAX_COUNT").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kMaxCountTag)) {
|
||||
max_count_ = cc->InputSidePackets().Tag(kMaxCountTag).Get<int>();
|
||||
RET_CHECK_LE(0, max_count_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("BATCH_SIZE")) {
|
||||
batch_size_ = cc->InputSidePackets().Tag("BATCH_SIZE").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kBatchSizeTag)) {
|
||||
batch_size_ = cc->InputSidePackets().Tag(kBatchSizeTag).Get<int>();
|
||||
RET_CHECK_LT(0, batch_size_);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_VALUE")) {
|
||||
counter_ = cc->InputSidePackets().Tag("INITIAL_VALUE").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kInitialValueTag)) {
|
||||
counter_ = cc->InputSidePackets().Tag(kInitialValueTag).Get<int>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("INCREMENT")) {
|
||||
increment_ = cc->InputSidePackets().Tag("INCREMENT").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kIncrementTag)) {
|
||||
increment_ = cc->InputSidePackets().Tag(kIncrementTag).Get<int>();
|
||||
RET_CHECK_LT(0, increment_);
|
||||
}
|
||||
RET_CHECK(error_count_ >= 0 || max_count_ >= 0);
|
||||
|
|
|
@ -35,11 +35,14 @@
|
|||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||
constexpr char kEncodedTag[] = "ENCODED";
|
||||
|
||||
class DequantizeByteArrayCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("ENCODED").Set<std::string>();
|
||||
cc->Outputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
cc->Inputs().Tag(kEncodedTag).Set<std::string>();
|
||||
cc->Outputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -66,7 +69,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
const std::string& encoded =
|
||||
cc->Inputs().Tag("ENCODED").Value().Get<std::string>();
|
||||
cc->Inputs().Tag(kEncodedTag).Value().Get<std::string>();
|
||||
std::vector<float> float_vector;
|
||||
float_vector.reserve(encoded.length());
|
||||
for (int i = 0; i < encoded.length(); ++i) {
|
||||
|
@ -74,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase {
|
|||
static_cast<unsigned char>(encoded.at(i)) * scalar_ + bias_);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("FLOAT_VECTOR")
|
||||
.Tag(kFloatVectorTag)
|
||||
.AddPacket(MakePacket<std::vector<float>>(float_vector)
|
||||
.At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||
constexpr char kEncodedTag[] = "ENCODED";
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
|
@ -39,8 +42,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
|||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kEncodedTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
|
@ -64,8 +69,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
|||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kEncodedTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
|
@ -89,8 +96,10 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
|||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
std::string empty_string;
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kEncodedTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::string>(empty_string).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
|
@ -114,14 +123,16 @@ TEST(DequantizeByteArrayCalculatorTest, TestDequantization) {
|
|||
)pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
unsigned char input[4] = {0x7F, 0xFF, 0x00, 0x01};
|
||||
runner.MutableInputs()->Tag("ENCODED").packets.push_back(
|
||||
MakePacket<std::string>(
|
||||
std::string(reinterpret_cast<char const*>(input), 4))
|
||||
.At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kEncodedTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::string>(
|
||||
std::string(reinterpret_cast<char const*>(input), 4))
|
||||
.At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag("FLOAT_VECTOR").packets;
|
||||
runner.Outputs().Tag(kFloatVectorTag).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::vector<float>& result = outputs[0].Get<std::vector<float>>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
|
|
|
@ -24,6 +24,11 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kFinishedTag[] = "FINISHED";
|
||||
constexpr char kAllowTag[] = "ALLOW";
|
||||
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
|
||||
constexpr char kOptionsTag[] = "OPTIONS";
|
||||
|
||||
// FlowLimiterCalculator is used to limit the number of frames in flight
|
||||
// by dropping input frames when necessary.
|
||||
//
|
||||
|
@ -69,16 +74,19 @@ class FlowLimiterCalculator : public CalculatorBase {
|
|||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
auto& side_inputs = cc->InputSidePackets();
|
||||
side_inputs.Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
||||
cc->Inputs().Tag("OPTIONS").Set<FlowLimiterCalculatorOptions>().Optional();
|
||||
side_inputs.Tag(kOptionsTag).Set<FlowLimiterCalculatorOptions>().Optional();
|
||||
cc->Inputs()
|
||||
.Tag(kOptionsTag)
|
||||
.Set<FlowLimiterCalculatorOptions>()
|
||||
.Optional();
|
||||
RET_CHECK_GE(cc->Inputs().NumEntries(""), 1);
|
||||
for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) {
|
||||
cc->Inputs().Get("", i).SetAny();
|
||||
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
||||
}
|
||||
cc->Inputs().Get("FINISHED", 0).SetAny();
|
||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>().Optional();
|
||||
cc->Outputs().Tag("ALLOW").Set<bool>().Optional();
|
||||
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>().Optional();
|
||||
cc->Outputs().Tag(kAllowTag).Set<bool>().Optional();
|
||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
|
@ -87,9 +95,9 @@ class FlowLimiterCalculator : public CalculatorBase {
|
|||
absl::Status Open(CalculatorContext* cc) final {
|
||||
options_ = cc->Options<FlowLimiterCalculatorOptions>();
|
||||
options_ = tool::RetrieveOptions(options_, cc->InputSidePackets());
|
||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
||||
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||
options_.set_max_in_flight(
|
||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>());
|
||||
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
||||
}
|
||||
input_queues_.resize(cc->Inputs().NumEntries(""));
|
||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
||||
|
@ -104,8 +112,8 @@ class FlowLimiterCalculator : public CalculatorBase {
|
|||
|
||||
// Outputs a packet indicating whether a frame was sent or dropped.
|
||||
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
||||
if (cc->Outputs().HasTag("ALLOW")) {
|
||||
cc->Outputs().Tag("ALLOW").AddPacket(MakePacket<bool>(allow).At(ts));
|
||||
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,7 +163,7 @@ class FlowLimiterCalculator : public CalculatorBase {
|
|||
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
||||
|
||||
// Process the FINISHED input stream.
|
||||
Packet finished_packet = cc->Inputs().Tag("FINISHED").Value();
|
||||
Packet finished_packet = cc->Inputs().Tag(kFinishedTag).Value();
|
||||
if (finished_packet.Timestamp() == cc->InputTimestamp()) {
|
||||
while (!frames_in_flight_.empty() &&
|
||||
frames_in_flight_.front() <= finished_packet.Timestamp()) {
|
||||
|
@ -210,8 +218,8 @@ class FlowLimiterCalculator : public CalculatorBase {
|
|||
Timestamp bound =
|
||||
cc->Inputs().Get("", 0).Value().Timestamp().NextAllowedInStream();
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", 0));
|
||||
if (cc->Outputs().HasTag("ALLOW")) {
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Tag("ALLOW"));
|
||||
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||
SetNextTimestampBound(bound, &cc->Outputs().Tag(kAllowTag));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,13 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDropTimestampsTag[] = "DROP_TIMESTAMPS";
|
||||
constexpr char kClockTag[] = "CLOCK";
|
||||
constexpr char kWarmupTimeTag[] = "WARMUP_TIME";
|
||||
constexpr char kSleepTimeTag[] = "SLEEP_TIME";
|
||||
constexpr char kPacketTag[] = "PACKET";
|
||||
|
||||
// A simple Semaphore for synchronizing test threads.
|
||||
class AtomicSemaphore {
|
||||
public:
|
||||
|
@ -204,17 +211,17 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) {
|
|||
class SleepCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
||||
cc->InputSidePackets().Tag("SLEEP_TIME").Set<int64>();
|
||||
cc->InputSidePackets().Tag("WARMUP_TIME").Set<int64>();
|
||||
cc->InputSidePackets().Tag("CLOCK").Set<mediapipe::Clock*>();
|
||||
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
|
||||
cc->InputSidePackets().Tag(kSleepTimeTag).Set<int64>();
|
||||
cc->InputSidePackets().Tag(kWarmupTimeTag).Set<int64>();
|
||||
cc->InputSidePackets().Tag(kClockTag).Set<mediapipe::Clock*>();
|
||||
cc->SetTimestampOffset(0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<mediapipe::Clock*>();
|
||||
clock_ = cc->InputSidePackets().Tag(kClockTag).Get<mediapipe::Clock*>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -222,10 +229,12 @@ class SleepCalculator : public CalculatorBase {
|
|||
++packet_count;
|
||||
absl::Duration sleep_time = absl::Microseconds(
|
||||
packet_count == 1
|
||||
? cc->InputSidePackets().Tag("WARMUP_TIME").Get<int64>()
|
||||
: cc->InputSidePackets().Tag("SLEEP_TIME").Get<int64>());
|
||||
? cc->InputSidePackets().Tag(kWarmupTimeTag).Get<int64>()
|
||||
: cc->InputSidePackets().Tag(kSleepTimeTag).Get<int64>());
|
||||
clock_->Sleep(sleep_time);
|
||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
||||
cc->Outputs()
|
||||
.Tag(kPacketTag)
|
||||
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -240,24 +249,27 @@ REGISTER_CALCULATOR(SleepCalculator);
|
|||
class DropCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET"));
|
||||
cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set<bool>();
|
||||
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||
cc->Outputs().Tag(kPacketTag).SetSameAs(&cc->Inputs().Tag(kPacketTag));
|
||||
cc->InputSidePackets().Tag(kDropTimestampsTag).Set<bool>();
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
||||
if (!cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
|
||||
++packet_count;
|
||||
}
|
||||
bool drop = (packet_count == 3);
|
||||
if (!drop && !cc->Inputs().Tag("PACKET").Value().IsEmpty()) {
|
||||
cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value());
|
||||
if (!drop && !cc->Inputs().Tag(kPacketTag).Value().IsEmpty()) {
|
||||
cc->Outputs()
|
||||
.Tag(kPacketTag)
|
||||
.AddPacket(cc->Inputs().Tag(kPacketTag).Value());
|
||||
}
|
||||
if (!drop || !cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Get<bool>()) {
|
||||
cc->Outputs().Tag("PACKET").SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
if (!drop || !cc->InputSidePackets().Tag(kDropTimestampsTag).Get<bool>()) {
|
||||
cc->Outputs()
|
||||
.Tag(kPacketTag)
|
||||
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -21,6 +21,11 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kStateChangeTag[] = "STATE_CHANGE";
|
||||
constexpr char kDisallowTag[] = "DISALLOW";
|
||||
constexpr char kAllowTag[] = "ALLOW";
|
||||
|
||||
enum GateState {
|
||||
GATE_UNINITIALIZED,
|
||||
GATE_ALLOW,
|
||||
|
@ -83,30 +88,31 @@ class GateCalculator : public CalculatorBase {
|
|||
GateCalculator() {}
|
||||
|
||||
static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) {
|
||||
bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") ||
|
||||
cc->InputSidePackets().HasTag("DISALLOW");
|
||||
bool input_via_side_packet = cc->InputSidePackets().HasTag(kAllowTag) ||
|
||||
cc->InputSidePackets().HasTag(kDisallowTag);
|
||||
bool input_via_stream =
|
||||
cc->Inputs().HasTag("ALLOW") || cc->Inputs().HasTag("DISALLOW");
|
||||
cc->Inputs().HasTag(kAllowTag) || cc->Inputs().HasTag(kDisallowTag);
|
||||
// Only one of input_side_packet or input_stream may specify ALLOW/DISALLOW
|
||||
// input.
|
||||
RET_CHECK(input_via_side_packet ^ input_via_stream);
|
||||
|
||||
if (input_via_side_packet) {
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("ALLOW") ^
|
||||
cc->InputSidePackets().HasTag("DISALLOW"));
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kAllowTag) ^
|
||||
cc->InputSidePackets().HasTag(kDisallowTag));
|
||||
|
||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
||||
cc->InputSidePackets().Tag("ALLOW").Set<bool>();
|
||||
if (cc->InputSidePackets().HasTag(kAllowTag)) {
|
||||
cc->InputSidePackets().Tag(kAllowTag).Set<bool>();
|
||||
} else {
|
||||
cc->InputSidePackets().Tag("DISALLOW").Set<bool>();
|
||||
cc->InputSidePackets().Tag(kDisallowTag).Set<bool>();
|
||||
}
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().HasTag("ALLOW") ^ cc->Inputs().HasTag("DISALLOW"));
|
||||
RET_CHECK(cc->Inputs().HasTag(kAllowTag) ^
|
||||
cc->Inputs().HasTag(kDisallowTag));
|
||||
|
||||
if (cc->Inputs().HasTag("ALLOW")) {
|
||||
cc->Inputs().Tag("ALLOW").Set<bool>();
|
||||
if (cc->Inputs().HasTag(kAllowTag)) {
|
||||
cc->Inputs().Tag(kAllowTag).Set<bool>();
|
||||
} else {
|
||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||
cc->Inputs().Tag(kDisallowTag).Set<bool>();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -125,8 +131,8 @@ class GateCalculator : public CalculatorBase {
|
|||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
cc->Outputs().Tag("STATE_CHANGE").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||
cc->Outputs().Tag(kStateChangeTag).Set<bool>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -134,14 +140,14 @@ class GateCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
use_side_packet_for_allow_disallow_ = false;
|
||||
if (cc->InputSidePackets().HasTag("ALLOW")) {
|
||||
if (cc->InputSidePackets().HasTag(kAllowTag)) {
|
||||
use_side_packet_for_allow_disallow_ = true;
|
||||
allow_by_side_packet_decision_ =
|
||||
cc->InputSidePackets().Tag("ALLOW").Get<bool>();
|
||||
} else if (cc->InputSidePackets().HasTag("DISALLOW")) {
|
||||
cc->InputSidePackets().Tag(kAllowTag).Get<bool>();
|
||||
} else if (cc->InputSidePackets().HasTag(kDisallowTag)) {
|
||||
use_side_packet_for_allow_disallow_ = true;
|
||||
allow_by_side_packet_decision_ =
|
||||
!cc->InputSidePackets().Tag("DISALLOW").Get<bool>();
|
||||
!cc->InputSidePackets().Tag(kDisallowTag).Get<bool>();
|
||||
}
|
||||
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
@ -160,18 +166,18 @@ class GateCalculator : public CalculatorBase {
|
|||
if (use_side_packet_for_allow_disallow_) {
|
||||
allow = allow_by_side_packet_decision_;
|
||||
} else {
|
||||
if (cc->Inputs().HasTag("ALLOW") &&
|
||||
!cc->Inputs().Tag("ALLOW").IsEmpty()) {
|
||||
allow = cc->Inputs().Tag("ALLOW").Get<bool>();
|
||||
if (cc->Inputs().HasTag(kAllowTag) &&
|
||||
!cc->Inputs().Tag(kAllowTag).IsEmpty()) {
|
||||
allow = cc->Inputs().Tag(kAllowTag).Get<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("DISALLOW") &&
|
||||
!cc->Inputs().Tag("DISALLOW").IsEmpty()) {
|
||||
allow = !cc->Inputs().Tag("DISALLOW").Get<bool>();
|
||||
if (cc->Inputs().HasTag(kDisallowTag) &&
|
||||
!cc->Inputs().Tag(kDisallowTag).IsEmpty()) {
|
||||
allow = !cc->Inputs().Tag(kDisallowTag).Get<bool>();
|
||||
}
|
||||
}
|
||||
const GateState new_gate_state = allow ? GATE_ALLOW : GATE_DISALLOW;
|
||||
|
||||
if (cc->Outputs().HasTag("STATE_CHANGE")) {
|
||||
if (cc->Outputs().HasTag(kStateChangeTag)) {
|
||||
if (last_gate_state_ != GATE_UNINITIALIZED &&
|
||||
last_gate_state_ != new_gate_state) {
|
||||
VLOG(2) << "State transition in " << cc->NodeName() << " @ "
|
||||
|
@ -179,7 +185,7 @@ class GateCalculator : public CalculatorBase {
|
|||
<< ToString(last_gate_state_) << " to "
|
||||
<< ToString(new_gate_state);
|
||||
cc->Outputs()
|
||||
.Tag("STATE_CHANGE")
|
||||
.Tag(kStateChangeTag)
|
||||
.AddPacket(MakePacket<bool>(allow).At(cc->InputTimestamp()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,9 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kDisallowTag[] = "DISALLOW";
|
||||
constexpr char kAllowTag[] = "ALLOW";
|
||||
|
||||
class GateCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
// Helper to run a graph and return status.
|
||||
|
@ -117,7 +120,7 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
|||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(true));
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
|
@ -139,7 +142,7 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
|||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(false));
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
|
@ -161,7 +164,7 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
|||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("ALLOW") = Adopt(new bool(false));
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
|
@ -179,7 +182,7 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
|||
input_stream: "test_input"
|
||||
output_stream: "test_output"
|
||||
)");
|
||||
runner()->MutableSidePackets()->Tag("DISALLOW") = Adopt(new bool(true));
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
|
|
|
@ -39,20 +39,24 @@ using testing::ElementsAre;
|
|||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kClockTag[] = "CLOCK";
|
||||
|
||||
using mediapipe::Clock;
|
||||
|
||||
// A Calculator with a fixed Process call latency.
|
||||
class SleepCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
||||
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
cc->SetTimestampOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
||||
clock_ =
|
||||
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kMinuendTag[] = "MINUEND";
|
||||
constexpr char kSubtrahendTag[] = "SUBTRAHEND";
|
||||
|
||||
// A 3x4 Matrix of random integers in [0,1000).
|
||||
const char kMatrixText[] =
|
||||
"rows: 3\n"
|
||||
|
@ -104,12 +107,13 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromInput) {
|
|||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("SUBTRAHEND") = Adopt(side_matrix);
|
||||
runner.MutableSidePackets()->Tag(kSubtrahendTag) = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()->Tag("MINUEND").packets.push_back(
|
||||
Adopt(input_matrix).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kMinuendTag)
|
||||
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(1, runner.Outputs().Index(0).packets.size());
|
||||
|
@ -133,12 +137,12 @@ TEST(MatrixSubtractCalculatorTest, SubtractFromSideMatrix) {
|
|||
CalculatorRunner runner(node_config);
|
||||
Matrix* side_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText, side_matrix);
|
||||
runner.MutableSidePackets()->Tag("MINUEND") = Adopt(side_matrix);
|
||||
runner.MutableSidePackets()->Tag(kMinuendTag) = Adopt(side_matrix);
|
||||
|
||||
Matrix* input_matrix = new Matrix();
|
||||
MatrixFromTextProto(kMatrixText2, input_matrix);
|
||||
runner.MutableInputs()
|
||||
->Tag("SUBTRAHEND")
|
||||
->Tag(kSubtrahendTag)
|
||||
.packets.push_back(Adopt(input_matrix).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kPresenceTag[] = "PRESENCE";
|
||||
constexpr char kPacketTag[] = "PACKET";
|
||||
|
||||
// For each non empty input packet, emits a single output packet containing a
|
||||
// boolean value "true", "false" in response to empty packets (a.k.a. timestamp
|
||||
// bound updates) This can be used to "flag" the presence of an arbitrary packet
|
||||
|
@ -58,8 +61,8 @@ namespace mediapipe {
|
|||
class PacketPresenceCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("PACKET").SetAny();
|
||||
cc->Outputs().Tag("PRESENCE").Set<bool>();
|
||||
cc->Inputs().Tag(kPacketTag).SetAny();
|
||||
cc->Outputs().Tag(kPresenceTag).Set<bool>();
|
||||
// Process() function is invoked in response to input stream timestamp
|
||||
// bound updates.
|
||||
cc->SetProcessTimestampBounds(true);
|
||||
|
@ -73,8 +76,8 @@ class PacketPresenceCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
cc->Outputs()
|
||||
.Tag("PRESENCE")
|
||||
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag("PACKET").IsEmpty())
|
||||
.Tag(kPresenceTag)
|
||||
.AddPacket(MakePacket<bool>(!cc->Inputs().Tag(kPacketTag).IsEmpty())
|
||||
.At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -39,6 +39,11 @@ namespace mediapipe {
|
|||
|
||||
REGISTER_CALCULATOR(PacketResamplerCalculator);
|
||||
namespace {
|
||||
|
||||
constexpr char kSeedTag[] = "SEED";
|
||||
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
|
||||
constexpr char kOptionsTag[] = "OPTIONS";
|
||||
|
||||
// Returns a TimestampDiff (assuming microseconds) corresponding to the
|
||||
// given time in seconds.
|
||||
TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
||||
|
@ -50,16 +55,16 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) {
|
|||
absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
||||
const auto& resampler_options =
|
||||
cc->Options<PacketResamplerCalculatorOptions>();
|
||||
if (cc->InputSidePackets().HasTag("OPTIONS")) {
|
||||
cc->InputSidePackets().Tag("OPTIONS").Set<CalculatorOptions>();
|
||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||
cc->InputSidePackets().Tag(kOptionsTag).Set<CalculatorOptions>();
|
||||
}
|
||||
CollectionItemId input_data_id = cc->Inputs().GetId("DATA", 0);
|
||||
if (!input_data_id.IsValid()) {
|
||||
input_data_id = cc->Inputs().GetId("", 0);
|
||||
}
|
||||
cc->Inputs().Get(input_data_id).SetAny();
|
||||
if (cc->Inputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Inputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
||||
if (cc->Inputs().HasTag(kVideoHeaderTag)) {
|
||||
cc->Inputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
|
||||
}
|
||||
|
||||
CollectionItemId output_data_id = cc->Outputs().GetId("DATA", 0);
|
||||
|
@ -67,15 +72,15 @@ absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) {
|
|||
output_data_id = cc->Outputs().GetId("", 0);
|
||||
}
|
||||
cc->Outputs().Get(output_data_id).SetSameAs(&cc->Inputs().Get(input_data_id));
|
||||
if (cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
cc->Outputs().Tag("VIDEO_HEADER").Set<VideoHeader>();
|
||||
if (cc->Outputs().HasTag(kVideoHeaderTag)) {
|
||||
cc->Outputs().Tag(kVideoHeaderTag).Set<VideoHeader>();
|
||||
}
|
||||
|
||||
if (resampler_options.jitter() != 0.0) {
|
||||
RET_CHECK_GT(resampler_options.jitter(), 0.0);
|
||||
RET_CHECK_LE(resampler_options.jitter(), 1.0);
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("SEED"));
|
||||
cc->InputSidePackets().Tag("SEED").Set<std::string>();
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kSeedTag));
|
||||
cc->InputSidePackets().Tag(kSeedTag).Set<std::string>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -143,9 +148,9 @@ absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream() &&
|
||||
cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") &&
|
||||
!cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) {
|
||||
video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get<VideoHeader>();
|
||||
cc->Inputs().UsesTags() && cc->Inputs().HasTag(kVideoHeaderTag) &&
|
||||
!cc->Inputs().Tag(kVideoHeaderTag).IsEmpty()) {
|
||||
video_header_ = cc->Inputs().Tag(kVideoHeaderTag).Get<VideoHeader>();
|
||||
video_header_.frame_rate = frame_rate_;
|
||||
if (cc->Inputs().Get(input_data_id_).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
|
@ -234,7 +239,7 @@ absl::Status LegacyJitterWithReflectionStrategy::Open(CalculatorContext* cc) {
|
|||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
|
@ -357,7 +362,7 @@ absl::Status ReproducibleJitterWithReflectionStrategy::Open(
|
|||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
|
@ -504,7 +509,7 @@ absl::Status JitterWithoutReflectionStrategy::Open(CalculatorContext* cc) {
|
|||
"ignored, because we are adding jitter.";
|
||||
}
|
||||
|
||||
const auto& seed = cc->InputSidePackets().Tag("SEED").Get<std::string>();
|
||||
const auto& seed = cc->InputSidePackets().Tag(kSeedTag).Get<std::string>();
|
||||
random_ = CreateSecureRandom(seed);
|
||||
if (random_ == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
|
@ -635,9 +640,9 @@ absl::Status NoJitterStrategy::Process(CalculatorContext* cc) {
|
|||
base_timestamp_ +
|
||||
TimestampDiffFromSeconds(first_index / calculator_->frame_rate_);
|
||||
}
|
||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag("VIDEO_HEADER")) {
|
||||
if (cc->Outputs().UsesTags() && cc->Outputs().HasTag(kVideoHeaderTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_HEADER")
|
||||
.Tag(kVideoHeaderTag)
|
||||
.Add(new VideoHeader(calculator_->video_header_),
|
||||
Timestamp::PreStream());
|
||||
}
|
||||
|
|
|
@ -32,6 +32,12 @@ namespace mediapipe {
|
|||
|
||||
using ::testing::ElementsAre;
|
||||
namespace {
|
||||
|
||||
constexpr char kOptionsTag[] = "OPTIONS";
|
||||
constexpr char kSeedTag[] = "SEED";
|
||||
constexpr char kVideoHeaderTag[] = "VIDEO_HEADER";
|
||||
constexpr char kDataTag[] = "DATA";
|
||||
|
||||
// A simple version of CalculatorRunner with built-in convenience
|
||||
// methods for setting inputs from a vector and checking outputs
|
||||
// against expected outputs (both timestamps and contents).
|
||||
|
@ -464,7 +470,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
|||
)pb"));
|
||||
|
||||
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
||||
runner.MutableInputs()->Tag("DATA").packets.push_back(
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
|
||||
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
||||
}
|
||||
VideoHeader video_header_in;
|
||||
|
@ -474,16 +480,16 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
|||
video_header_in.duration = 1.0;
|
||||
video_header_in.format = ImageFormat::SRGB;
|
||||
runner.MutableInputs()
|
||||
->Tag("VIDEO_HEADER")
|
||||
->Tag(kVideoHeaderTag)
|
||||
.packets.push_back(
|
||||
Adopt(new VideoHeader(video_header_in)).At(Timestamp::PreStream()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
ASSERT_EQ(1, runner.Outputs().Tag("VIDEO_HEADER").packets.size());
|
||||
ASSERT_EQ(1, runner.Outputs().Tag(kVideoHeaderTag).packets.size());
|
||||
EXPECT_EQ(Timestamp::PreStream(),
|
||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Timestamp());
|
||||
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Timestamp());
|
||||
const VideoHeader& video_header_out =
|
||||
runner.Outputs().Tag("VIDEO_HEADER").packets[0].Get<VideoHeader>();
|
||||
runner.Outputs().Tag(kVideoHeaderTag).packets[0].Get<VideoHeader>();
|
||||
EXPECT_EQ(video_header_in.width, video_header_out.width);
|
||||
EXPECT_EQ(video_header_in.height, video_header_out.height);
|
||||
EXPECT_DOUBLE_EQ(50.0, video_header_out.frame_rate);
|
||||
|
@ -725,7 +731,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
|
|||
[mediapipe.PacketResamplerCalculatorOptions.ext] {
|
||||
frame_rate: 30
|
||||
})pb"));
|
||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
||||
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(6, runner.Outputs().Index(0).packets.size());
|
||||
|
@ -740,7 +746,7 @@ TEST(PacketResamplerCalculatorTest, OptionsSidePacket) {
|
|||
frame_rate: 30
|
||||
base_timestamp: 0
|
||||
})pb"));
|
||||
runner.MutableSidePackets()->Tag("OPTIONS") = Adopt(options);
|
||||
runner.MutableSidePackets()->Tag(kOptionsTag) = Adopt(options);
|
||||
|
||||
runner.SetInput({-222, 15000, 32000, 49999, 150000});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kPeriodTag[] = "PERIOD";
|
||||
|
||||
// A simple version of CalculatorRunner with built-in convenience methods for
|
||||
// setting inputs from a vector and checking outputs against a vector of
|
||||
// expected outputs.
|
||||
|
@ -121,7 +123,7 @@ TEST(PacketThinnerCalculatorTest, ASyncUniformStreamThinningTestBySidePacket) {
|
|||
|
||||
SimpleRunner runner(node);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
||||
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 8, 14};
|
||||
|
@ -160,7 +162,7 @@ TEST(PacketThinnerCalculatorTest, SyncUniformStreamThinningTestBySidePacket1) {
|
|||
|
||||
SimpleRunner runner(node);
|
||||
runner.SetInput({2, 4, 6, 8, 10, 12, 14});
|
||||
runner.MutableSidePackets()->Tag("PERIOD") = MakePacket<int64>(5);
|
||||
runner.MutableSidePackets()->Tag(kPeriodTag) = MakePacket<int64>(5);
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const std::vector<int64> expected_timestamps = {2, 6, 10, 14};
|
||||
|
|
|
@ -39,6 +39,8 @@ using ::testing::Pair;
|
|||
using ::testing::Value;
|
||||
namespace {
|
||||
|
||||
constexpr char kDisallowTag[] = "DISALLOW";
|
||||
|
||||
// Returns the timestamp values for a vector of Packets.
|
||||
// TODO: puth this kind of test util in a common place.
|
||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||
|
@ -702,14 +704,14 @@ class DroppingGateCalculator : public CalculatorBase {
|
|||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Inputs().Tag("DISALLOW").Set<bool>();
|
||||
cc->Inputs().Tag(kDisallowTag).Set<bool>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (!cc->Inputs().Index(0).IsEmpty() &&
|
||||
!cc->Inputs().Tag("DISALLOW").Get<bool>()) {
|
||||
!cc->Inputs().Tag(kDisallowTag).Get<bool>()) {
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -41,11 +41,14 @@
|
|||
// }
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kEncodedTag[] = "ENCODED";
|
||||
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||
|
||||
class QuantizeFloatVectorCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Set<std::vector<float>>();
|
||||
cc->Outputs().Tag("ENCODED").Set<std::string>();
|
||||
cc->Inputs().Tag(kFloatVectorTag).Set<std::vector<float>>();
|
||||
cc->Outputs().Tag(kEncodedTag).Set<std::string>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -70,7 +73,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
const std::vector<float>& float_vector =
|
||||
cc->Inputs().Tag("FLOAT_VECTOR").Value().Get<std::vector<float>>();
|
||||
cc->Inputs().Tag(kFloatVectorTag).Value().Get<std::vector<float>>();
|
||||
int feature_size = float_vector.size();
|
||||
std::string encoded_features;
|
||||
encoded_features.reserve(feature_size);
|
||||
|
@ -86,8 +89,10 @@ class QuantizeFloatVectorCalculator : public CalculatorBase {
|
|||
(old_value - min_quantized_value_) * (255.0 / range_));
|
||||
encoded_features += encoded;
|
||||
}
|
||||
cc->Outputs().Tag("ENCODED").AddPacket(
|
||||
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
||||
cc->Outputs()
|
||||
.Tag(kEncodedTag)
|
||||
.AddPacket(
|
||||
MakePacket<std::string>(encoded_features).At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kEncodedTag[] = "ENCODED";
|
||||
constexpr char kFloatVectorTag[] = "FLOAT_VECTOR";
|
||||
|
||||
TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
|
@ -40,7 +43,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
|
@ -67,7 +70,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig2) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
|
@ -94,7 +97,7 @@ TEST(QuantizeFloatVectorCalculatorTest, WrongConfig3) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
auto status = runner.Run();
|
||||
|
@ -121,11 +124,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestEmptyVector) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> empty_vector;
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(empty_vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag(kEncodedTag).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
EXPECT_TRUE(outputs[0].Get<std::string>().empty());
|
||||
EXPECT_EQ(Timestamp(0), outputs[0].Timestamp());
|
||||
|
@ -147,11 +151,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestNonEmptyVector) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {0.0f, -64.0f, 64.0f, -32.0f, 32.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag(kEncodedTag).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
|
@ -185,11 +190,12 @@ TEST(QuantizeFloatVectorCalculatorTest, TestSaturation) {
|
|||
CalculatorRunner runner(node_config);
|
||||
std::vector<float> vector = {-65.0f, 65.0f};
|
||||
runner.MutableInputs()
|
||||
->Tag("FLOAT_VECTOR")
|
||||
->Tag(kFloatVectorTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(vector).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& outputs = runner.Outputs().Tag("ENCODED").packets;
|
||||
const std::vector<Packet>& outputs =
|
||||
runner.Outputs().Tag(kEncodedTag).packets;
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
const std::string& result = outputs[0].Get<std::string>();
|
||||
ASSERT_FALSE(result.empty());
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kAllowTag[] = "ALLOW";
|
||||
constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
|
||||
|
||||
// RealTimeFlowLimiterCalculator is used to limit the number of pipelined
|
||||
// processing operations in a section of the graph.
|
||||
//
|
||||
|
@ -86,11 +89,11 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
|||
cc->Outputs().Get("", i).SetSameAs(&(cc->Inputs().Get("", i)));
|
||||
}
|
||||
cc->Inputs().Get("FINISHED", 0).SetAny();
|
||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
||||
cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||
cc->InputSidePackets().Tag(kMaxInFlightTag).Set<int>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("ALLOW")) {
|
||||
cc->Outputs().Tag("ALLOW").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||
cc->Outputs().Tag(kAllowTag).Set<bool>();
|
||||
}
|
||||
|
||||
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||
|
@ -101,8 +104,8 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
|||
absl::Status Open(CalculatorContext* cc) final {
|
||||
finished_id_ = cc->Inputs().GetId("FINISHED", 0);
|
||||
max_in_flight_ = 1;
|
||||
if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) {
|
||||
max_in_flight_ = cc->InputSidePackets().Tag("MAX_IN_FLIGHT").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kMaxInFlightTag)) {
|
||||
max_in_flight_ = cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>();
|
||||
}
|
||||
RET_CHECK_GE(max_in_flight_, 1);
|
||||
num_in_flight_ = 0;
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kFinishedTag[] = "FINISHED";
|
||||
|
||||
// A simple Semaphore for synchronizing test threads.
|
||||
class AtomicSemaphore {
|
||||
public:
|
||||
|
@ -112,7 +115,7 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) {
|
|||
Timestamp timestamp =
|
||||
Timestamp((i + 1) * Timestamp::kTimestampUnitsPerSecond);
|
||||
runner.MutableInputs()
|
||||
->Tag("FINISHED")
|
||||
->Tag(kFinishedTag)
|
||||
.packets.push_back(MakePacket<bool>(true).At(timestamp));
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kPacketOffsetTag[] = "PACKET_OFFSET";
|
||||
|
||||
// Adds packets containing integers equal to their original timestamp.
|
||||
void AddPackets(CalculatorRunner* runner) {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
|
@ -111,7 +113,7 @@ TEST(SequenceShiftCalculatorTest, SidePacketOffset) {
|
|||
|
||||
CalculatorRunner runner(node);
|
||||
AddPackets(&runner);
|
||||
runner.MutableSidePackets()->Tag("PACKET_OFFSET") = Adopt(new int(-2));
|
||||
runner.MutableSidePackets()->Tag(kPacketOffsetTag) = Adopt(new int(-2));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& input_packets =
|
||||
runner.MutableInputs()->Index(0).packets;
|
||||
|
|
|
@ -102,6 +102,10 @@ mediapipe::ScaleMode_Mode ParseScaleMode(
|
|||
// IMAGE: ImageFrame representing the input image.
|
||||
// IMAGE_GPU: GpuBuffer representing the input image.
|
||||
//
|
||||
// OUTPUT_DIMENSIONS (optional): The output width and height in pixels as
|
||||
// pair<int, int>. If set, it will override corresponding field in calculator
|
||||
// options and input side packet.
|
||||
//
|
||||
// ROTATION_DEGREES (optional): The counterclockwise rotation angle in
|
||||
// degrees. This allows different rotation angles for different frames. It has
|
||||
// to be a multiple of 90 degrees. If provided, it overrides the
|
||||
|
@ -221,6 +225,10 @@ absl::Status ImageTransformationCalculator::GetContract(
|
|||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) {
|
||||
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Set<std::pair<int, int>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
|
||||
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
|
||||
}
|
||||
|
@ -329,6 +337,13 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) {
|
|||
!cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) {
|
||||
flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") &&
|
||||
!cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) {
|
||||
const auto& image_size =
|
||||
cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get<std::pair<int, int>>();
|
||||
output_width_ = image_size.first;
|
||||
output_height_ = image_size.second;
|
||||
}
|
||||
|
||||
if (use_gpu_) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
|
|
@ -88,6 +88,13 @@ proto_library(
|
|||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "tensor_to_vector_string_calculator_options_proto",
|
||||
srcs = ["tensor_to_vector_string_calculator_options.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "unpack_media_sequence_calculator_proto",
|
||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||
|
@ -257,6 +264,14 @@ mediapipe_cc_proto_library(
|
|||
deps = [":tensor_to_vector_float_calculator_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "tensor_to_vector_string_calculator_options_cc_proto",
|
||||
srcs = ["tensor_to_vector_string_calculator_options.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tensor_to_vector_string_calculator_options_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "unpack_media_sequence_calculator_cc_proto",
|
||||
srcs = ["unpack_media_sequence_calculator.proto"],
|
||||
|
@ -694,6 +709,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_to_vector_string_calculator",
|
||||
srcs = ["tensor_to_vector_string_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
":tensor_to_vector_string_calculator_options_cc_proto",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
],
|
||||
"//mediapipe:android": [
|
||||
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unpack_media_sequence_calculator",
|
||||
srcs = ["unpack_media_sequence_calculator.cc"],
|
||||
|
@ -1059,6 +1094,20 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_to_vector_string_calculator_test",
|
||||
srcs = ["tensor_to_vector_string_calculator_test.cc"],
|
||||
deps = [
|
||||
":tensor_to_vector_string_calculator",
|
||||
":tensor_to_vector_string_calculator_options_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "unpack_media_sequence_calculator_test",
|
||||
srcs = ["unpack_media_sequence_calculator_test.cc"],
|
||||
|
|
|
@ -40,6 +40,24 @@ namespace {
|
|||
namespace tf = ::tensorflow;
|
||||
namespace mpms = mediapipe::mediasequence;
|
||||
|
||||
constexpr char kBboxTag[] = "BBOX";
|
||||
constexpr char kEncodedMediaStartTimestampTag[] =
|
||||
"ENCODED_MEDIA_START_TIMESTAMP";
|
||||
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
|
||||
constexpr char kClassSegmentationTag[] = "CLASS_SEGMENTATION";
|
||||
constexpr char kKeypointsTestTag[] = "KEYPOINTS_TEST";
|
||||
constexpr char kBboxPredictedTag[] = "BBOX_PREDICTED";
|
||||
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
|
||||
constexpr char kAudioTestTag[] = "AUDIO_TEST";
|
||||
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
|
||||
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
|
||||
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
|
||||
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpCalculator(const std::vector<std::string>& input_streams,
|
||||
|
@ -83,17 +101,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) {
|
|||
for (int i = 0; i < num_images; ++i) {
|
||||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -127,17 +145,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) {
|
|||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()
|
||||
->Tag("IMAGE_PREFIX")
|
||||
->Tag(kImagePrefixTag)
|
||||
.packets.push_back(Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -161,21 +179,21 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoFloatLists) {
|
|||
for (int i = 0; i < num_timesteps; ++i) {
|
||||
auto vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FLOAT_FEATURE_TEST")
|
||||
->Tag(kFloatFeatureTestTag)
|
||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||
vf_ptr = ::absl::make_unique<std::vector<float>>(2, 2 << i);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FLOAT_FEATURE_OTHER")
|
||||
->Tag(kFloatFeatureOtherTag)
|
||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -228,20 +246,20 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
|
|||
|
||||
auto vf_ptr = absl::make_unique<std::vector<float>>(2, 3);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FLOAT_CONTEXT_FEATURE_TEST")
|
||||
->Tag(kFloatContextFeatureTestTag)
|
||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
||||
vf_ptr = absl::make_unique<std::vector<float>>(2, 4);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FLOAT_CONTEXT_FEATURE_OTHER")
|
||||
->Tag(kFloatContextFeatureOtherTag)
|
||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp::PostStream()));
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -259,7 +277,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
|||
SetUpCalculator({"IMAGE:images"}, context, false, true);
|
||||
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
std::vector<uchar> bytes;
|
||||
|
@ -268,13 +286,13 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
|||
encoded_image.set_encoded_image(bytes.data(), bytes.size());
|
||||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
Adopt(image_ptr.release()).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -307,17 +325,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) {
|
|||
auto flow_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FORWARD_FLOW_ENCODED")
|
||||
->Tag(kForwardFlowEncodedTag)
|
||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -371,17 +389,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
|
|||
detections->push_back(detection);
|
||||
|
||||
runner_->MutableInputs()
|
||||
->Tag("BBOX_PREDICTED")
|
||||
->Tag(kBboxPredictedTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -450,11 +468,11 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) {
|
|||
detections->push_back(detection);
|
||||
|
||||
runner_->MutableInputs()
|
||||
->Tag("BBOX_PREDICTED")
|
||||
->Tag(kBboxPredictedTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
auto status = runner_->Run();
|
||||
|
@ -498,7 +516,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
|||
detections->push_back(detection);
|
||||
|
||||
runner_->MutableInputs()
|
||||
->Tag("BBOX_PREDICTED")
|
||||
->Tag(kBboxPredictedTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||
}
|
||||
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
|
||||
|
@ -513,16 +531,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
|
|||
for (int i = 0; i < num_images; ++i) {
|
||||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -564,18 +582,18 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
|
|||
absl::flat_hash_map<std::string, std::vector<std::pair<float, float>>>
|
||||
points = {{"HEAD", {{0.1, 0.2}, {0.3, 0.4}}}, {"TAIL", {{0.5, 0.6}}}};
|
||||
runner_->MutableInputs()
|
||||
->Tag("KEYPOINTS_TEST")
|
||||
->Tag(kKeypointsTestTag)
|
||||
.packets.push_back(PointToForeign(&points).At(Timestamp(0)));
|
||||
runner_->MutableInputs()
|
||||
->Tag("KEYPOINTS_TEST")
|
||||
->Tag(kKeypointsTestTag)
|
||||
.packets.push_back(PointToForeign(&points).At(Timestamp(1)));
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -615,17 +633,17 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
|||
detections->push_back(detection);
|
||||
|
||||
runner_->MutableInputs()
|
||||
->Tag("CLASS_SEGMENTATION")
|
||||
->Tag(kClassSegmentationTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -664,17 +682,17 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) {
|
|||
auto flow_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FORWARD_FLOW_ENCODED")
|
||||
->Tag(kForwardFlowEncodedTag)
|
||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -710,11 +728,11 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) {
|
|||
auto flow_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_flow);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FORWARD_FLOW_ENCODED")
|
||||
->Tag(kForwardFlowEncodedTag)
|
||||
.packets.push_back(Adopt(flow_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
absl::Status status = runner_->Run();
|
||||
|
@ -731,13 +749,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingImages) {
|
|||
mpms::AddImageTimestamp(1, input_sequence.get());
|
||||
mpms::AddImageTimestamp(2, input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -757,13 +775,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFlowImages) {
|
|||
mpms::AddForwardFlowTimestamp(1, input_sequence.get());
|
||||
mpms::AddForwardFlowTimestamp(2, input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -794,13 +812,13 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReplacingFloatVectors) {
|
|||
mpms::GetFeatureTimestampSize("OTHER", *input_sequence));
|
||||
ASSERT_EQ(num_timesteps,
|
||||
mpms::GetFeatureFloatsSize("OTHER", *input_sequence));
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -826,7 +844,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
|||
for (int i = 0; i < num_images; ++i) {
|
||||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
Adopt(image_ptr.release()).At(Timestamp((i + 1) * 10)));
|
||||
}
|
||||
|
||||
|
@ -838,11 +856,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) {
|
|||
mpms::AddBBoxTimestamp("PREFIX", 9, input_sequence.get());
|
||||
mpms::AddBBoxTimestamp("PREFIX", 22, input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
@ -879,7 +897,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
|||
for (int i = 0; i < num_images; ++i) {
|
||||
auto image_ptr =
|
||||
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
|
||||
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
Adopt(image_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
|
@ -893,7 +911,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
|||
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
|
||||
.ConvertToProto(detection.mutable_location_data());
|
||||
detections->push_back(detection);
|
||||
runner_->MutableInputs()->Tag("BBOX").packets.push_back(
|
||||
runner_->MutableInputs()->Tag(kBboxTag).packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
|
@ -909,7 +927,7 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) {
|
|||
mpms::AddBBoxTrackIndex({-1}, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
// If the all the previous values aren't cleared, this assert will fail.
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
@ -925,11 +943,11 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) {
|
|||
for (int i = 0; i < num_timesteps; ++i) {
|
||||
auto vf_ptr = ::absl::make_unique<std::vector<float>>(1000000, i);
|
||||
runner_->MutableInputs()
|
||||
->Tag("FLOAT_FEATURE_TEST")
|
||||
->Tag(kFloatFeatureTestTag)
|
||||
.packets.push_back(Adopt(vf_ptr.release()).At(Timestamp(i)));
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
ASSERT_FALSE(runner_->Run().ok());
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ namespace mediapipe {
|
|||
namespace tf = ::tensorflow;
|
||||
namespace {
|
||||
|
||||
constexpr char kReferenceTag[] = "REFERENCE";
|
||||
|
||||
constexpr char kMatrix[] = "MATRIX";
|
||||
constexpr char kTensor[] = "TENSOR";
|
||||
|
||||
|
@ -68,7 +70,8 @@ class TensorToMatrixCalculatorTest : public ::testing::Test {
|
|||
if (include_rate) {
|
||||
header->set_packet_rate(1.0);
|
||||
}
|
||||
runner_->MutableInputs()->Tag("REFERENCE").header = Adopt(header.release());
|
||||
runner_->MutableInputs()->Tag(kReferenceTag).header =
|
||||
Adopt(header.release());
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Calculator converts from one-dimensional Tensor of DT_STRING to
|
||||
// vector<std::string> OR from (batched) two-dimensional Tensor of DT_STRING to
|
||||
// vector<vector<std::string>.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
class TensorToVectorStringCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
TensorToVectorStringCalculatorOptions options_;
|
||||
};
|
||||
REGISTER_CALCULATOR(TensorToVectorStringCalculator);
|
||||
|
||||
absl::Status TensorToVectorStringCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
// Start with only one input packet.
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 1)
|
||||
<< "Only one input stream is supported.";
|
||||
cc->Inputs().Index(0).Set<tf::Tensor>(
|
||||
// Input Tensor
|
||||
);
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1)
|
||||
<< "Only one output stream is supported.";
|
||||
const auto& options = cc->Options<TensorToVectorStringCalculatorOptions>();
|
||||
if (options.tensor_is_2d()) {
|
||||
RET_CHECK(!options.flatten_nd());
|
||||
cc->Outputs().Index(0).Set<std::vector<std::vector<std::string>>>(
|
||||
/* "Output vector<vector<std::string>>." */);
|
||||
} else {
|
||||
cc->Outputs().Index(0).Set<std::vector<std::string>>(
|
||||
// Output vector<std::string>.
|
||||
);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorToVectorStringCalculator::Open(CalculatorContext* cc) {
|
||||
options_ = cc->Options<TensorToVectorStringCalculatorOptions>();
|
||||
|
||||
// Inform mediapipe that this calculator produces an output at time t for
|
||||
// each input received at time t (i.e. this calculator does not buffer
|
||||
// inputs). This enables mediapipe to propagate time of arrival estimates in
|
||||
// mediapipe graphs through this calculator.
|
||||
cc->SetOffset(/*offset=*/0);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorToVectorStringCalculator::Process(CalculatorContext* cc) {
|
||||
const tf::Tensor& input_tensor =
|
||||
cc->Inputs().Index(0).Value().Get<tf::Tensor>();
|
||||
RET_CHECK(tf::DT_STRING == input_tensor.dtype())
|
||||
<< "expected DT_STRING input but got "
|
||||
<< tensorflow::DataTypeString(input_tensor.dtype());
|
||||
|
||||
if (options_.tensor_is_2d()) {
|
||||
RET_CHECK(2 == input_tensor.dims())
|
||||
<< "Expected 2-dimensional Tensor, but the tensor shape is: "
|
||||
<< input_tensor.shape().DebugString();
|
||||
auto output = absl::make_unique<std::vector<std::vector<std::string>>>(
|
||||
input_tensor.dim_size(0),
|
||||
std::vector<std::string>(input_tensor.dim_size(1)));
|
||||
for (int i = 0; i < input_tensor.dim_size(0); ++i) {
|
||||
auto& instance_output = output->at(i);
|
||||
const auto& slice =
|
||||
input_tensor.Slice(i, i + 1).unaligned_flat<tensorflow::tstring>();
|
||||
for (int j = 0; j < input_tensor.dim_size(1); ++j) {
|
||||
instance_output.at(j) = slice(j);
|
||||
}
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
} else {
|
||||
if (!options_.flatten_nd()) {
|
||||
RET_CHECK(1 == input_tensor.dims())
|
||||
<< "`flatten_nd` is not set. Expected 1-dimensional Tensor, but the "
|
||||
<< "tensor shape is: " << input_tensor.shape().DebugString();
|
||||
}
|
||||
auto output =
|
||||
absl::make_unique<std::vector<std::string>>(input_tensor.NumElements());
|
||||
const auto& tensor_values = input_tensor.flat<tensorflow::tstring>();
|
||||
for (int i = 0; i < input_tensor.NumElements(); ++i) {
|
||||
output->at(i) = tensor_values(i);
|
||||
}
|
||||
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message TensorToVectorStringCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TensorToVectorStringCalculatorOptions ext = 386534187;
|
||||
}
|
||||
|
||||
// If true, unpack a 2d tensor (matrix) into a vector<vector<string>>. If
|
||||
// false, convert a 1d tensor (vector) into a vector<string>.
|
||||
optional bool tensor_is_2d = 1 [default = false];
|
||||
|
||||
// If true, an N-D tensor will be flattened to a vector<string>. This is
|
||||
// exclusive with tensor_is_2d.
|
||||
optional bool flatten_nd = 2 [default = false];
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2018 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensorflow/tensor_to_vector_string_calculator_options.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
class TensorToVectorStringCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd) {
|
||||
CalculatorGraphConfig::Node config;
|
||||
config.set_calculator("TensorToVectorStringCalculator");
|
||||
config.add_input_stream("input_tensor");
|
||||
config.add_output_stream("output_tensor");
|
||||
auto options = config.mutable_options()->MutableExtension(
|
||||
TensorToVectorStringCalculatorOptions::ext);
|
||||
options->set_tensor_is_2d(tensor_is_2d);
|
||||
options->set_flatten_nd(flatten_nd);
|
||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||
}
|
||||
|
||||
std::unique_ptr<CalculatorRunner> runner_;
|
||||
};
|
||||
|
||||
TEST_F(TensorToVectorStringCalculatorTest, ConvertsToVectorFloat) {
|
||||
SetUpRunner(false, false);
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{5});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||
auto tensor_vec = tensor->vec<tensorflow::tstring>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
tensor_vec(i) = absl::StrCat("foo", i);
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const std::vector<std::string>& output_vector =
|
||||
output_packets[0].Get<std::vector<std::string>>();
|
||||
|
||||
EXPECT_EQ(5, output_vector.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const std::string expected = absl::StrCat("foo", i);
|
||||
EXPECT_EQ(expected, output_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToVectorStringCalculatorTest, ConvertsBatchedToVectorVectorFloat) {
|
||||
SetUpRunner(true, false);
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{1, 5});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||
auto slice = tensor->Slice(0, 1).flat<tensorflow::tstring>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
slice(i) = absl::StrCat("foo", i);
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const std::vector<std::vector<std::string>>& output_vectors =
|
||||
output_packets[0].Get<std::vector<std::vector<std::string>>>();
|
||||
ASSERT_EQ(1, output_vectors.size());
|
||||
const std::vector<std::string>& output_vector = output_vectors[0];
|
||||
EXPECT_EQ(5, output_vector.size());
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
const std::string expected = absl::StrCat("foo", i);
|
||||
EXPECT_EQ(expected, output_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToVectorStringCalculatorTest, FlattenShouldTakeAllDimensions) {
|
||||
SetUpRunner(false, true);
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 2, 2});
|
||||
auto tensor = absl::make_unique<tf::Tensor>(tf::DT_STRING, tensor_shape);
|
||||
auto slice = tensor->flat<tensorflow::tstring>();
|
||||
for (int i = 0; i < 2 * 2 * 2; ++i) {
|
||||
slice(i) = absl::StrCat("foo", i);
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(tensor.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Index(0).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const std::vector<std::string>& output_vector =
|
||||
output_packets[0].Get<std::vector<std::string>>();
|
||||
EXPECT_EQ(2 * 2 * 2, output_vector.size());
|
||||
for (int i = 0; i < 2 * 2 * 2; ++i) {
|
||||
const std::string expected = absl::StrCat("foo", i);
|
||||
EXPECT_EQ(expected, output_vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -49,6 +49,11 @@ namespace tf = ::tensorflow;
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
|
||||
|
||||
// This is a simple implementation of a semaphore using standard C++ libraries.
|
||||
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
|
||||
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
|
||||
|
@ -252,10 +257,10 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
}
|
||||
// A mediapipe::TensorFlowSession with a model loaded and ready for use.
|
||||
// For this calculator it must include a tag_to_tensor_map.
|
||||
cc->InputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) {
|
||||
cc->InputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
|
||||
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("RECURRENT_INIT_TENSORS")
|
||||
.Tag(kRecurrentInitTensorsTag)
|
||||
.Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -265,11 +270,11 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
std::unique_ptr<InferenceState> inference_state =
|
||||
absl::make_unique<InferenceState>();
|
||||
if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") &&
|
||||
!cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) {
|
||||
if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) &&
|
||||
!cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) {
|
||||
std::map<std::string, tf::Tensor>* init_tensor_map;
|
||||
init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
|
||||
cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS"));
|
||||
cc->InputSidePackets().Tag(kRecurrentInitTensorsTag));
|
||||
for (const auto& p : *init_tensor_map) {
|
||||
inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
|
||||
}
|
||||
|
@ -280,13 +285,13 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();
|
||||
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("SESSION"));
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag));
|
||||
session_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Tag(kSessionTag)
|
||||
.Get<TensorFlowSession>()
|
||||
.session.get();
|
||||
tag_to_tensor_map_ = cc->InputSidePackets()
|
||||
.Tag("SESSION")
|
||||
.Tag(kSessionTag)
|
||||
.Get<TensorFlowSession>()
|
||||
.tag_to_tensor_map;
|
||||
|
||||
|
|
|
@ -41,6 +41,11 @@ namespace mediapipe {
|
|||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kMultipliedTag[] = "MULTIPLIED";
|
||||
constexpr char kBTag[] = "B";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetGraphDefPath() {
|
||||
#ifdef __APPLE__
|
||||
char path[1024];
|
||||
|
@ -86,8 +91,8 @@ class TensorflowInferenceCalculatorTest : public ::testing::Test {
|
|||
MEDIAPIPE_CHECK_OK(tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options,
|
||||
input_side_packets, &output_side_packets));
|
||||
runner_->MutableSidePackets()->Tag("SESSION") =
|
||||
output_side_packets.Tag("SESSION");
|
||||
runner_->MutableSidePackets()->Tag(kSessionTag) =
|
||||
output_side_packets.Tag(kSessionTag);
|
||||
}
|
||||
|
||||
Packet CreateTensorPacket(const std::vector<int32>& input, int64 time) {
|
||||
|
@ -140,7 +145,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_b =
|
||||
runner_->Outputs().Tag("B").packets;
|
||||
runner_->Outputs().Tag(kBTag).packets;
|
||||
ASSERT_EQ(output_packets_b.size(), 1);
|
||||
const tf::Tensor& tensor_b = output_packets_b[0].Get<tf::Tensor>();
|
||||
tf::TensorShape expected_shape({1, 3});
|
||||
|
@ -148,7 +153,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetConstants) {
|
|||
tf::test::ExpectTensorEqual<int32>(expected_tensor, tensor_b);
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(1, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
expected_tensor = tf::test::AsTensor<int32>({0, 0, 0}, expected_shape);
|
||||
|
@ -181,7 +186,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(1, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
tf::TensorShape expected_shape({3});
|
||||
|
@ -220,7 +225,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(1, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
tf::TensorShape expected_shape({3});
|
||||
|
@ -274,7 +279,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -311,7 +316,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed_MaxInFlight) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -351,7 +356,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(3, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -392,7 +397,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetSingleBatchComputed) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -430,7 +435,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetCloseBatchComputed) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -481,7 +486,7 @@ TEST_F(TensorflowInferenceCalculatorTest, GetBatchComputed_MaxInFlight) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(5, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({6, 8, 10});
|
||||
|
@ -528,7 +533,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStates) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
LOG(INFO) << "timestamp: " << 0;
|
||||
|
@ -569,7 +574,7 @@ TEST_F(TensorflowInferenceCalculatorTest, TestRecurrentStateOverride) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(2, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
LOG(INFO) << "timestamp: " << 0;
|
||||
|
@ -662,7 +667,7 @@ TEST_F(TensorflowInferenceCalculatorTest, MissingInputFeature_Skip) {
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(0, output_packets_mult.size());
|
||||
}
|
||||
|
||||
|
@ -691,7 +696,7 @@ TEST_F(TensorflowInferenceCalculatorTest,
|
|||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets_mult =
|
||||
runner_->Outputs().Tag("MULTIPLIED").packets;
|
||||
runner_->Outputs().Tag(kMultipliedTag).packets;
|
||||
ASSERT_EQ(1, output_packets_mult.size());
|
||||
const tf::Tensor& tensor_mult = output_packets_mult[0].Get<tf::Tensor>();
|
||||
auto expected_tensor = tf::test::AsTensor<int32>({9, 12, 15});
|
||||
|
|
|
@ -47,6 +47,11 @@ namespace mediapipe {
|
|||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||
|
||||
// Updates the graph nodes to use the device as specified by device_id.
|
||||
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
||||
for (auto& node : *graph_def->mutable_node()) {
|
||||
|
@ -64,30 +69,32 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
|||
cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
|
||||
bool has_exactly_one_model =
|
||||
!options.graph_proto_path().empty()
|
||||
? !(cc->InputSidePackets().HasTag("STRING_MODEL") |
|
||||
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"))
|
||||
: (cc->InputSidePackets().HasTag("STRING_MODEL") ^
|
||||
cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH"));
|
||||
? !(cc->InputSidePackets().HasTag(kStringModelTag) |
|
||||
cc->InputSidePackets().HasTag(kStringModelFilePathTag))
|
||||
: (cc->InputSidePackets().HasTag(kStringModelTag) ^
|
||||
cc->InputSidePackets().HasTag(kStringModelFilePathTag));
|
||||
RET_CHECK(has_exactly_one_model)
|
||||
<< "Must have exactly one of graph_proto_path in options or "
|
||||
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
||||
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
|
||||
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("STRING_MODEL")
|
||||
.Tag(kStringModelTag)
|
||||
.Set<std::string>(
|
||||
// String model from embedded path
|
||||
);
|
||||
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
|
||||
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag("STRING_MODEL_FILE_PATH")
|
||||
.Tag(kStringModelFilePathTag)
|
||||
.Set<std::string>(
|
||||
// Filename of std::string model.
|
||||
);
|
||||
}
|
||||
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>(
|
||||
// A TensorFlow model loaded and ready for use along with
|
||||
// a map from tags to tensor names.
|
||||
);
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kSessionTag)
|
||||
.Set<TensorFlowSession>(
|
||||
// A TensorFlow model loaded and ready for use along with
|
||||
// a map from tags to tensor names.
|
||||
);
|
||||
RET_CHECK_GT(options.tag_to_tensor_names().size(), 0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -111,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
|||
session->session.reset(tf::NewSession(session_options));
|
||||
|
||||
std::string graph_def_serialized;
|
||||
if (cc->InputSidePackets().HasTag("STRING_MODEL")) {
|
||||
if (cc->InputSidePackets().HasTag(kStringModelTag)) {
|
||||
graph_def_serialized =
|
||||
cc->InputSidePackets().Tag("STRING_MODEL").Get<std::string>();
|
||||
} else if (cc->InputSidePackets().HasTag("STRING_MODEL_FILE_PATH")) {
|
||||
cc->InputSidePackets().Tag(kStringModelTag).Get<std::string>();
|
||||
} else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
|
||||
const std::string& frozen_graph = cc->InputSidePackets()
|
||||
.Tag("STRING_MODEL_FILE_PATH")
|
||||
.Tag(kStringModelFilePathTag)
|
||||
.Get<std::string>();
|
||||
RET_CHECK_OK(
|
||||
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
||||
|
@ -147,7 +154,7 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
|
|||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||
}
|
||||
|
||||
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
|
||||
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
|
||||
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
||||
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
||||
<< " microseconds.";
|
||||
|
|
|
@ -37,6 +37,10 @@ namespace {
|
|||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetGraphDefPath() {
|
||||
return mediapipe::file::JoinPath("./",
|
||||
"mediapipe/calculators/tensorflow/"
|
||||
|
@ -112,7 +116,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
VerifySignatureMap(session);
|
||||
}
|
||||
|
||||
|
@ -190,12 +194,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
std::string serialized_graph_contents;
|
||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
VerifySignatureMap(session);
|
||||
}
|
||||
|
||||
|
@ -213,12 +217,12 @@ TEST_F(
|
|||
}
|
||||
})",
|
||||
calculator_options_->DebugString()));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
VerifySignatureMap(session);
|
||||
}
|
||||
|
||||
|
@ -234,7 +238,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
}
|
||||
})",
|
||||
calculator_options_->DebugString()));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
auto run_status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
|
@ -255,12 +259,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
}
|
||||
})",
|
||||
calculator_options_->DebugString()));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
std::string serialized_graph_contents;
|
||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
auto run_status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
|
@ -282,12 +286,12 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
}
|
||||
})",
|
||||
calculator_options_->DebugString()));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL_FILE_PATH") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
std::string serialized_graph_contents;
|
||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
runner.MutableSidePackets()->Tag("STRING_MODEL") =
|
||||
runner.MutableSidePackets()->Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
auto run_status = runner.Run();
|
||||
EXPECT_THAT(
|
||||
|
@ -310,7 +314,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest,
|
|||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
VerifySignatureMap(session);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,11 @@ namespace mediapipe {
|
|||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||
|
||||
// Updates the graph nodes to use the device as specified by device_id.
|
||||
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
|
||||
for (auto& node : *graph_def->mutable_node()) {
|
||||
|
@ -64,28 +69,29 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
|||
TensorFlowSessionFromFrozenGraphGeneratorOptions::ext);
|
||||
bool has_exactly_one_model =
|
||||
!options.graph_proto_path().empty()
|
||||
? !(input_side_packets->HasTag("STRING_MODEL") |
|
||||
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"))
|
||||
: (input_side_packets->HasTag("STRING_MODEL") ^
|
||||
input_side_packets->HasTag("STRING_MODEL_FILE_PATH"));
|
||||
? !(input_side_packets->HasTag(kStringModelTag) |
|
||||
input_side_packets->HasTag(kStringModelFilePathTag))
|
||||
: (input_side_packets->HasTag(kStringModelTag) ^
|
||||
input_side_packets->HasTag(kStringModelFilePathTag));
|
||||
RET_CHECK(has_exactly_one_model)
|
||||
<< "Must have exactly one of graph_proto_path in options or "
|
||||
"input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
|
||||
if (input_side_packets->HasTag("STRING_MODEL")) {
|
||||
input_side_packets->Tag("STRING_MODEL")
|
||||
if (input_side_packets->HasTag(kStringModelTag)) {
|
||||
input_side_packets->Tag(kStringModelTag)
|
||||
.Set<std::string>(
|
||||
// String model from embedded path
|
||||
);
|
||||
} else if (input_side_packets->HasTag("STRING_MODEL_FILE_PATH")) {
|
||||
input_side_packets->Tag("STRING_MODEL_FILE_PATH")
|
||||
} else if (input_side_packets->HasTag(kStringModelFilePathTag)) {
|
||||
input_side_packets->Tag(kStringModelFilePathTag)
|
||||
.Set<std::string>(
|
||||
// Filename of std::string model.
|
||||
);
|
||||
}
|
||||
output_side_packets->Tag("SESSION").Set<TensorFlowSession>(
|
||||
// A TensorFlow model loaded and ready for use along with
|
||||
// a map from tags to tensor names.
|
||||
);
|
||||
output_side_packets->Tag(kSessionTag)
|
||||
.Set<TensorFlowSession>(
|
||||
// A TensorFlow model loaded and ready for use along with
|
||||
// a map from tags to tensor names.
|
||||
);
|
||||
RET_CHECK_GT(options.tag_to_tensor_names().size(), 0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -112,12 +118,12 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
|||
session->session.reset(tf::NewSession(session_options));
|
||||
|
||||
std::string graph_def_serialized;
|
||||
if (input_side_packets.HasTag("STRING_MODEL")) {
|
||||
if (input_side_packets.HasTag(kStringModelTag)) {
|
||||
graph_def_serialized =
|
||||
input_side_packets.Tag("STRING_MODEL").Get<std::string>();
|
||||
} else if (input_side_packets.HasTag("STRING_MODEL_FILE_PATH")) {
|
||||
input_side_packets.Tag(kStringModelTag).Get<std::string>();
|
||||
} else if (input_side_packets.HasTag(kStringModelFilePathTag)) {
|
||||
const std::string& frozen_graph =
|
||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH").Get<std::string>();
|
||||
input_side_packets.Tag(kStringModelFilePathTag).Get<std::string>();
|
||||
RET_CHECK_OK(
|
||||
mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
|
||||
} else {
|
||||
|
@ -147,7 +153,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator {
|
|||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||
}
|
||||
|
||||
output_side_packets->Tag("SESSION") = Adopt(session.release());
|
||||
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
|
||||
const uint64 end_time = absl::ToUnixMicros(clock->TimeNow());
|
||||
LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
|
||||
<< " microseconds.";
|
||||
|
|
|
@ -37,6 +37,10 @@ namespace {
|
|||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
|
||||
constexpr char kStringModelTag[] = "STRING_MODEL";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetGraphDefPath() {
|
||||
return mediapipe::file::JoinPath("./",
|
||||
"mediapipe/calculators/tensorflow/"
|
||||
|
@ -72,7 +76,7 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test {
|
|||
|
||||
void VerifySignatureMap(PacketSet* output_side_packets) {
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets->Tag("SESSION").Get<TensorFlowSession>();
|
||||
output_side_packets->Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
|
||||
|
@ -179,7 +183,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
|||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
generator_options_->clear_graph_proto_path();
|
||||
input_side_packets.Tag("STRING_MODEL") =
|
||||
input_side_packets.Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||
|
@ -196,7 +200,7 @@ TEST_F(
|
|||
PacketSet output_side_packets(
|
||||
tool::CreateTagMap({"SESSION:session"}).value());
|
||||
generator_options_->clear_graph_proto_path();
|
||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
||||
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||
|
@ -211,7 +215,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
|||
tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value());
|
||||
PacketSet output_side_packets(
|
||||
tool::CreateTagMap({"SESSION:session"}).value());
|
||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
||||
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
"TensorFlowSessionFromFrozenGraphGenerator", extendable_options_,
|
||||
|
@ -233,9 +237,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
|||
std::string serialized_graph_contents;
|
||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
input_side_packets.Tag("STRING_MODEL") =
|
||||
input_side_packets.Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
||||
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
|
||||
absl::Status run_status = tool::RunGenerateAndValidateTypes(
|
||||
|
@ -258,9 +262,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest,
|
|||
std::string serialized_graph_contents;
|
||||
MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(),
|
||||
&serialized_graph_contents));
|
||||
input_side_packets.Tag("STRING_MODEL") =
|
||||
input_side_packets.Tag(kStringModelTag) =
|
||||
Adopt(new std::string(serialized_graph_contents));
|
||||
input_side_packets.Tag("STRING_MODEL_FILE_PATH") =
|
||||
input_side_packets.Tag(kStringModelFilePathTag) =
|
||||
Adopt(new std::string(GetGraphDefPath()));
|
||||
generator_options_->clear_graph_proto_path();
|
||||
|
||||
|
|
|
@ -31,6 +31,9 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||
|
||||
// Given the path to a directory containing multiple tensorflow saved models
|
||||
|
@ -108,7 +111,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
|
|||
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
|
||||
}
|
||||
// A TensorFlow model loaded and ready for use along with tensor
|
||||
cc->OutputSidePackets().Tag("SESSION").Set<TensorFlowSession>();
|
||||
cc->OutputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -160,7 +163,7 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
|
|||
output_signature.first, options)] = output_signature.second.name();
|
||||
}
|
||||
|
||||
cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release()));
|
||||
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,9 @@ namespace {
|
|||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetSavedModelDir() {
|
||||
std::string out_path =
|
||||
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
||||
|
@ -79,7 +82,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
|||
options_->DebugString()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
|
||||
|
@ -119,11 +122,11 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
|||
}
|
||||
})",
|
||||
options_->DebugString()));
|
||||
runner.MutableSidePackets()->Tag("STRING_SAVED_MODEL_PATH") =
|
||||
runner.MutableSidePackets()->Tag(kStringSavedModelPathTag) =
|
||||
MakePacket<std::string>(GetSavedModelDir());
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
@ -201,7 +204,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
|||
options_->DebugString()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
@ -224,7 +227,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest,
|
|||
options_->DebugString()));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const TensorFlowSession& session =
|
||||
runner.OutputSidePackets().Tag("SESSION").Get<TensorFlowSession>();
|
||||
runner.OutputSidePackets().Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
std::vector<tensorflow::DeviceAttributes> devices;
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
|
||||
|
||||
// Given the path to a directory containing multiple tensorflow saved models
|
||||
|
@ -100,7 +103,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
|||
input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
|
||||
}
|
||||
// A TensorFlow model loaded and ready for use along with tensor
|
||||
output_side_packets->Tag("SESSION").Set<TensorFlowSession>();
|
||||
output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -153,7 +156,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
|
|||
output_signature.first, options)] = output_signature.second.name();
|
||||
}
|
||||
|
||||
output_side_packets->Tag("SESSION") = Adopt(session.release());
|
||||
output_side_packets->Tag(kSessionTag) = Adopt(session.release());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -34,6 +34,9 @@ namespace {
|
|||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kStringSavedModelPathTag[] = "STRING_SAVED_MODEL_PATH";
|
||||
constexpr char kSessionTag[] = "SESSION";
|
||||
|
||||
std::string GetSavedModelDir() {
|
||||
std::string out_path =
|
||||
file::JoinPath("./", "mediapipe/calculators/tensorflow/testdata/",
|
||||
|
@ -75,7 +78,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
input_side_packets, &output_side_packets);
|
||||
MP_EXPECT_OK(run_status) << run_status.message();
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
||||
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
|
||||
|
@ -107,7 +110,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
generator_options_->clear_saved_model_path();
|
||||
PacketSet input_side_packets(
|
||||
tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value());
|
||||
input_side_packets.Tag("STRING_SAVED_MODEL_PATH") =
|
||||
input_side_packets.Tag(kStringSavedModelPathTag) =
|
||||
Adopt(new std::string(GetSavedModelDir()));
|
||||
PacketSet output_side_packets(
|
||||
tool::CreateTagMap({"SESSION:session"}).value());
|
||||
|
@ -116,7 +119,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
input_side_packets, &output_side_packets);
|
||||
MP_EXPECT_OK(run_status) << run_status.message();
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
||||
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
@ -192,7 +195,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
input_side_packets, &output_side_packets);
|
||||
MP_EXPECT_OK(run_status) << run_status.message();
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
||||
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
}
|
||||
|
@ -213,7 +216,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest,
|
|||
input_side_packets, &output_side_packets);
|
||||
MP_EXPECT_OK(run_status) << run_status.message();
|
||||
const TensorFlowSession& session =
|
||||
output_side_packets.Tag("SESSION").Get<TensorFlowSession>();
|
||||
output_side_packets.Tag(kSessionTag).Get<TensorFlowSession>();
|
||||
// Session must be set.
|
||||
ASSERT_NE(session.session, nullptr);
|
||||
std::vector<tensorflow::DeviceAttributes> devices;
|
||||
|
|
|
@ -33,6 +33,31 @@ namespace {
|
|||
namespace tf = ::tensorflow;
|
||||
namespace mpms = mediapipe::mediasequence;
|
||||
|
||||
constexpr char kImageFrameRateTag[] = "IMAGE_FRAME_RATE";
|
||||
constexpr char kEncodedMediaStartTimestampTag[] =
|
||||
"ENCODED_MEDIA_START_TIMESTAMP";
|
||||
constexpr char kEncodedMediaTag[] = "ENCODED_MEDIA";
|
||||
constexpr char kResamplerOptionsTag[] = "RESAMPLER_OPTIONS";
|
||||
constexpr char kSandboxedDecoderOptionsTag[] = "SANDBOXED_DECODER_OPTIONS";
|
||||
constexpr char kDecoderOptionsTag[] = "DECODER_OPTIONS";
|
||||
constexpr char kAudioDecoderOptionsTag[] = "AUDIO_DECODER_OPTIONS";
|
||||
constexpr char kDataPathTag[] = "DATA_PATH";
|
||||
constexpr char kDatasetRootTag[] = "DATASET_ROOT";
|
||||
constexpr char kMediaIdTag[] = "MEDIA_ID";
|
||||
constexpr char kFloatFeatureFdenseMaxTag[] = "FLOAT_FEATURE_FDENSE_MAX";
|
||||
constexpr char kFloatFeatureFdenseAvgTag[] = "FLOAT_FEATURE_FDENSE_AVG";
|
||||
constexpr char kAudioOtherTag[] = "AUDIO_OTHER";
|
||||
constexpr char kAudioTestTag[] = "AUDIO_TEST";
|
||||
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||
constexpr char kBboxPrefixTag[] = "BBOX_PREFIX";
|
||||
constexpr char kKeypointsTag[] = "KEYPOINTS";
|
||||
constexpr char kBboxTag[] = "BBOX";
|
||||
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||
|
||||
class UnpackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpCalculator(const std::vector<std::string>& output_streams,
|
||||
|
@ -95,13 +120,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneImage) {
|
|||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE").packets;
|
||||
runner_->Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
|
@ -124,13 +149,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoImages) {
|
|||
mpms::AddImageEncoded(test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE").packets;
|
||||
runner_->Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
|
@ -154,13 +179,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPrefixedImages) {
|
|||
mpms::AddImageEncoded(prefix, test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE_PREFIX").packets;
|
||||
runner_->Outputs().Tag(kImagePrefixTag).packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
|
@ -182,12 +207,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksOneForwardFlowImage) {
|
|||
mpms::AddForwardFlowEncoded(test_image_string, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
|
||||
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
|
||||
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_forward_flow_images; ++i) {
|
||||
|
@ -211,12 +236,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoForwardFlowImages) {
|
|||
mpms::AddForwardFlowEncoded(test_image_strings[i], input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("FORWARD_FLOW_ENCODED").packets;
|
||||
runner_->Outputs().Tag(kForwardFlowEncodedTag).packets;
|
||||
ASSERT_EQ(num_forward_flow_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_forward_flow_images; ++i) {
|
||||
|
@ -240,13 +265,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksBBoxes) {
|
|||
mpms::AddBBoxTimestamp(i, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("BBOX").packets;
|
||||
runner_->Outputs().Tag(kBboxTag).packets;
|
||||
ASSERT_EQ(bboxes.size(), output_packets.size());
|
||||
|
||||
for (int i = 0; i < bboxes.size(); ++i) {
|
||||
|
@ -274,13 +299,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPrefixedBBoxes) {
|
|||
mpms::AddBBoxTimestamp(prefix, i, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("BBOX_PREFIX").packets;
|
||||
runner_->Outputs().Tag(kBboxPrefixTag).packets;
|
||||
ASSERT_EQ(bboxes.size(), output_packets.size());
|
||||
|
||||
for (int i = 0; i < bboxes.size(); ++i) {
|
||||
|
@ -306,13 +331,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
|
|||
mpms::AddFeatureTimestamp("OTHER", i, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_TEST").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureTestTag).packets;
|
||||
ASSERT_EQ(num_float_lists, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_float_lists; ++i) {
|
||||
|
@ -322,7 +347,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoFloatLists) {
|
|||
}
|
||||
|
||||
const std::vector<Packet>& output_packets_other =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
|
||||
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
||||
|
||||
for (int i = 0; i < num_float_lists; ++i) {
|
||||
|
@ -352,12 +377,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
|
|||
mpms::AddFeatureTimestamp("OTHER", i + 5, input_sequence.get());
|
||||
}
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE").packets;
|
||||
runner_->Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
|
@ -366,7 +391,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksNonOverlappingTimestamps) {
|
|||
}
|
||||
|
||||
const std::vector<Packet>& output_packets_other =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_OTHER").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureOtherTag).packets;
|
||||
ASSERT_EQ(num_float_lists, output_packets_other.size());
|
||||
|
||||
for (int i = 0; i < num_float_lists; ++i) {
|
||||
|
@ -389,12 +414,12 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
|||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||
input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& fdense_avg_packets =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_AVG").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureFdenseAvgTag).packets;
|
||||
ASSERT_EQ(fdense_avg_packets.size(), 1);
|
||||
const auto& fdense_avg_vector =
|
||||
fdense_avg_packets[0].Get<std::vector<float>>();
|
||||
|
@ -403,7 +428,7 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksTwoPostStreamFloatLists) {
|
|||
::testing::Eq(Timestamp::PostStream()));
|
||||
|
||||
const std::vector<Packet>& fdense_max_packets =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
|
||||
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||
const auto& fdense_max_vector =
|
||||
fdense_max_packets[0].Get<std::vector<float>>();
|
||||
|
@ -430,13 +455,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksImageWithPostStreamFloatList) {
|
|||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||
input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("IMAGE").packets;
|
||||
runner_->Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(num_images, output_packets.size());
|
||||
|
||||
for (int i = 0; i < num_images; ++i) {
|
||||
|
@ -463,13 +488,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
|||
mpms::AddFeatureTimestamp("FDENSE_MAX", Timestamp::PostStream().Value(),
|
||||
input_sequence.get());
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& fdense_max_packets =
|
||||
runner_->Outputs().Tag("FLOAT_FEATURE_FDENSE_MAX").packets;
|
||||
runner_->Outputs().Tag(kFloatFeatureFdenseMaxTag).packets;
|
||||
ASSERT_EQ(fdense_max_packets.size(), 1);
|
||||
const auto& fdense_max_vector =
|
||||
fdense_max_packets[0].Get<std::vector<float>>();
|
||||
|
@ -481,17 +506,17 @@ TEST_F(UnpackMediaSequenceCalculatorTest, UnpacksPostStreamFloatListWithImage) {
|
|||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromPacket) {
|
||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {"DATASET_ROOT:root"});
|
||||
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
|
||||
std::string root = "test_root";
|
||||
runner_->MutableSidePackets()->Tag("DATASET_ROOT") = PointToForeign(&root);
|
||||
runner_->MutableSidePackets()->Tag(kDatasetRootTag) = PointToForeign(&root);
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||
.Tag("DATA_PATH")
|
||||
.Tag(kDataPathTag)
|
||||
.ValidateAsType<std::string>());
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||
root + "/" + data_path_);
|
||||
}
|
||||
|
||||
|
@ -501,28 +526,28 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromOptions) {
|
|||
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
|
||||
->set_dataset_root_directory(root);
|
||||
SetUpCalculator({}, {"DATA_PATH:data_path"}, {}, &options);
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||
.Tag("DATA_PATH")
|
||||
.Tag(kDataPathTag)
|
||||
.ValidateAsType<std::string>());
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||
root + "/" + data_path_);
|
||||
}
|
||||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
|
||||
SetUpCalculator({}, {"DATA_PATH:data_path"});
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_ASSERT_OK(runner_->OutputSidePackets()
|
||||
.Tag("DATA_PATH")
|
||||
.Tag(kDataPathTag)
|
||||
.ValidateAsType<std::string>());
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag("DATA_PATH").Get<std::string>(),
|
||||
ASSERT_EQ(runner_->OutputSidePackets().Tag(kDataPathTag).Get<std::string>(),
|
||||
data_path_);
|
||||
}
|
||||
|
||||
|
@ -534,20 +559,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) {
|
|||
->set_padding_after_label(2);
|
||||
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
||||
&options);
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.ValidateAsType<AudioDecoderOptions>());
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.Get<AudioDecoderOptions>()
|
||||
.start_time(),
|
||||
2.0, 1e-5);
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.Get<AudioDecoderOptions>()
|
||||
.end_time(),
|
||||
7.0, 1e-5);
|
||||
|
@ -563,20 +588,20 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) {
|
|||
->set_force_decoding_from_start_of_media(true);
|
||||
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
|
||||
&options);
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.ValidateAsType<AudioDecoderOptions>());
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.Get<AudioDecoderOptions>()
|
||||
.start_time(),
|
||||
0.0, 1e-5);
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("AUDIO_DECODER_OPTIONS")
|
||||
.Tag(kAudioDecoderOptionsTag)
|
||||
.Get<AudioDecoderOptions>()
|
||||
.end_time(),
|
||||
7.0, 1e-5);
|
||||
|
@ -594,27 +619,27 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
|
|||
->mutable_base_packet_resampler_options()
|
||||
->set_frame_rate(1.0);
|
||||
SetUpCalculator({}, {"RESAMPLER_OPTIONS:resampler_options"}, {}, &options);
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||
.Tag("RESAMPLER_OPTIONS")
|
||||
.Tag(kResamplerOptionsTag)
|
||||
.ValidateAsType<CalculatorOptions>());
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("RESAMPLER_OPTIONS")
|
||||
.Tag(kResamplerOptionsTag)
|
||||
.Get<CalculatorOptions>()
|
||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||
.start_time(),
|
||||
2000000, 1);
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("RESAMPLER_OPTIONS")
|
||||
.Tag(kResamplerOptionsTag)
|
||||
.Get<CalculatorOptions>()
|
||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||
.end_time(),
|
||||
7000000, 1);
|
||||
EXPECT_NEAR(runner_->OutputSidePackets()
|
||||
.Tag("RESAMPLER_OPTIONS")
|
||||
.Tag(kResamplerOptionsTag)
|
||||
.Get<CalculatorOptions>()
|
||||
.GetExtension(PacketResamplerCalculatorOptions::ext)
|
||||
.frame_rate(),
|
||||
|
@ -623,13 +648,13 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
|
|||
|
||||
TEST_F(UnpackMediaSequenceCalculatorTest, GetFrameRateFromExample) {
|
||||
SetUpCalculator({}, {"IMAGE_FRAME_RATE:frame_rate"});
|
||||
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(sequence_.release());
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
MP_EXPECT_OK(runner_->OutputSidePackets()
|
||||
.Tag("IMAGE_FRAME_RATE")
|
||||
.Tag(kImageFrameRateTag)
|
||||
.ValidateAsType<double>());
|
||||
EXPECT_EQ(runner_->OutputSidePackets().Tag("IMAGE_FRAME_RATE").Get<double>(),
|
||||
EXPECT_EQ(runner_->OutputSidePackets().Tag(kImageFrameRateTag).Get<double>(),
|
||||
image_frame_rate_);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,10 @@ namespace {
|
|||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
constexpr char kSingleIntTag[] = "SINGLE_INT";
|
||||
constexpr char kTensorOutTag[] = "TENSOR_OUT";
|
||||
constexpr char kVectorIntTag[] = "VECTOR_INT";
|
||||
|
||||
class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUpRunner(
|
||||
|
@ -61,13 +65,13 @@ class VectorIntToTensorCalculatorTest : public ::testing::Test {
|
|||
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
->Tag(kVectorIntTag)
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
@ -95,13 +99,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestSingleValue) {
|
|||
tensorflow::DT_INT32, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
->Tag(kSingleIntTag)
|
||||
.packets.push_back(MakePacket<int>(1).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
@ -121,13 +125,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TesOneDim) {
|
|||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
->Tag(kVectorIntTag)
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
@ -152,13 +156,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestInt64) {
|
|||
tensorflow::DT_INT64, false, true);
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("SINGLE_INT")
|
||||
->Tag(kSingleIntTag)
|
||||
.packets.push_back(MakePacket<int>(1LL << 31).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
@ -179,13 +183,13 @@ TEST_F(VectorIntToTensorCalculatorTest, TestUint8) {
|
|||
}
|
||||
const int64 time = 1234;
|
||||
runner_->MutableInputs()
|
||||
->Tag("VECTOR_INT")
|
||||
->Tag(kVectorIntTag)
|
||||
.packets.push_back(Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
EXPECT_TRUE(runner_->Run().ok());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag("TENSOR_OUT").packets;
|
||||
runner_->Outputs().Tag(kTensorOutTag).packets;
|
||||
EXPECT_EQ(1, output_packets.size());
|
||||
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||
const tf::Tensor& output_tensor = output_packets[0].Get<tf::Tensor>();
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kFloatsTag[] = "FLOATS";
|
||||
constexpr char kFloatTag[] = "FLOAT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
// A calculator for converting TFLite tensors to to a float or a float vector.
|
||||
//
|
||||
// Input:
|
||||
|
@ -48,15 +52,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator);
|
|||
|
||||
absl::Status TfLiteTensorsToFloatsCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("TENSORS"));
|
||||
RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT"));
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||
RET_CHECK(cc->Outputs().HasTag(kFloatsTag) ||
|
||||
cc->Outputs().HasTag(kFloatTag));
|
||||
|
||||
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
cc->Outputs().Tag("FLOATS").Set<std::vector<float>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
cc->Outputs().Tag(kFloatsTag).Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
cc->Outputs().Tag("FLOAT").Set<float>();
|
||||
if (cc->Outputs().HasTag(kFloatTag)) {
|
||||
cc->Outputs().Tag(kFloatTag).Set<float>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -69,10 +74,10 @@ absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
||||
RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty());
|
||||
RET_CHECK(!cc->Inputs().Tag(kTensorsTag).IsEmpty());
|
||||
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
|
||||
// TODO: Add option to specify which tensor to take from.
|
||||
const TfLiteTensor* raw_tensor = &input_tensors[0];
|
||||
const float* raw_floats = raw_tensor->data.f;
|
||||
|
@ -82,18 +87,19 @@ absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) {
|
|||
num_values *= raw_tensor->dims->data[i];
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("FLOAT")) {
|
||||
if (cc->Outputs().HasTag(kFloatTag)) {
|
||||
// TODO: Could add an index in the option to specifiy returning one
|
||||
// value of a float array.
|
||||
RET_CHECK_EQ(num_values, 1);
|
||||
cc->Outputs().Tag("FLOAT").AddPacket(
|
||||
cc->Outputs().Tag(kFloatTag).AddPacket(
|
||||
MakePacket<float>(raw_floats[0]).At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag("FLOATS")) {
|
||||
if (cc->Outputs().HasTag(kFloatsTag)) {
|
||||
auto output_floats = absl::make_unique<std::vector<float>>(
|
||||
raw_floats, raw_floats + num_values);
|
||||
cc->Outputs().Tag("FLOATS").Add(output_floats.release(),
|
||||
cc->InputTimestamp());
|
||||
cc->Outputs()
|
||||
.Tag(kFloatsTag)
|
||||
.Add(output_floats.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -87,7 +87,7 @@ absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) {
|
|||
// Initialize the clock.
|
||||
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
||||
clock_ = cc->InputSidePackets()
|
||||
.Tag("CLOCK")
|
||||
.Tag(kClockTag)
|
||||
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
||||
} else {
|
||||
clock_.reset(
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kIterableTag[] = "ITERABLE";
|
||||
|
||||
typedef CollectionHasMinSizeCalculator<std::vector<int>>
|
||||
TestIntCollectionHasMinSizeCalculator;
|
||||
REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
||||
|
@ -34,7 +36,7 @@ REGISTER_CALCULATOR(TestIntCollectionHasMinSizeCalculator);
|
|||
void AddInputVector(const std::vector<int>& input, int64 timestamp,
|
||||
CalculatorRunner* runner) {
|
||||
runner->MutableInputs()
|
||||
->Tag("ITERABLE")
|
||||
->Tag(kIterableTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int>>(input).At(Timestamp(timestamp)));
|
||||
}
|
||||
|
|
|
@ -144,7 +144,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag("DETECTIONS")
|
||||
.Tag(kDetectionsTag)
|
||||
.Add(output_detections.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
|
||||
LocationData CreateRelativeLocationData(double xmin, double ymin, double width,
|
||||
double height) {
|
||||
LocationData location_data;
|
||||
|
@ -76,19 +79,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingLeftRight) {
|
|||
detections->push_back(
|
||||
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
||||
runner.MutableInputs()
|
||||
->Tag("LETTERBOX_PADDING")
|
||||
->Tag(kLetterboxPaddingTag)
|
||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("DETECTIONS").packets;
|
||||
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||
|
||||
|
@ -124,19 +127,19 @@ TEST(DetectionLetterboxRemovalCalculatorTest, PaddingTopBottom) {
|
|||
detections->push_back(
|
||||
CreateDetection({label}, {}, {0.3f}, location_data, "feature_tag"));
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||
std::array<float, 4>{0.f, 0.2f, 0.f, 0.3f});
|
||||
runner.MutableInputs()
|
||||
->Tag("LETTERBOX_PADDING")
|
||||
->Tag(kLetterboxPaddingTag)
|
||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("DETECTIONS").packets;
|
||||
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||
|
||||
|
|
|
@ -31,6 +31,9 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::FloatNear;
|
||||
|
||||
|
@ -74,19 +77,19 @@ absl::StatusOr<Detection> RunProjectionCalculator(
|
|||
)pb"));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(MakePacket<std::vector<Detection>>(
|
||||
std::vector<Detection>({std::move(detection)}))
|
||||
.At(Timestamp::PostStream()));
|
||||
runner.MutableInputs()
|
||||
->Tag("PROJECTION_MATRIX")
|
||||
->Tag(kProjectionMatrixTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::array<float, 16>>(std::move(project_mat))
|
||||
.At(Timestamp::PostStream()));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("DETECTIONS").packets;
|
||||
runner.Outputs().Tag(kDetectionsTag).packets;
|
||||
RET_CHECK_EQ(output.size(), 1);
|
||||
const auto& output_detections = output[0].Get<std::vector<Detection>>();
|
||||
|
||||
|
|
|
@ -32,6 +32,14 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kRectsTag[] = "RECTS";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kRectTag[] = "RECT";
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
|
||||
MATCHER_P4(RectEq, x_center, y_center, width, height, "") {
|
||||
return testing::Value(arg.x_center(), testing::Eq(x_center)) &&
|
||||
testing::Value(arg.y_center(), testing::Eq(y_center)) &&
|
||||
|
@ -94,12 +102,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) {
|
|||
DetectionWithLocationData(100, 200, 300, 400));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rect = output[0].Get<Rect>();
|
||||
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
||||
|
@ -120,16 +128,16 @@ absl::StatusOr<Rect> RunDetectionKeyPointsToRectCalculation(
|
|||
)pb"));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
||||
.At(Timestamp::PostStream()));
|
||||
runner.MutableInputs()
|
||||
->Tag("IMAGE_SIZE")
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(MakePacket<std::pair<int, int>>(image_size)
|
||||
.At(Timestamp::PostStream()));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||
RET_CHECK_EQ(output.size(), 1);
|
||||
return output[0].Get<Rect>();
|
||||
}
|
||||
|
@ -176,12 +184,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) {
|
|||
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag(kNormRectTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rect = output[0].Get<NormalizedRect>();
|
||||
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
||||
|
@ -201,12 +210,13 @@ absl::StatusOr<NormalizedRect> RunDetectionKeyPointsToNormRectCalculation(
|
|||
)pb"));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(MakePacket<Detection>(std::move(detection))
|
||||
.At(Timestamp::PostStream()));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag(kNormRectTag).packets;
|
||||
RET_CHECK_EQ(output.size(), 1);
|
||||
return output[0].Get<NormalizedRect>();
|
||||
}
|
||||
|
@ -248,12 +258,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) {
|
|||
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECT").packets;
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag(kRectTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rect = output[0].Get<Rect>();
|
||||
EXPECT_THAT(rect, RectEq(250, 400, 300, 400));
|
||||
|
@ -271,12 +281,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRect) {
|
|||
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("NORM_RECT").packets;
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag(kNormRectTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rect = output[0].Get<NormalizedRect>();
|
||||
EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f));
|
||||
|
@ -294,12 +305,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToRects) {
|
|||
detections->push_back(DetectionWithLocationData(200, 300, 400, 500));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rects = output[0].Get<std::vector<Rect>>();
|
||||
ASSERT_EQ(rects.size(), 2);
|
||||
|
@ -319,13 +330,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionsToNormalizedRects) {
|
|||
detections->push_back(DetectionWithRelativeLocationData(0.2, 0.3, 0.4, 0.5));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("NORM_RECTS").packets;
|
||||
runner.Outputs().Tag(kNormRectsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
ASSERT_EQ(rects.size(), 2);
|
||||
|
@ -344,12 +355,12 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRects) {
|
|||
DetectionWithLocationData(100, 200, 300, 400));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("RECTS").packets;
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag(kRectsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rects = output[0].Get<std::vector<Rect>>();
|
||||
EXPECT_EQ(rects.size(), 1);
|
||||
|
@ -367,13 +378,13 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRects) {
|
|||
DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION")
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("NORM_RECTS").packets;
|
||||
runner.Outputs().Tag(kNormRectsTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||
ASSERT_EQ(rects.size(), 1);
|
||||
|
@ -391,7 +402,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToRect) {
|
|||
detections->push_back(DetectionWithRelativeLocationData(0.1, 0.2, 0.3, 0.4));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
|
@ -411,7 +422,7 @@ TEST(DetectionsToRectsCalculatorTest, WrongInputToNormalizedRect) {
|
|||
detections->push_back(DetectionWithLocationData(100, 200, 300, 400));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
|
|
|
@ -30,6 +30,10 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||
constexpr char kDetectionListTag[] = "DETECTION_LIST";
|
||||
|
||||
using ::testing::DoubleNear;
|
||||
|
||||
// Error tolerance for pixels, distances, etc.
|
||||
|
@ -97,13 +101,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionList) {
|
|||
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag");
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION_LIST")
|
||||
->Tag(kDetectionListTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
||||
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& actual = output[0].Get<RenderData>();
|
||||
EXPECT_EQ(actual.render_annotations_size(), 3);
|
||||
|
@ -131,13 +135,13 @@ TEST(DetectionsToRenderDataCalculatorTest, OnlyDetecctionVector) {
|
|||
CreateDetection({"label1"}, {}, {0.3}, location_data, "feature_tag"));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
||||
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& actual = output[0].Get<RenderData>();
|
||||
EXPECT_EQ(actual.render_annotations_size(), 3);
|
||||
|
@ -165,7 +169,7 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
|
|||
*(detection_list->add_detection()) =
|
||||
CreateDetection({"label1"}, {}, {0.3}, location_data1, "feature_tag1");
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTION_LIST")
|
||||
->Tag(kDetectionListTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection_list.release()).At(Timestamp::PostStream()));
|
||||
|
||||
|
@ -174,13 +178,13 @@ TEST(DetectionsToRenderDataCalculatorTest, BothDetecctionListAndVector) {
|
|||
detections->push_back(
|
||||
CreateDetection({"label2"}, {}, {0.6}, location_data2, "feature_tag2"));
|
||||
runner.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& actual =
|
||||
runner.Outputs().Tag("RENDER_DATA").packets;
|
||||
runner.Outputs().Tag(kRenderDataTag).packets;
|
||||
ASSERT_EQ(1, actual.size());
|
||||
// Check the feature tag for item from detection list.
|
||||
EXPECT_EQ(
|
||||
|
@ -209,19 +213,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
|
|||
|
||||
auto detection_list1(absl::make_unique<DetectionList>());
|
||||
runner1.MutableInputs()
|
||||
->Tag("DETECTION_LIST")
|
||||
->Tag(kDetectionListTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection_list1.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto detections1(absl::make_unique<std::vector<Detection>>());
|
||||
runner1.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections1.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner1.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& exact1 =
|
||||
runner1.Outputs().Tag("RENDER_DATA").packets;
|
||||
runner1.Outputs().Tag(kRenderDataTag).packets;
|
||||
ASSERT_EQ(0, exact1.size());
|
||||
|
||||
// Check when produce_empty_packet is true.
|
||||
|
@ -240,19 +244,19 @@ TEST(DetectionsToRenderDataCalculatorTest, ProduceEmptyPacket) {
|
|||
|
||||
auto detection_list2(absl::make_unique<DetectionList>());
|
||||
runner2.MutableInputs()
|
||||
->Tag("DETECTION_LIST")
|
||||
->Tag(kDetectionListTag)
|
||||
.packets.push_back(
|
||||
Adopt(detection_list2.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto detections2(absl::make_unique<std::vector<Detection>>());
|
||||
runner2.MutableInputs()
|
||||
->Tag("DETECTIONS")
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(
|
||||
Adopt(detections2.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner2.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& exact2 =
|
||||
runner2.Outputs().Tag("RENDER_DATA").packets;
|
||||
runner2.Outputs().Tag(kRenderDataTag).packets;
|
||||
ASSERT_EQ(1, exact2.size());
|
||||
EXPECT_EQ(exact2[0].Get<RenderData>().render_annotations_size(), 0);
|
||||
}
|
||||
|
|
|
@ -32,6 +32,12 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kRenderDataTag[] = "RENDER_DATA";
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
constexpr char kLabelsTag[] = "LABELS";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
|
||||
constexpr float kFontHeightScale = 1.25f;
|
||||
|
||||
// A calculator takes in pairs of labels and scores or classifications, outputs
|
||||
|
@ -74,20 +80,20 @@ class LabelsToRenderDataCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(LabelsToRenderDataCalculator);
|
||||
|
||||
absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
if (cc->Inputs().HasTag(kClassificationsTag)) {
|
||||
cc->Inputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().HasTag("LABELS"))
|
||||
RET_CHECK(cc->Inputs().HasTag(kLabelsTag))
|
||||
<< "Must provide input stream \"LABELS\"";
|
||||
cc->Inputs().Tag("LABELS").Set<std::vector<std::string>>();
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
||||
cc->Inputs().Tag(kLabelsTag).Set<std::vector<std::string>>();
|
||||
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
|
||||
}
|
||||
}
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||
}
|
||||
cc->Outputs().Tag("RENDER_DATA").Set<RenderData>();
|
||||
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -100,10 +106,10 @@ absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM") &&
|
||||
if (cc->Inputs().HasTag(kVideoPrestreamTag) &&
|
||||
cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
const VideoHeader& video_header =
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
||||
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||
video_width_ = video_header.width;
|
||||
video_height_ = video_header.height;
|
||||
return absl::OkStatus();
|
||||
|
@ -114,9 +120,9 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
std::vector<std::string> labels;
|
||||
std::vector<float> scores;
|
||||
if (cc->Inputs().HasTag("CLASSIFICATIONS")) {
|
||||
if (cc->Inputs().HasTag(kClassificationsTag)) {
|
||||
const ClassificationList& classifications =
|
||||
cc->Inputs().Tag("CLASSIFICATIONS").Get<ClassificationList>();
|
||||
cc->Inputs().Tag(kClassificationsTag).Get<ClassificationList>();
|
||||
labels.resize(classifications.classification_size());
|
||||
scores.resize(classifications.classification_size());
|
||||
for (int i = 0; i < classifications.classification_size(); ++i) {
|
||||
|
@ -129,15 +135,15 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
} else {
|
||||
const std::vector<std::string>& label_vector =
|
||||
cc->Inputs().Tag("LABELS").Get<std::vector<std::string>>();
|
||||
cc->Inputs().Tag(kLabelsTag).Get<std::vector<std::string>>();
|
||||
labels.resize(label_vector.size());
|
||||
for (int i = 0; i < label_vector.size(); ++i) {
|
||||
labels[i] = label_vector[i];
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||
std::vector<float> score_vector =
|
||||
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
||||
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
|
||||
CHECK_EQ(label_vector.size(), score_vector.size());
|
||||
scores.resize(label_vector.size());
|
||||
for (int i = 0; i < label_vector.size(); ++i) {
|
||||
|
@ -169,7 +175,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
auto* text = label_annotation->mutable_text();
|
||||
std::string display_text = labels[i];
|
||||
if (cc->Inputs().HasTag("SCORES")) {
|
||||
if (cc->Inputs().HasTag(kScoresTag)) {
|
||||
absl::StrAppend(&display_text, ":", scores[i]);
|
||||
}
|
||||
text->set_display_text(display_text);
|
||||
|
@ -179,7 +185,7 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
text->set_font_face(options_.font_face());
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag("RENDER_DATA")
|
||||
.Tag(kRenderDataTag)
|
||||
.AddPacket(MakePacket<RenderData>(render_data).At(cc->InputTimestamp()));
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
|
||||
NormalizedLandmark CreateLandmark(float x, float y) {
|
||||
NormalizedLandmark landmark;
|
||||
landmark.set_x(x);
|
||||
|
@ -48,18 +51,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingLeftRight) {
|
|||
*landmarks->add_landmark() = CreateLandmark(0.2f, 0.2f);
|
||||
*landmarks->add_landmark() = CreateLandmark(0.7f, 0.7f);
|
||||
runner.MutableInputs()
|
||||
->Tag("LANDMARKS")
|
||||
->Tag(kLandmarksTag)
|
||||
.packets.push_back(
|
||||
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||
std::array<float, 4>{0.2f, 0.f, 0.3f, 0.f});
|
||||
runner.MutableInputs()
|
||||
->Tag("LETTERBOX_PADDING")
|
||||
->Tag(kLetterboxPaddingTag)
|
||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag(kLandmarksTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
||||
|
||||
|
@ -84,18 +88,19 @@ TEST(LandmarkLetterboxRemovalCalculatorTest, PaddingTopBottom) {
|
|||
landmark = landmarks->add_landmark();
|
||||
*landmark = CreateLandmark(0.7f, 0.7f);
|
||||
runner.MutableInputs()
|
||||
->Tag("LANDMARKS")
|
||||
->Tag(kLandmarksTag)
|
||||
.packets.push_back(
|
||||
Adopt(landmarks.release()).At(Timestamp::PostStream()));
|
||||
|
||||
auto padding = absl::make_unique<std::array<float, 4>>(
|
||||
std::array<float, 4>{0.0f, 0.2f, 0.0f, 0.3f});
|
||||
runner.MutableInputs()
|
||||
->Tag("LETTERBOX_PADDING")
|
||||
->Tag(kLetterboxPaddingTag)
|
||||
.packets.push_back(Adopt(padding.release()).At(Timestamp::PostStream()));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
const std::vector<Packet>& output = runner.Outputs().Tag("LANDMARKS").packets;
|
||||
const std::vector<Packet>& output =
|
||||
runner.Outputs().Tag(kLandmarksTag).packets;
|
||||
ASSERT_EQ(1, output.size());
|
||||
const auto& output_landmarks = output[0].Get<NormalizedLandmarkList>();
|
||||
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
||||
|
||||
absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
||||
mediapipe::NormalizedLandmarkList input, mediapipe::NormalizedRect rect) {
|
||||
mediapipe::CalculatorRunner runner(
|
||||
|
@ -26,17 +30,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
|||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||
)pb"));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_LANDMARKS")
|
||||
->Tag(kNormLandmarksTag)
|
||||
.packets.push_back(
|
||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||
.At(Timestamp(1)));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_RECT")
|
||||
->Tag(kNormRectTag)
|
||||
.packets.push_back(MakePacket<mediapipe::NormalizedRect>(std::move(rect))
|
||||
.At(Timestamp(1)));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
|
||||
RET_CHECK_EQ(output_packets.size(), 1);
|
||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||
}
|
||||
|
@ -104,17 +108,17 @@ absl::StatusOr<mediapipe::NormalizedLandmarkList> RunCalculator(
|
|||
output_stream: "NORM_LANDMARKS:projected_landmarks"
|
||||
)pb"));
|
||||
runner.MutableInputs()
|
||||
->Tag("NORM_LANDMARKS")
|
||||
->Tag(kNormLandmarksTag)
|
||||
.packets.push_back(
|
||||
MakePacket<mediapipe::NormalizedLandmarkList>(std::move(input))
|
||||
.At(Timestamp(1)));
|
||||
runner.MutableInputs()
|
||||
->Tag("PROJECTION_MATRIX")
|
||||
->Tag(kProjectionMatrixTag)
|
||||
.packets.push_back(MakePacket<std::array<float, 16>>(std::move(matrix))
|
||||
.At(Timestamp(1)));
|
||||
|
||||
MP_RETURN_IF_ERROR(runner.Run());
|
||||
const auto& output_packets = runner.Outputs().Tag("NORM_LANDMARKS").packets;
|
||||
const auto& output_packets = runner.Outputs().Tag(kNormLandmarksTag).packets;
|
||||
RET_CHECK_EQ(output_packets.size(), 1);
|
||||
return output_packets[0].Get<mediapipe::NormalizedLandmarkList>();
|
||||
}
|
||||
|
|
|
@ -20,6 +20,11 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kContentsTag[] = "CONTENTS";
|
||||
constexpr char kFileSuffixTag[] = "FILE_SUFFIX";
|
||||
constexpr char kFileDirectoryTag[] = "FILE_DIRECTORY";
|
||||
|
||||
// The calculator takes the path to local directory and desired file suffix to
|
||||
// mach as input side packets, and outputs the contents of those files that
|
||||
// match the pattern. Those matched files will be sent sequentially through the
|
||||
|
@ -35,16 +40,16 @@ namespace mediapipe {
|
|||
class LocalFilePatternContentsCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("FILE_DIRECTORY").Set<std::string>();
|
||||
cc->InputSidePackets().Tag("FILE_SUFFIX").Set<std::string>();
|
||||
cc->Outputs().Tag("CONTENTS").Set<std::string>();
|
||||
cc->InputSidePackets().Tag(kFileDirectoryTag).Set<std::string>();
|
||||
cc->InputSidePackets().Tag(kFileSuffixTag).Set<std::string>();
|
||||
cc->Outputs().Tag(kContentsTag).Set<std::string>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory(
|
||||
cc->InputSidePackets().Tag("FILE_DIRECTORY").Get<std::string>(),
|
||||
cc->InputSidePackets().Tag("FILE_SUFFIX").Get<std::string>(),
|
||||
cc->InputSidePackets().Tag(kFileDirectoryTag).Get<std::string>(),
|
||||
cc->InputSidePackets().Tag(kFileSuffixTag).Get<std::string>(),
|
||||
&filenames_));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -57,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase {
|
|||
filenames_[current_output_], contents.get()));
|
||||
++current_output_;
|
||||
cc->Outputs()
|
||||
.Tag("CONTENTS")
|
||||
.Tag(kContentsTag)
|
||||
.Add(contents.release(), Timestamp(current_output_));
|
||||
} else {
|
||||
return tool::StatusStop();
|
||||
|
|
|
@ -217,7 +217,7 @@ absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) {
|
|||
// Initialize the clock.
|
||||
if (cc->InputSidePackets().HasTag(kClockTag)) {
|
||||
clock_ = cc->InputSidePackets()
|
||||
.Tag("CLOCK")
|
||||
.Tag(kClockTag)
|
||||
.Get<std::shared_ptr<::mediapipe::Clock>>();
|
||||
} else {
|
||||
clock_ = std::shared_ptr<::mediapipe::Clock>(
|
||||
|
|
|
@ -17,6 +17,12 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kThresholdTag[] = "THRESHOLD";
|
||||
constexpr char kRejectTag[] = "REJECT";
|
||||
constexpr char kAcceptTag[] = "ACCEPT";
|
||||
constexpr char kFlagTag[] = "FLAG";
|
||||
constexpr char kFloatTag[] = "FLOAT";
|
||||
|
||||
// Applies a threshold on a stream of numeric values and outputs a flag and/or
|
||||
// accept/reject stream. The threshold can be specified by one of the following:
|
||||
// 1) Input stream.
|
||||
|
@ -61,24 +67,24 @@ class ThresholdingCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(ThresholdingCalculator);
|
||||
|
||||
absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("FLOAT"));
|
||||
cc->Inputs().Tag("FLOAT").Set<float>();
|
||||
RET_CHECK(cc->Inputs().HasTag(kFloatTag));
|
||||
cc->Inputs().Tag(kFloatTag).Set<float>();
|
||||
|
||||
if (cc->Outputs().HasTag("FLAG")) {
|
||||
cc->Outputs().Tag("FLAG").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kFlagTag)) {
|
||||
cc->Outputs().Tag(kFlagTag).Set<bool>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("ACCEPT")) {
|
||||
cc->Outputs().Tag("ACCEPT").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kAcceptTag)) {
|
||||
cc->Outputs().Tag(kAcceptTag).Set<bool>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("REJECT")) {
|
||||
cc->Outputs().Tag("REJECT").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kRejectTag)) {
|
||||
cc->Outputs().Tag(kRejectTag).Set<bool>();
|
||||
}
|
||||
if (cc->Inputs().HasTag("THRESHOLD")) {
|
||||
cc->Inputs().Tag("THRESHOLD").Set<double>();
|
||||
if (cc->Inputs().HasTag(kThresholdTag)) {
|
||||
cc->Inputs().Tag(kThresholdTag).Set<double>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
||||
cc->InputSidePackets().Tag("THRESHOLD").Set<double>();
|
||||
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
|
||||
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
|
||||
cc->InputSidePackets().Tag(kThresholdTag).Set<double>();
|
||||
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
|
||||
<< "Using both the threshold input side packet and input stream is not "
|
||||
"supported.";
|
||||
}
|
||||
|
@ -92,43 +98,45 @@ absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) {
|
|||
const auto& options =
|
||||
cc->Options<::mediapipe::ThresholdingCalculatorOptions>();
|
||||
if (options.has_threshold()) {
|
||||
RET_CHECK(!cc->Inputs().HasTag("THRESHOLD"))
|
||||
RET_CHECK(!cc->Inputs().HasTag(kThresholdTag))
|
||||
<< "Using both the threshold option and input stream is not supported.";
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag("THRESHOLD"))
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag(kThresholdTag))
|
||||
<< "Using both the threshold option and input side packet is not "
|
||||
"supported.";
|
||||
threshold_ = options.threshold();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("THRESHOLD")) {
|
||||
threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get<double>();
|
||||
if (cc->InputSidePackets().HasTag(kThresholdTag)) {
|
||||
threshold_ = cc->InputSidePackets().Tag(kThresholdTag).Get<double>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->Inputs().HasTag("THRESHOLD") &&
|
||||
!cc->Inputs().Tag("THRESHOLD").IsEmpty()) {
|
||||
threshold_ = cc->Inputs().Tag("THRESHOLD").Get<double>();
|
||||
if (cc->Inputs().HasTag(kThresholdTag) &&
|
||||
!cc->Inputs().Tag(kThresholdTag).IsEmpty()) {
|
||||
threshold_ = cc->Inputs().Tag(kThresholdTag).Get<double>();
|
||||
}
|
||||
|
||||
bool accept = false;
|
||||
RET_CHECK(!cc->Inputs().Tag("FLOAT").IsEmpty());
|
||||
accept =
|
||||
static_cast<double>(cc->Inputs().Tag("FLOAT").Get<float>()) > threshold_;
|
||||
RET_CHECK(!cc->Inputs().Tag(kFloatTag).IsEmpty());
|
||||
accept = static_cast<double>(cc->Inputs().Tag(kFloatTag).Get<float>()) >
|
||||
threshold_;
|
||||
|
||||
if (cc->Outputs().HasTag("FLAG")) {
|
||||
cc->Outputs().Tag("FLAG").AddPacket(
|
||||
if (cc->Outputs().HasTag(kFlagTag)) {
|
||||
cc->Outputs().Tag(kFlagTag).AddPacket(
|
||||
MakePacket<bool>(accept).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (accept && cc->Outputs().HasTag("ACCEPT")) {
|
||||
cc->Outputs().Tag("ACCEPT").AddPacket(
|
||||
MakePacket<bool>(true).At(cc->InputTimestamp()));
|
||||
if (accept && cc->Outputs().HasTag(kAcceptTag)) {
|
||||
cc->Outputs()
|
||||
.Tag(kAcceptTag)
|
||||
.AddPacket(MakePacket<bool>(true).At(cc->InputTimestamp()));
|
||||
}
|
||||
if (!accept && cc->Outputs().HasTag("REJECT")) {
|
||||
cc->Outputs().Tag("REJECT").AddPacket(
|
||||
MakePacket<bool>(false).At(cc->InputTimestamp()));
|
||||
if (!accept && cc->Outputs().HasTag(kRejectTag)) {
|
||||
cc->Outputs()
|
||||
.Tag(kRejectTag)
|
||||
.AddPacket(MakePacket<bool>(false).At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -39,6 +39,14 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
|
||||
constexpr char kSummaryTag[] = "SUMMARY";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
|
||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
|
||||
// A calculator that takes a vector of scores and returns the indexes, scores,
|
||||
// labels of the top k elements, classification protos, and summary std::string
|
||||
// (in csv format).
|
||||
|
@ -79,22 +87,22 @@ class TopKScoresCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(TopKScoresCalculator);
|
||||
|
||||
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("SCORES"));
|
||||
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
|
||||
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
|
||||
cc->Outputs().Tag("TOP_K_INDEXES").Set<std::vector<int>>();
|
||||
RET_CHECK(cc->Inputs().HasTag(kScoresTag));
|
||||
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
|
||||
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
|
||||
cc->Outputs().Tag(kTopKIndexesTag).Set<std::vector<int>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
|
||||
cc->Outputs().Tag("TOP_K_SCORES").Set<std::vector<float>>();
|
||||
if (cc->Outputs().HasTag(kTopKScoresTag)) {
|
||||
cc->Outputs().Tag(kTopKScoresTag).Set<std::vector<float>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
|
||||
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("CLASSIFICATIONS")) {
|
||||
cc->Outputs().Tag("CLASSIFICATIONS").Set<ClassificationList>();
|
||||
if (cc->Outputs().HasTag(kClassificationsTag)) {
|
||||
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
cc->Outputs().Tag("SUMMARY").Set<std::string>();
|
||||
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -114,7 +122,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
|
|||
if (options.has_label_map_path()) {
|
||||
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||
RET_CHECK(!label_map_.empty());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -122,7 +130,7 @@ absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
||||
const std::vector<float>& input_vector =
|
||||
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
|
||||
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
|
||||
std::vector<int> top_k_indexes;
|
||||
|
||||
std::vector<float> top_k_scores;
|
||||
|
@ -166,26 +174,26 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
|||
top_k_labels.push_back(label_map_[index]);
|
||||
}
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
|
||||
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TOP_K_INDEXES")
|
||||
.Tag(kTopKIndexesTag)
|
||||
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
|
||||
if (cc->Outputs().HasTag(kTopKScoresTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TOP_K_SCORES")
|
||||
.Tag(kTopKScoresTag)
|
||||
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
|
||||
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TOP_K_LABELS")
|
||||
.Tag(kTopKLabelsTag)
|
||||
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("SUMMARY")) {
|
||||
if (cc->Outputs().HasTag(kSummaryTag)) {
|
||||
std::vector<std::string> results;
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
if (label_map_loaded_) {
|
||||
|
@ -196,12 +204,13 @@ absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
|
|||
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
|
||||
}
|
||||
}
|
||||
cc->Outputs().Tag("SUMMARY").AddPacket(
|
||||
MakePacket<std::string>(absl::StrJoin(results, ","))
|
||||
.At(cc->InputTimestamp()));
|
||||
cc->Outputs()
|
||||
.Tag(kSummaryTag)
|
||||
.AddPacket(MakePacket<std::string>(absl::StrJoin(results, ","))
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TOP_K_CLASSIFICATION")) {
|
||||
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int index = 0; index < top_k_indexes.size(); ++index) {
|
||||
Classification* classification =
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
|
||||
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
|
||||
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "TopKScoresCalculator"
|
||||
|
@ -55,19 +59,21 @@ TEST(TopKScoresCalculatorTest, TestTopKOnly) {
|
|||
|
||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||
|
||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kScoresTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& indexes_outputs =
|
||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
||||
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||
ASSERT_EQ(1, indexes_outputs.size());
|
||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(2, indexes.size());
|
||||
EXPECT_EQ(3, indexes[0]);
|
||||
EXPECT_EQ(0, indexes[1]);
|
||||
const std::vector<Packet>& scores_outputs =
|
||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
||||
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||
ASSERT_EQ(1, scores_outputs.size());
|
||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||
EXPECT_EQ(2, scores.size());
|
||||
|
@ -88,12 +94,14 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
|
|||
|
||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||
|
||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kScoresTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& indexes_outputs =
|
||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
||||
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||
ASSERT_EQ(1, indexes_outputs.size());
|
||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(4, indexes.size());
|
||||
|
@ -102,7 +110,7 @@ TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
|
|||
EXPECT_EQ(2, indexes[2]);
|
||||
EXPECT_EQ(1, indexes[3]);
|
||||
const std::vector<Packet>& scores_outputs =
|
||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
||||
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||
ASSERT_EQ(1, scores_outputs.size());
|
||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||
EXPECT_EQ(4, scores.size());
|
||||
|
@ -125,12 +133,14 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
|||
|
||||
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
|
||||
|
||||
runner.MutableInputs()->Tag("SCORES").packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kScoresTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
const std::vector<Packet>& indexes_outputs =
|
||||
runner.Outputs().Tag("TOP_K_INDEXES").packets;
|
||||
runner.Outputs().Tag(kTopKIndexesTag).packets;
|
||||
ASSERT_EQ(1, indexes_outputs.size());
|
||||
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
|
||||
EXPECT_EQ(3, indexes.size());
|
||||
|
@ -138,7 +148,7 @@ TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
|
|||
EXPECT_EQ(0, indexes[1]);
|
||||
EXPECT_EQ(2, indexes[2]);
|
||||
const std::vector<Packet>& scores_outputs =
|
||||
runner.Outputs().Tag("TOP_K_SCORES").packets;
|
||||
runner.Outputs().Tag(kTopKScoresTag).packets;
|
||||
ASSERT_EQ(1, scores_outputs.size());
|
||||
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
|
||||
EXPECT_EQ(3, scores.size());
|
||||
|
|
|
@ -47,6 +47,21 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kFrameAlignmentTag[] = "FRAME_ALIGNMENT";
|
||||
constexpr char kOutputIndexFilenameTag[] = "OUTPUT_INDEX_FILENAME";
|
||||
constexpr char kIndexProtoStringTag[] = "INDEX_PROTO_STRING";
|
||||
constexpr char kVizTag[] = "VIZ";
|
||||
constexpr char kBoxesTag[] = "BOXES";
|
||||
constexpr char kReacqSwitchTag[] = "REACQ_SWITCH";
|
||||
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
|
||||
constexpr char kAddIndexTag[] = "ADD_INDEX";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kDescriptorsTag[] = "DESCRIPTORS";
|
||||
constexpr char kFeaturesTag[] = "FEATURES";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr char kTrackedBoxesTag[] = "TRACKED_BOXES";
|
||||
constexpr char kTrackingTag[] = "TRACKING";
|
||||
|
||||
// A calculator to detect reappeared box positions from single frame.
|
||||
//
|
||||
// Input stream:
|
||||
|
@ -110,66 +125,66 @@ class BoxDetectorCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(BoxDetectorCalculator);
|
||||
|
||||
absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("TRACKING")) {
|
||||
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
|
||||
if (cc->Inputs().HasTag(kTrackingTag)) {
|
||||
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("TRACKED_BOXES")) {
|
||||
cc->Inputs().Tag("TRACKED_BOXES").Set<TimedBoxProtoList>();
|
||||
if (cc->Inputs().HasTag(kTrackedBoxesTag)) {
|
||||
cc->Inputs().Tag(kTrackedBoxesTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("VIDEO")) {
|
||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("FEATURES")) {
|
||||
RET_CHECK(cc->Inputs().HasTag("DESCRIPTORS"))
|
||||
if (cc->Inputs().HasTag(kFeaturesTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kDescriptorsTag))
|
||||
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
||||
cc->Inputs().Tag("FEATURES").Set<std::vector<cv::KeyPoint>>();
|
||||
cc->Inputs().Tag(kFeaturesTag).Set<std::vector<cv::KeyPoint>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("DESCRIPTORS")) {
|
||||
RET_CHECK(cc->Inputs().HasTag("FEATURES"))
|
||||
if (cc->Inputs().HasTag(kDescriptorsTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kFeaturesTag))
|
||||
<< "FEATURES and DESCRIPTORS need to be specified together.";
|
||||
cc->Inputs().Tag("DESCRIPTORS").Set<std::vector<float>>();
|
||||
cc->Inputs().Tag(kDescriptorsTag).Set<std::vector<float>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("IMAGE_SIZE")) {
|
||||
cc->Inputs().Tag("IMAGE_SIZE").Set<std::pair<int, int>>();
|
||||
if (cc->Inputs().HasTag(kImageSizeTag)) {
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("ADD_INDEX")) {
|
||||
cc->Inputs().Tag("ADD_INDEX").Set<std::string>();
|
||||
if (cc->Inputs().HasTag(kAddIndexTag)) {
|
||||
cc->Inputs().Tag(kAddIndexTag).Set<std::string>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
|
||||
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
|
||||
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
|
||||
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("REACQ_SWITCH")) {
|
||||
cc->Inputs().Tag("REACQ_SWITCH").Set<bool>();
|
||||
if (cc->Inputs().HasTag(kReacqSwitchTag)) {
|
||||
cc->Inputs().Tag(kReacqSwitchTag).Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("BOXES")) {
|
||||
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
|
||||
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("VIZ")) {
|
||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
|
||||
if (cc->Outputs().HasTag(kVizTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
|
||||
<< "Output stream VIZ requires VIDEO to be present.";
|
||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
||||
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
|
||||
cc->InputSidePackets().Tag("INDEX_PROTO_STRING").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
|
||||
cc->InputSidePackets().Tag(kIndexProtoStringTag).Set<std::string>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
|
||||
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
|
||||
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Set<std::string>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
|
||||
cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set<int>();
|
||||
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
|
||||
cc->InputSidePackets().Tag(kFrameAlignmentTag).Set<int>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -179,10 +194,10 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
|
|||
options_ = cc->Options<BoxDetectorCalculatorOptions>();
|
||||
box_detector_ = BoxDetectorInterface::Create(options_.detector_options());
|
||||
|
||||
if (cc->InputSidePackets().HasTag("INDEX_PROTO_STRING")) {
|
||||
if (cc->InputSidePackets().HasTag(kIndexProtoStringTag)) {
|
||||
BoxDetectorIndex predefined_index;
|
||||
if (!predefined_index.ParseFromString(cc->InputSidePackets()
|
||||
.Tag("INDEX_PROTO_STRING")
|
||||
.Tag(kIndexProtoStringTag)
|
||||
.Get<std::string>())) {
|
||||
LOG(FATAL) << "failed to parse BoxDetectorIndex from INDEX_PROTO_STRING";
|
||||
}
|
||||
|
@ -202,12 +217,13 @@ absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) {
|
|||
box_detector_->AddBoxDetectorIndex(predefined_index);
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("OUTPUT_INDEX_FILENAME")) {
|
||||
if (cc->InputSidePackets().HasTag(kOutputIndexFilenameTag)) {
|
||||
write_index_ = true;
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("FRAME_ALIGNMENT")) {
|
||||
frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get<int>();
|
||||
if (cc->InputSidePackets().HasTag(kFrameAlignmentTag)) {
|
||||
frame_alignment_ =
|
||||
cc->InputSidePackets().Tag(kFrameAlignmentTag).Get<int>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -218,16 +234,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
const int64 timestamp_msec = timestamp.Value() / 1000;
|
||||
|
||||
InputStream* cancel_object_id_stream =
|
||||
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
|
||||
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
|
||||
cc->Inputs().HasTag(kCancelObjectIdTag)
|
||||
? &(cc->Inputs().Tag(kCancelObjectIdTag))
|
||||
: nullptr;
|
||||
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
||||
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
||||
box_detector_->CancelBoxDetection(cancel_object_id);
|
||||
}
|
||||
|
||||
InputStream* add_index_stream = cc->Inputs().HasTag("ADD_INDEX")
|
||||
? &(cc->Inputs().Tag("ADD_INDEX"))
|
||||
InputStream* add_index_stream = cc->Inputs().HasTag(kAddIndexTag)
|
||||
? &(cc->Inputs().Tag(kAddIndexTag))
|
||||
: nullptr;
|
||||
if (add_index_stream && !add_index_stream->IsEmpty()) {
|
||||
BoxDetectorIndex predefined_index;
|
||||
|
@ -238,8 +254,8 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
box_detector_->AddBoxDetectorIndex(predefined_index);
|
||||
}
|
||||
|
||||
InputStream* reacq_switch_stream = cc->Inputs().HasTag("REACQ_SWITCH")
|
||||
? &(cc->Inputs().Tag("REACQ_SWITCH"))
|
||||
InputStream* reacq_switch_stream = cc->Inputs().HasTag(kReacqSwitchTag)
|
||||
? &(cc->Inputs().Tag(kReacqSwitchTag))
|
||||
: nullptr;
|
||||
if (reacq_switch_stream && !reacq_switch_stream->IsEmpty()) {
|
||||
detector_switch_ = reacq_switch_stream->Get<bool>();
|
||||
|
@ -249,16 +265,16 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
|
||||
? &(cc->Inputs().Tag("TRACKING"))
|
||||
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
|
||||
? &(cc->Inputs().Tag(kTrackingTag))
|
||||
: nullptr;
|
||||
InputStream* video_stream =
|
||||
cc->Inputs().HasTag("VIDEO") ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
|
||||
InputStream* feature_stream = cc->Inputs().HasTag("FEATURES")
|
||||
? &(cc->Inputs().Tag("FEATURES"))
|
||||
cc->Inputs().HasTag(kVideoTag) ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
|
||||
InputStream* feature_stream = cc->Inputs().HasTag(kFeaturesTag)
|
||||
? &(cc->Inputs().Tag(kFeaturesTag))
|
||||
: nullptr;
|
||||
InputStream* descriptor_stream = cc->Inputs().HasTag("DESCRIPTORS")
|
||||
? &(cc->Inputs().Tag("DESCRIPTORS"))
|
||||
InputStream* descriptor_stream = cc->Inputs().HasTag(kDescriptorsTag)
|
||||
? &(cc->Inputs().Tag(kDescriptorsTag))
|
||||
: nullptr;
|
||||
|
||||
CHECK(track_stream != nullptr || video_stream != nullptr ||
|
||||
|
@ -266,9 +282,10 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
<< "One and only one of {tracking_data, input image frame, "
|
||||
"feature/descriptor} need to be valid.";
|
||||
|
||||
InputStream* tracked_boxes_stream = cc->Inputs().HasTag("TRACKED_BOXES")
|
||||
? &(cc->Inputs().Tag("TRACKED_BOXES"))
|
||||
: nullptr;
|
||||
InputStream* tracked_boxes_stream =
|
||||
cc->Inputs().HasTag(kTrackedBoxesTag)
|
||||
? &(cc->Inputs().Tag(kTrackedBoxesTag))
|
||||
: nullptr;
|
||||
std::unique_ptr<TimedBoxProtoList> detected_boxes(new TimedBoxProtoList());
|
||||
|
||||
if (track_stream != nullptr) {
|
||||
|
@ -309,7 +326,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
const auto& image_size =
|
||||
cc->Inputs().Tag("IMAGE_SIZE").Get<std::pair<int, int>>();
|
||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||
float inv_scale = 1.0f / std::max(image_size.first, image_size.second);
|
||||
|
||||
TimedBoxProtoList tracked_boxes;
|
||||
|
@ -359,7 +376,7 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
detected_boxes.get());
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("VIZ")) {
|
||||
if (cc->Outputs().HasTag(kVizTag)) {
|
||||
cv::Mat viz_view;
|
||||
std::unique_ptr<ImageFrame> viz_frame;
|
||||
if (video_stream != nullptr && !video_stream->IsEmpty()) {
|
||||
|
@ -370,11 +387,11 @@ absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) {
|
|||
for (const auto& box : detected_boxes->box()) {
|
||||
RenderBox(box, &viz_view);
|
||||
}
|
||||
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
|
||||
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("BOXES")) {
|
||||
cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp);
|
||||
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||
cc->Outputs().Tag(kBoxesTag).Add(detected_boxes.release(), timestamp);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -384,7 +401,7 @@ absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) {
|
|||
if (write_index_) {
|
||||
BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex();
|
||||
MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents(
|
||||
cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get<std::string>(),
|
||||
cc->InputSidePackets().Tag(kOutputIndexFilenameTag).Get<std::string>(),
|
||||
index.SerializeAsString()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -293,6 +293,22 @@ const int BoxTrackerCalculator::kMotionBoxPathMinQueueSize = 2;
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kCacheDirTag[] = "CACHE_DIR";
|
||||
constexpr char kInitialPosTag[] = "INITIAL_POS";
|
||||
constexpr char kRaBoxesTag[] = "RA_BOXES";
|
||||
constexpr char kBoxesTag[] = "BOXES";
|
||||
constexpr char kVizTag[] = "VIZ";
|
||||
constexpr char kRaTrackProtoStringTag[] = "RA_TRACK_PROTO_STRING";
|
||||
constexpr char kRaTrackTag[] = "RA_TRACK";
|
||||
constexpr char kCancelObjectIdTag[] = "CANCEL_OBJECT_ID";
|
||||
constexpr char kRestartPosTag[] = "RESTART_POS";
|
||||
constexpr char kStartPosProtoStringTag[] = "START_POS_PROTO_STRING";
|
||||
constexpr char kStartPosTag[] = "START_POS";
|
||||
constexpr char kStartTag[] = "START";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr char kTrackTimeTag[] = "TRACK_TIME";
|
||||
constexpr char kTrackingTag[] = "TRACKING";
|
||||
|
||||
// Convert box position according to rotation angle in degrees.
|
||||
void ConvertCoordinateForRotation(float in_top, float in_left, float in_bottom,
|
||||
float in_right, int rotation, float* out_top,
|
||||
|
@ -374,78 +390,78 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec,
|
|||
} // namespace.
|
||||
|
||||
absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("TRACKING")) {
|
||||
cc->Inputs().Tag("TRACKING").Set<TrackingData>();
|
||||
if (cc->Inputs().HasTag(kTrackingTag)) {
|
||||
cc->Inputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("TRACK_TIME")) {
|
||||
RET_CHECK(cc->Inputs().HasTag("TRACKING"))
|
||||
if (cc->Inputs().HasTag(kTrackTimeTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTrackingTag))
|
||||
<< "TRACK_TIME needs TRACKING input";
|
||||
cc->Inputs().Tag("TRACK_TIME").SetAny();
|
||||
cc->Inputs().Tag(kTrackTimeTag).SetAny();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("VIDEO")) {
|
||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("START")) {
|
||||
if (cc->Inputs().HasTag(kStartTag)) {
|
||||
// Actual packet content does not matter.
|
||||
cc->Inputs().Tag("START").SetAny();
|
||||
cc->Inputs().Tag(kStartTag).SetAny();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("START_POS")) {
|
||||
cc->Inputs().Tag("START_POS").Set<TimedBoxProtoList>();
|
||||
if (cc->Inputs().HasTag(kStartPosTag)) {
|
||||
cc->Inputs().Tag(kStartPosTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("START_POS_PROTO_STRING")) {
|
||||
cc->Inputs().Tag("START_POS_PROTO_STRING").Set<std::string>();
|
||||
if (cc->Inputs().HasTag(kStartPosProtoStringTag)) {
|
||||
cc->Inputs().Tag(kStartPosProtoStringTag).Set<std::string>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("RESTART_POS")) {
|
||||
cc->Inputs().Tag("RESTART_POS").Set<TimedBoxProtoList>();
|
||||
if (cc->Inputs().HasTag(kRestartPosTag)) {
|
||||
cc->Inputs().Tag(kRestartPosTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("CANCEL_OBJECT_ID")) {
|
||||
cc->Inputs().Tag("CANCEL_OBJECT_ID").Set<int>();
|
||||
if (cc->Inputs().HasTag(kCancelObjectIdTag)) {
|
||||
cc->Inputs().Tag(kCancelObjectIdTag).Set<int>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("RA_TRACK")) {
|
||||
cc->Inputs().Tag("RA_TRACK").Set<TimedBoxProtoList>();
|
||||
if (cc->Inputs().HasTag(kRaTrackTag)) {
|
||||
cc->Inputs().Tag(kRaTrackTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")) {
|
||||
cc->Inputs().Tag("RA_TRACK_PROTO_STRING").Set<std::string>();
|
||||
if (cc->Inputs().HasTag(kRaTrackProtoStringTag)) {
|
||||
cc->Inputs().Tag(kRaTrackProtoStringTag).Set<std::string>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("VIZ")) {
|
||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"))
|
||||
if (cc->Outputs().HasTag(kVizTag)) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kVideoTag))
|
||||
<< "Output stream VIZ requires VIDEO to be present.";
|
||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
||||
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("BOXES")) {
|
||||
cc->Outputs().Tag("BOXES").Set<TimedBoxProtoList>();
|
||||
if (cc->Outputs().HasTag(kBoxesTag)) {
|
||||
cc->Outputs().Tag(kBoxesTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("RA_BOXES")) {
|
||||
cc->Outputs().Tag("RA_BOXES").Set<TimedBoxProtoList>();
|
||||
if (cc->Outputs().HasTag(kRaBoxesTag)) {
|
||||
cc->Outputs().Tag(kRaBoxesTag).Set<TimedBoxProtoList>();
|
||||
}
|
||||
|
||||
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS"))
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag))
|
||||
<< "Unsupported on mobile";
|
||||
#else
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
|
||||
cc->InputSidePackets().Tag("INITIAL_POS").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
|
||||
cc->InputSidePackets().Tag(kInitialPosTag).Set<std::string>();
|
||||
}
|
||||
#endif // defined(__ANDROID__) || defined(__APPLE__) || defined(__EMSCRIPTEN__)
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
||||
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->Inputs().HasTag("TRACKING") !=
|
||||
cc->InputSidePackets().HasTag("CACHE_DIR"))
|
||||
RET_CHECK(cc->Inputs().HasTag(kTrackingTag) !=
|
||||
cc->InputSidePackets().HasTag(kCacheDirTag))
|
||||
<< "Either TRACKING or CACHE_DIR needs to be specified.";
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||
|
@ -459,7 +475,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
|||
options_ = tool::RetrieveOptions(cc->Options<BoxTrackerCalculatorOptions>(),
|
||||
cc->InputSidePackets(), kOptionsTag);
|
||||
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag("INITIAL_POS") ||
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag(kInitialPosTag) ||
|
||||
!options_.has_initial_position())
|
||||
<< "Can not specify initial position as side packet and via options";
|
||||
|
||||
|
@ -468,11 +484,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
#if !defined(__ANDROID__) && !defined(__APPLE__) && !defined(__EMSCRIPTEN__)
|
||||
if (cc->InputSidePackets().HasTag("INITIAL_POS")) {
|
||||
if (cc->InputSidePackets().HasTag(kInitialPosTag)) {
|
||||
LOG(INFO) << "Parsing: "
|
||||
<< cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>();
|
||||
<< cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>();
|
||||
initial_pos_ = ParseTextProtoOrDie<TimedBoxProtoList>(
|
||||
cc->InputSidePackets().Tag("INITIAL_POS").Get<std::string>());
|
||||
cc->InputSidePackets().Tag(kInitialPosTag).Get<std::string>());
|
||||
}
|
||||
#endif // !defined(__ANDROID__) && !defined(__APPLE__) &&
|
||||
// !defined(__EMSCRIPTEN__)
|
||||
|
@ -484,10 +500,11 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
visualize_tracking_data_ =
|
||||
options_.visualize_tracking_data() && cc->Outputs().HasTag("VIZ");
|
||||
visualize_state_ = options_.visualize_state() && cc->Outputs().HasTag("VIZ");
|
||||
options_.visualize_tracking_data() && cc->Outputs().HasTag(kVizTag);
|
||||
visualize_state_ =
|
||||
options_.visualize_state() && cc->Outputs().HasTag(kVizTag);
|
||||
visualize_internal_state_ =
|
||||
options_.visualize_internal_state() && cc->Outputs().HasTag("VIZ");
|
||||
options_.visualize_internal_state() && cc->Outputs().HasTag(kVizTag);
|
||||
|
||||
// Force recording of internal state for rendering.
|
||||
if (visualize_internal_state_) {
|
||||
|
@ -500,8 +517,8 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
|||
options_.mutable_tracker_options()->set_record_path_states(true);
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
||||
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
|
||||
RET_CHECK(!cache_dir_.empty());
|
||||
box_tracker_.reset(new BoxTracker(cache_dir_, options_.tracker_options()));
|
||||
} else {
|
||||
|
@ -511,7 +528,7 @@ absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
if (options_.streaming_track_data_cache_size() > 0) {
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag("CACHE_DIR"))
|
||||
RET_CHECK(!cc->InputSidePackets().HasTag(kCacheDirTag))
|
||||
<< "Streaming mode not compatible with cache dir.";
|
||||
}
|
||||
|
||||
|
@ -533,11 +550,11 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
InputStream* track_stream = cc->Inputs().HasTag("TRACKING")
|
||||
? &(cc->Inputs().Tag("TRACKING"))
|
||||
InputStream* track_stream = cc->Inputs().HasTag(kTrackingTag)
|
||||
? &(cc->Inputs().Tag(kTrackingTag))
|
||||
: nullptr;
|
||||
InputStream* track_time_stream = cc->Inputs().HasTag("TRACK_TIME")
|
||||
? &(cc->Inputs().Tag("TRACK_TIME"))
|
||||
InputStream* track_time_stream = cc->Inputs().HasTag(kTrackTimeTag)
|
||||
? &(cc->Inputs().Tag(kTrackTimeTag))
|
||||
: nullptr;
|
||||
|
||||
// Cache tracking data if possible.
|
||||
|
@ -562,8 +579,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
InputStream* start_pos_stream = cc->Inputs().HasTag("START_POS")
|
||||
? &(cc->Inputs().Tag("START_POS"))
|
||||
InputStream* start_pos_stream = cc->Inputs().HasTag(kStartPosTag)
|
||||
? &(cc->Inputs().Tag(kStartPosTag))
|
||||
: nullptr;
|
||||
|
||||
MotionBoxMap fast_forward_boxes;
|
||||
|
@ -575,8 +592,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
InputStream* start_pos_proto_string_stream =
|
||||
cc->Inputs().HasTag("START_POS_PROTO_STRING")
|
||||
? &(cc->Inputs().Tag("START_POS_PROTO_STRING"))
|
||||
cc->Inputs().HasTag(kStartPosProtoStringTag)
|
||||
? &(cc->Inputs().Tag(kStartPosProtoStringTag))
|
||||
: nullptr;
|
||||
if (start_pos_stream == nullptr || start_pos_stream->IsEmpty()) {
|
||||
if (start_pos_proto_string_stream &&
|
||||
|
@ -589,8 +606,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
InputStream* restart_pos_stream = cc->Inputs().HasTag("RESTART_POS")
|
||||
? &(cc->Inputs().Tag("RESTART_POS"))
|
||||
InputStream* restart_pos_stream = cc->Inputs().HasTag(kRestartPosTag)
|
||||
? &(cc->Inputs().Tag(kRestartPosTag))
|
||||
: nullptr;
|
||||
|
||||
if (restart_pos_stream && !restart_pos_stream->IsEmpty()) {
|
||||
|
@ -600,8 +617,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
InputStream* cancel_object_id_stream =
|
||||
cc->Inputs().HasTag("CANCEL_OBJECT_ID")
|
||||
? &(cc->Inputs().Tag("CANCEL_OBJECT_ID"))
|
||||
cc->Inputs().HasTag(kCancelObjectIdTag)
|
||||
? &(cc->Inputs().Tag(kCancelObjectIdTag))
|
||||
: nullptr;
|
||||
if (cancel_object_id_stream && !cancel_object_id_stream->IsEmpty()) {
|
||||
const int cancel_object_id = cancel_object_id_stream->Get<int>();
|
||||
|
@ -616,8 +633,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
TrackingData track_data_to_render;
|
||||
|
||||
if (cc->Outputs().HasTag("VIZ")) {
|
||||
InputStream* video_stream = &(cc->Inputs().Tag("VIDEO"));
|
||||
if (cc->Outputs().HasTag(kVizTag)) {
|
||||
InputStream* video_stream = &(cc->Inputs().Tag(kVideoTag));
|
||||
if (!video_stream->IsEmpty()) {
|
||||
input_view = formats::MatView(&video_stream->Get<ImageFrame>());
|
||||
|
||||
|
@ -745,7 +762,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
++frame_num_since_reset_;
|
||||
|
||||
// Generate results for queued up request.
|
||||
if (cc->Outputs().HasTag("BOXES") && !queued_track_requests_.empty()) {
|
||||
if (cc->Outputs().HasTag(kBoxesTag) && !queued_track_requests_.empty()) {
|
||||
for (int j = 0; j < queued_track_requests_.size(); ++j) {
|
||||
const Timestamp& past_time = queued_track_requests_[j];
|
||||
RET_CHECK(past_time.Value() < timestamp.Value())
|
||||
|
@ -770,7 +787,7 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
// Output for every time.
|
||||
cc->Outputs().Tag("BOXES").Add(past_box_list.release(), past_time);
|
||||
cc->Outputs().Tag(kBoxesTag).Add(past_box_list.release(), past_time);
|
||||
}
|
||||
|
||||
queued_track_requests_.clear();
|
||||
|
@ -845,8 +862,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
// Handle random access track requests.
|
||||
InputStream* ra_track_stream = cc->Inputs().HasTag("RA_TRACK")
|
||||
? &(cc->Inputs().Tag("RA_TRACK"))
|
||||
InputStream* ra_track_stream = cc->Inputs().HasTag(kRaTrackTag)
|
||||
? &(cc->Inputs().Tag(kRaTrackTag))
|
||||
: nullptr;
|
||||
|
||||
if (ra_track_stream && !ra_track_stream->IsEmpty()) {
|
||||
|
@ -861,8 +878,8 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
InputStream* ra_track_proto_string_stream =
|
||||
cc->Inputs().HasTag("RA_TRACK_PROTO_STRING")
|
||||
? &(cc->Inputs().Tag("RA_TRACK_PROTO_STRING"))
|
||||
cc->Inputs().HasTag(kRaTrackProtoStringTag)
|
||||
? &(cc->Inputs().Tag(kRaTrackProtoStringTag))
|
||||
: nullptr;
|
||||
if (ra_track_stream == nullptr || ra_track_stream->IsEmpty()) {
|
||||
if (ra_track_proto_string_stream &&
|
||||
|
@ -881,15 +898,15 @@ absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
// Always output in batch, only output in streaming if tracking data
|
||||
// is present (might be in fast forward mode instead).
|
||||
if (cc->Outputs().HasTag("BOXES") &&
|
||||
if (cc->Outputs().HasTag(kBoxesTag) &&
|
||||
(box_tracker_ || !track_stream->IsEmpty())) {
|
||||
std::unique_ptr<TimedBoxProtoList> boxes(new TimedBoxProtoList());
|
||||
*boxes = std::move(box_track_list);
|
||||
cc->Outputs().Tag("BOXES").Add(boxes.release(), timestamp);
|
||||
cc->Outputs().Tag(kBoxesTag).Add(boxes.release(), timestamp);
|
||||
}
|
||||
|
||||
if (viz_frame) {
|
||||
cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp);
|
||||
cc->Outputs().Tag(kVizTag).Add(viz_frame.release(), timestamp);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -1001,7 +1018,7 @@ void BoxTrackerCalculator::OutputRandomAccessTrack(
|
|||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag("RA_BOXES")
|
||||
.Tag(kRaBoxesTag)
|
||||
.Add(result_list.release(), cc->InputTimestamp());
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,13 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kCacheDirTag[] = "CACHE_DIR";
|
||||
constexpr char kCompleteTag[] = "COMPLETE";
|
||||
constexpr char kTrackingChunkTag[] = "TRACKING_CHUNK";
|
||||
constexpr char kTrackingTag[] = "TRACKING";
|
||||
constexpr char kCameraTag[] = "CAMERA";
|
||||
constexpr char kFlowTag[] = "FLOW";
|
||||
|
||||
using mediapipe::CameraMotion;
|
||||
using mediapipe::FlowPackager;
|
||||
using mediapipe::RegionFlowFeatureList;
|
||||
|
@ -91,27 +98,27 @@ class FlowPackagerCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(FlowPackagerCalculator);
|
||||
|
||||
absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (!cc->Inputs().HasTag("FLOW")) {
|
||||
if (!cc->Inputs().HasTag(kFlowTag)) {
|
||||
return tool::StatusFail("No input flow was specified.");
|
||||
}
|
||||
|
||||
cc->Inputs().Tag("FLOW").Set<RegionFlowFeatureList>();
|
||||
cc->Inputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
|
||||
|
||||
if (cc->Inputs().HasTag("CAMERA")) {
|
||||
cc->Inputs().Tag("CAMERA").Set<CameraMotion>();
|
||||
if (cc->Inputs().HasTag(kCameraTag)) {
|
||||
cc->Inputs().Tag(kCameraTag).Set<CameraMotion>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("TRACKING")) {
|
||||
cc->Outputs().Tag("TRACKING").Set<TrackingData>();
|
||||
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||
cc->Outputs().Tag(kTrackingTag).Set<TrackingData>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
||||
cc->Outputs().Tag("TRACKING_CHUNK").Set<TrackingDataChunk>();
|
||||
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||
cc->Outputs().Tag(kTrackingChunkTag).Set<TrackingDataChunk>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("COMPLETE")) {
|
||||
cc->Outputs().Tag("COMPLETE").Set<bool>();
|
||||
if (cc->Outputs().HasTag(kCompleteTag)) {
|
||||
cc->Outputs().Tag(kCompleteTag).Set<bool>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CACHE_DIR")) {
|
||||
cc->InputSidePackets().Tag("CACHE_DIR").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kCacheDirTag)) {
|
||||
cc->InputSidePackets().Tag(kCacheDirTag).Set<std::string>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -122,24 +129,24 @@ absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
flow_packager_.reset(new FlowPackager(options_.flow_packager_options()));
|
||||
|
||||
use_caching_ = cc->InputSidePackets().HasTag("CACHE_DIR");
|
||||
build_chunk_ = use_caching_ || cc->Outputs().HasTag("TRACKING_CHUNK");
|
||||
use_caching_ = cc->InputSidePackets().HasTag(kCacheDirTag);
|
||||
build_chunk_ = use_caching_ || cc->Outputs().HasTag(kTrackingChunkTag);
|
||||
if (use_caching_) {
|
||||
cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get<std::string>();
|
||||
cache_dir_ = cc->InputSidePackets().Tag(kCacheDirTag).Get<std::string>();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
||||
InputStream* flow_stream = &(cc->Inputs().Tag("FLOW"));
|
||||
InputStream* flow_stream = &(cc->Inputs().Tag(kFlowTag));
|
||||
const RegionFlowFeatureList& flow = flow_stream->Get<RegionFlowFeatureList>();
|
||||
|
||||
const Timestamp timestamp = flow_stream->Value().Timestamp();
|
||||
|
||||
const CameraMotion* camera_motion = nullptr;
|
||||
if (cc->Inputs().HasTag("CAMERA")) {
|
||||
InputStream* camera_stream = &(cc->Inputs().Tag("CAMERA"));
|
||||
if (cc->Inputs().HasTag(kCameraTag)) {
|
||||
InputStream* camera_stream = &(cc->Inputs().Tag(kCameraTag));
|
||||
camera_motion = &camera_stream->Get<CameraMotion>();
|
||||
}
|
||||
|
||||
|
@ -161,7 +168,7 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
|||
if (frame_idx_ > 0) {
|
||||
item->set_prev_timestamp_usec(prev_timestamp_.Value());
|
||||
}
|
||||
if (cc->Outputs().HasTag("TRACKING")) {
|
||||
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||
// Need to copy as output is requested.
|
||||
*item->mutable_tracking_data() = *tracking_data;
|
||||
} else {
|
||||
|
@ -172,9 +179,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
|||
options_.caching_chunk_size_msec() * (chunk_idx_ + 1);
|
||||
|
||||
if (timestamp.Value() / 1000 >= next_chunk_msec) {
|
||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
||||
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TRACKING_CHUNK")
|
||||
.Tag(kTrackingChunkTag)
|
||||
.Add(new TrackingDataChunk(tracking_chunk_),
|
||||
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
||||
}
|
||||
|
@ -185,9 +192,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("TRACKING")) {
|
||||
if (cc->Outputs().HasTag(kTrackingTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TRACKING")
|
||||
.Tag(kTrackingTag)
|
||||
.Add(tracking_data.release(), flow_stream->Value().Timestamp());
|
||||
}
|
||||
|
||||
|
@ -199,9 +206,9 @@ absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) {
|
|||
absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
|
||||
if (frame_idx_ > 0) {
|
||||
tracking_chunk_.set_last_chunk(true);
|
||||
if (cc->Outputs().HasTag("TRACKING_CHUNK")) {
|
||||
if (cc->Outputs().HasTag(kTrackingChunkTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("TRACKING_CHUNK")
|
||||
.Tag(kTrackingChunkTag)
|
||||
.Add(new TrackingDataChunk(tracking_chunk_),
|
||||
Timestamp(tracking_chunk_.item(0).timestamp_usec()));
|
||||
}
|
||||
|
@ -211,8 +218,8 @@ absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("COMPLETE")) {
|
||||
cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream());
|
||||
if (cc->Outputs().HasTag(kCompleteTag)) {
|
||||
cc->Outputs().Tag(kCompleteTag).Add(new bool(true), Timestamp::PreStream());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -38,6 +38,18 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kDownsampleTag[] = "DOWNSAMPLE";
|
||||
constexpr char kCsvFileTag[] = "CSV_FILE";
|
||||
constexpr char kGrayVideoOutTag[] = "GRAY_VIDEO_OUT";
|
||||
constexpr char kVideoOutTag[] = "VIDEO_OUT";
|
||||
constexpr char kDenseFgTag[] = "DENSE_FG";
|
||||
constexpr char kVizTag[] = "VIZ";
|
||||
constexpr char kSaliencyTag[] = "SALIENCY";
|
||||
constexpr char kCameraTag[] = "CAMERA";
|
||||
constexpr char kFlowTag[] = "FLOW";
|
||||
constexpr char kSelectionTag[] = "SELECTION";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
|
||||
using mediapipe::AffineAdapter;
|
||||
using mediapipe::CameraMotion;
|
||||
using mediapipe::FrameSelectionResult;
|
||||
|
@ -190,55 +202,56 @@ class MotionAnalysisCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(MotionAnalysisCalculator);
|
||||
|
||||
absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (cc->Inputs().HasTag("VIDEO")) {
|
||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag(kVideoTag)) {
|
||||
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
// Optional input stream from frame selection calculator.
|
||||
if (cc->Inputs().HasTag("SELECTION")) {
|
||||
cc->Inputs().Tag("SELECTION").Set<FrameSelectionResult>();
|
||||
if (cc->Inputs().HasTag(kSelectionTag)) {
|
||||
cc->Inputs().Tag(kSelectionTag).Set<FrameSelectionResult>();
|
||||
}
|
||||
|
||||
RET_CHECK(cc->Inputs().HasTag("VIDEO") || cc->Inputs().HasTag("SELECTION"))
|
||||
RET_CHECK(cc->Inputs().HasTag(kVideoTag) ||
|
||||
cc->Inputs().HasTag(kSelectionTag))
|
||||
<< "Either VIDEO, SELECTION must be specified.";
|
||||
|
||||
if (cc->Outputs().HasTag("FLOW")) {
|
||||
cc->Outputs().Tag("FLOW").Set<RegionFlowFeatureList>();
|
||||
if (cc->Outputs().HasTag(kFlowTag)) {
|
||||
cc->Outputs().Tag(kFlowTag).Set<RegionFlowFeatureList>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("CAMERA")) {
|
||||
cc->Outputs().Tag("CAMERA").Set<CameraMotion>();
|
||||
if (cc->Outputs().HasTag(kCameraTag)) {
|
||||
cc->Outputs().Tag(kCameraTag).Set<CameraMotion>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
||||
cc->Outputs().Tag("SALIENCY").Set<SalientPointFrame>();
|
||||
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||
cc->Outputs().Tag(kSaliencyTag).Set<SalientPointFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("VIZ")) {
|
||||
cc->Outputs().Tag("VIZ").Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag(kVizTag)) {
|
||||
cc->Outputs().Tag(kVizTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("DENSE_FG")) {
|
||||
cc->Outputs().Tag("DENSE_FG").Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag(kDenseFgTag)) {
|
||||
cc->Outputs().Tag(kDenseFgTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("VIDEO_OUT")) {
|
||||
cc->Outputs().Tag("VIDEO_OUT").Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag(kVideoOutTag)) {
|
||||
cc->Outputs().Tag(kVideoOutTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("GRAY_VIDEO_OUT")) {
|
||||
if (cc->Outputs().HasTag(kGrayVideoOutTag)) {
|
||||
// We only output grayscale video if we're actually performing full region-
|
||||
// flow analysis on the video.
|
||||
RET_CHECK(cc->Inputs().HasTag("VIDEO") &&
|
||||
!cc->Inputs().HasTag("SELECTION"));
|
||||
cc->Outputs().Tag("GRAY_VIDEO_OUT").Set<ImageFrame>();
|
||||
RET_CHECK(cc->Inputs().HasTag(kVideoTag) &&
|
||||
!cc->Inputs().HasTag(kSelectionTag));
|
||||
cc->Outputs().Tag(kGrayVideoOutTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("CSV_FILE")) {
|
||||
cc->InputSidePackets().Tag("CSV_FILE").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kCsvFileTag)) {
|
||||
cc->InputSidePackets().Tag(kCsvFileTag).Set<std::string>();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
|
||||
cc->InputSidePackets().Tag("DOWNSAMPLE").Set<float>();
|
||||
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
|
||||
cc->InputSidePackets().Tag(kDownsampleTag).Set<float>();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kOptionsTag)) {
|
||||
|
@ -253,16 +266,16 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
tool::RetrieveOptions(cc->Options<MotionAnalysisCalculatorOptions>(),
|
||||
cc->InputSidePackets(), kOptionsTag);
|
||||
|
||||
video_input_ = cc->Inputs().HasTag("VIDEO");
|
||||
selection_input_ = cc->Inputs().HasTag("SELECTION");
|
||||
region_flow_feature_output_ = cc->Outputs().HasTag("FLOW");
|
||||
camera_motion_output_ = cc->Outputs().HasTag("CAMERA");
|
||||
saliency_output_ = cc->Outputs().HasTag("SALIENCY");
|
||||
visualize_output_ = cc->Outputs().HasTag("VIZ");
|
||||
dense_foreground_output_ = cc->Outputs().HasTag("DENSE_FG");
|
||||
video_output_ = cc->Outputs().HasTag("VIDEO_OUT");
|
||||
grayscale_output_ = cc->Outputs().HasTag("GRAY_VIDEO_OUT");
|
||||
csv_file_input_ = cc->InputSidePackets().HasTag("CSV_FILE");
|
||||
video_input_ = cc->Inputs().HasTag(kVideoTag);
|
||||
selection_input_ = cc->Inputs().HasTag(kSelectionTag);
|
||||
region_flow_feature_output_ = cc->Outputs().HasTag(kFlowTag);
|
||||
camera_motion_output_ = cc->Outputs().HasTag(kCameraTag);
|
||||
saliency_output_ = cc->Outputs().HasTag(kSaliencyTag);
|
||||
visualize_output_ = cc->Outputs().HasTag(kVizTag);
|
||||
dense_foreground_output_ = cc->Outputs().HasTag(kDenseFgTag);
|
||||
video_output_ = cc->Outputs().HasTag(kVideoOutTag);
|
||||
grayscale_output_ = cc->Outputs().HasTag(kGrayVideoOutTag);
|
||||
csv_file_input_ = cc->InputSidePackets().HasTag(kCsvFileTag);
|
||||
hybrid_meta_analysis_ = options_.meta_analysis() ==
|
||||
MotionAnalysisCalculatorOptions::META_ANALYSIS_HYBRID;
|
||||
|
||||
|
@ -310,7 +323,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
if (csv_file_input_) {
|
||||
// Read from file and parse.
|
||||
const std::string filename =
|
||||
cc->InputSidePackets().Tag("CSV_FILE").Get<std::string>();
|
||||
cc->InputSidePackets().Tag(kCsvFileTag).Get<std::string>();
|
||||
|
||||
std::string file_contents;
|
||||
std::ifstream input_file(filename, std::ios::in);
|
||||
|
@ -327,11 +340,12 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
// Get video header from video or selection input if present.
|
||||
const VideoHeader* video_header = nullptr;
|
||||
if (video_input_ && !cc->Inputs().Tag("VIDEO").Header().IsEmpty()) {
|
||||
video_header = &(cc->Inputs().Tag("VIDEO").Header().Get<VideoHeader>());
|
||||
if (video_input_ && !cc->Inputs().Tag(kVideoTag).Header().IsEmpty()) {
|
||||
video_header = &(cc->Inputs().Tag(kVideoTag).Header().Get<VideoHeader>());
|
||||
} else if (selection_input_ &&
|
||||
!cc->Inputs().Tag("SELECTION").Header().IsEmpty()) {
|
||||
video_header = &(cc->Inputs().Tag("SELECTION").Header().Get<VideoHeader>());
|
||||
!cc->Inputs().Tag(kSelectionTag).Header().IsEmpty()) {
|
||||
video_header =
|
||||
&(cc->Inputs().Tag(kSelectionTag).Header().Get<VideoHeader>());
|
||||
} else {
|
||||
LOG(WARNING) << "No input video header found. Downstream calculators "
|
||||
"expecting video headers are likely to fail.";
|
||||
|
@ -339,7 +353,7 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
with_saliency_ = options_.analysis_options().compute_motion_saliency();
|
||||
// Force computation of saliency if requested as output.
|
||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
||||
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||
with_saliency_ = true;
|
||||
if (!options_.analysis_options().compute_motion_saliency()) {
|
||||
LOG(WARNING) << "Enable saliency computation. Set "
|
||||
|
@ -353,11 +367,11 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
cc->SetOffset(TimestampDiff(0));
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("DOWNSAMPLE")) {
|
||||
if (cc->InputSidePackets().HasTag(kDownsampleTag)) {
|
||||
options_.mutable_analysis_options()
|
||||
->mutable_flow_options()
|
||||
->set_downsample_factor(
|
||||
cc->InputSidePackets().Tag("DOWNSAMPLE").Get<float>());
|
||||
cc->InputSidePackets().Tag(kDownsampleTag).Get<float>());
|
||||
}
|
||||
|
||||
// If no video header is provided, just return and initialize on the first
|
||||
|
@ -369,30 +383,33 @@ absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) {
|
|||
////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE ///////////////
|
||||
|
||||
if (visualize_output_) {
|
||||
cc->Outputs().Tag("VIZ").SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||
cc->Outputs().Tag(kVizTag).SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||
}
|
||||
|
||||
if (video_output_) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_OUT")
|
||||
.Tag(kVideoOutTag)
|
||||
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("DENSE_FG")) {
|
||||
if (cc->Outputs().HasTag(kDenseFgTag)) {
|
||||
std::unique_ptr<VideoHeader> foreground_header(
|
||||
new VideoHeader(*video_header));
|
||||
foreground_header->format = ImageFormat::GRAY8;
|
||||
cc->Outputs().Tag("DENSE_FG").SetHeader(Adopt(foreground_header.release()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("CAMERA")) {
|
||||
cc->Outputs().Tag("CAMERA").SetHeader(
|
||||
Adopt(new VideoHeader(*video_header)));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag("SALIENCY")) {
|
||||
cc->Outputs()
|
||||
.Tag("SALIENCY")
|
||||
.Tag(kDenseFgTag)
|
||||
.SetHeader(Adopt(foreground_header.release()));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kCameraTag)) {
|
||||
cc->Outputs()
|
||||
.Tag(kCameraTag)
|
||||
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kSaliencyTag)) {
|
||||
cc->Outputs()
|
||||
.Tag(kSaliencyTag)
|
||||
.SetHeader(Adopt(new VideoHeader(*video_header)));
|
||||
}
|
||||
|
||||
|
@ -405,9 +422,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
InputStream* video_stream =
|
||||
video_input_ ? &(cc->Inputs().Tag("VIDEO")) : nullptr;
|
||||
video_input_ ? &(cc->Inputs().Tag(kVideoTag)) : nullptr;
|
||||
InputStream* selection_stream =
|
||||
selection_input_ ? &(cc->Inputs().Tag("SELECTION")) : nullptr;
|
||||
selection_input_ ? &(cc->Inputs().Tag(kSelectionTag)) : nullptr;
|
||||
|
||||
// Checked on Open.
|
||||
CHECK(video_stream || selection_stream);
|
||||
|
@ -425,8 +442,9 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
|||
CameraMotion output_motion = meta_motions_.front();
|
||||
meta_motions_.pop_front();
|
||||
output_motion.set_timestamp_usec(timestamp.Value());
|
||||
cc->Outputs().Tag("CAMERA").Add(new CameraMotion(output_motion),
|
||||
timestamp);
|
||||
cc->Outputs()
|
||||
.Tag(kCameraTag)
|
||||
.Add(new CameraMotion(output_motion), timestamp);
|
||||
}
|
||||
|
||||
if (region_flow_feature_output_) {
|
||||
|
@ -435,8 +453,8 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
|||
meta_features_.pop_front();
|
||||
|
||||
output_features.set_timestamp_usec(timestamp.Value());
|
||||
cc->Outputs().Tag("FLOW").Add(new RegionFlowFeatureList(output_features),
|
||||
timestamp);
|
||||
cc->Outputs().Tag(kFlowTag).Add(
|
||||
new RegionFlowFeatureList(output_features), timestamp);
|
||||
}
|
||||
|
||||
++frame_idx_;
|
||||
|
@ -478,16 +496,17 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
|||
MotionAnalysisCalculatorOptions::NO_ANALYSIS_USE_SELECTION) {
|
||||
// Output concatenated results, nothing to compute here.
|
||||
if (camera_motion_output_) {
|
||||
cc->Outputs().Tag("CAMERA").Add(
|
||||
frame_selection_result->release_camera_motion(), timestamp);
|
||||
cc->Outputs()
|
||||
.Tag(kCameraTag)
|
||||
.Add(frame_selection_result->release_camera_motion(), timestamp);
|
||||
}
|
||||
if (region_flow_feature_output_) {
|
||||
cc->Outputs().Tag("FLOW").Add(frame_selection_result->release_features(),
|
||||
timestamp);
|
||||
cc->Outputs().Tag(kFlowTag).Add(
|
||||
frame_selection_result->release_features(), timestamp);
|
||||
}
|
||||
|
||||
if (video_output_) {
|
||||
cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value());
|
||||
cc->Outputs().Tag(kVideoOutTag).AddPacket(video_stream->Value());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -565,7 +584,7 @@ absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) {
|
|||
grayscale_mat.copyTo(image_frame_mat);
|
||||
|
||||
cc->Outputs()
|
||||
.Tag("GRAY_VIDEO_OUT")
|
||||
.Tag(kGrayVideoOutTag)
|
||||
.Add(grayscale_image.release(), timestamp);
|
||||
}
|
||||
|
||||
|
@ -640,7 +659,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
|
|||
*feature_list, *camera_motion,
|
||||
with_saliency_ ? saliency[k].get() : nullptr, &visualization);
|
||||
|
||||
cc->Outputs().Tag("VIZ").Add(visualization_frame.release(), timestamp);
|
||||
cc->Outputs().Tag(kVizTag).Add(visualization_frame.release(), timestamp);
|
||||
}
|
||||
|
||||
// Output dense foreground mask.
|
||||
|
@ -650,26 +669,26 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames(
|
|||
cv::Mat foreground = formats::MatView(foreground_frame.get());
|
||||
motion_analysis_->ComputeDenseForeground(*feature_list, *camera_motion,
|
||||
&foreground);
|
||||
cc->Outputs().Tag("DENSE_FG").Add(foreground_frame.release(), timestamp);
|
||||
cc->Outputs().Tag(kDenseFgTag).Add(foreground_frame.release(), timestamp);
|
||||
}
|
||||
|
||||
// Output flow features if requested.
|
||||
if (region_flow_feature_output_) {
|
||||
cc->Outputs().Tag("FLOW").Add(feature_list.release(), timestamp);
|
||||
cc->Outputs().Tag(kFlowTag).Add(feature_list.release(), timestamp);
|
||||
}
|
||||
|
||||
// Output camera motion.
|
||||
if (camera_motion_output_) {
|
||||
cc->Outputs().Tag("CAMERA").Add(camera_motion.release(), timestamp);
|
||||
cc->Outputs().Tag(kCameraTag).Add(camera_motion.release(), timestamp);
|
||||
}
|
||||
|
||||
if (video_output_) {
|
||||
cc->Outputs().Tag("VIDEO_OUT").AddPacket(packet_buffer_[k]);
|
||||
cc->Outputs().Tag(kVideoOutTag).AddPacket(packet_buffer_[k]);
|
||||
}
|
||||
|
||||
// Output saliency.
|
||||
if (saliency_output_) {
|
||||
cc->Outputs().Tag("SALIENCY").Add(saliency[k].release(), timestamp);
|
||||
cc->Outputs().Tag(kSaliencyTag).Add(saliency[k].release(), timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,12 @@
|
|||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kSavedAudioPathTag[] = "SAVED_AUDIO_PATH";
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
|
||||
|
||||
// cv::VideoCapture set data type to unsigned char by default. Therefore, the
|
||||
// image format is only related to the number of channles the cv::Mat has.
|
||||
ImageFormat::Format GetImageFormat(int num_channels) {
|
||||
|
@ -87,20 +93,20 @@ ImageFormat::Format GetImageFormat(int num_channels) {
|
|||
class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
|
||||
cc->Outputs().Tag("VIDEO").Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
cc->Outputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
cc->InputSidePackets().Tag(kInputFilePathTag).Set<std::string>();
|
||||
cc->Outputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
|
||||
cc->Outputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||
}
|
||||
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
|
||||
cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set<std::string>();
|
||||
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
|
||||
cc->OutputSidePackets().Tag(kSavedAudioPathTag).Set<std::string>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
const std::string& input_file_path =
|
||||
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
|
||||
cc->InputSidePackets().Tag(kInputFilePathTag).Get<std::string>();
|
||||
cap_ = absl::make_unique<cv::VideoCapture>(input_file_path);
|
||||
if (!cap_->isOpened()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
|
@ -140,16 +146,16 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
header->frame_rate = fps;
|
||||
header->duration = frame_count_ / fps;
|
||||
|
||||
if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
if (cc->Outputs().HasTag(kVideoPrestreamTag)) {
|
||||
cc->Outputs()
|
||||
.Tag("VIDEO_PRESTREAM")
|
||||
.Tag(kVideoPrestreamTag)
|
||||
.Add(header.release(), Timestamp::PreStream());
|
||||
cc->Outputs().Tag("VIDEO_PRESTREAM").Close();
|
||||
cc->Outputs().Tag(kVideoPrestreamTag).Close();
|
||||
}
|
||||
// Rewind to the very first frame.
|
||||
cap_->set(cv::CAP_PROP_POS_AVI_RATIO, 0);
|
||||
|
||||
if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) {
|
||||
if (cc->OutputSidePackets().HasTag(kSavedAudioPathTag)) {
|
||||
#ifdef HAVE_FFMPEG
|
||||
std::string saved_audio_path = std::tmpnam(nullptr);
|
||||
std::string ffmpeg_command =
|
||||
|
@ -159,14 +165,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
int status_code = system(absl::StrCat("ls ", saved_audio_path).c_str());
|
||||
if (status_code == 0) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag("SAVED_AUDIO_PATH")
|
||||
.Tag(kSavedAudioPathTag)
|
||||
.Set(MakePacket<std::string>(saved_audio_path));
|
||||
} else {
|
||||
LOG(WARNING) << "FFmpeg can't extract audio from " << input_file_path
|
||||
<< " by executing the following command: "
|
||||
<< ffmpeg_command;
|
||||
cc->OutputSidePackets()
|
||||
.Tag("SAVED_AUDIO_PATH")
|
||||
.Tag(kSavedAudioPathTag)
|
||||
.Set(MakePacket<std::string>(std::string()));
|
||||
}
|
||||
#else
|
||||
|
@ -208,7 +214,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
// If the timestamp of the current frame is not greater than the one of the
|
||||
// previous frame, the new frame will be discarded.
|
||||
if (prev_timestamp_ < timestamp) {
|
||||
cc->Outputs().Tag("VIDEO").Add(image_frame.release(), timestamp);
|
||||
cc->Outputs().Tag(kVideoTag).Add(image_frame.release(), timestamp);
|
||||
prev_timestamp_ = timestamp;
|
||||
decoded_frames_++;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,10 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
|
||||
|
||||
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
|
@ -37,19 +41,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MP4_AVC720P_AAC.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("VIDEO_PRESTREAM")
|
||||
.Tag(kVideoPrestreamTag)
|
||||
.packets[0]
|
||||
.ValidateAsType<VideoHeader>());
|
||||
const mediapipe::VideoHeader& header =
|
||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
||||
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||
EXPECT_EQ(1280, header.width);
|
||||
EXPECT_EQ(640, header.height);
|
||||
|
@ -58,10 +62,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
|||
// The number of the output packets should be 180.
|
||||
// Some OpenCV version returns the first two frames with the same timestamp on
|
||||
// macos and we might miss one frame here.
|
||||
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
|
||||
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
|
||||
EXPECT_GE(num_of_packets, 179);
|
||||
for (int i = 0; i < num_of_packets; ++i) {
|
||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
||||
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||
cv::Mat output_mat =
|
||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||
EXPECT_EQ(1280, output_mat.size().width);
|
||||
|
@ -83,19 +87,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_FLV_H264_AAC.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("VIDEO_PRESTREAM")
|
||||
.Tag(kVideoPrestreamTag)
|
||||
.packets[0]
|
||||
.ValidateAsType<VideoHeader>());
|
||||
const mediapipe::VideoHeader& header =
|
||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
||||
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||
EXPECT_EQ(640, header.width);
|
||||
EXPECT_EQ(320, header.height);
|
||||
|
@ -103,9 +107,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
|
|||
// can be either 30.30303f (with opencv2) or 30f (with opencv3 and opencv4).
|
||||
// EXPECT_FLOAT_EQ(6.0f, header.duration);
|
||||
// EXPECT_FLOAT_EQ(30.0f, header.frame_rate);
|
||||
EXPECT_EQ(180, runner.Outputs().Tag("VIDEO").packets.size());
|
||||
EXPECT_EQ(180, runner.Outputs().Tag(kVideoTag).packets.size());
|
||||
for (int i = 0; i < 180; ++i) {
|
||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
||||
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||
cv::Mat output_mat =
|
||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||
EXPECT_EQ(640, output_mat.size().width);
|
||||
|
@ -127,19 +131,19 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MKV_VP8_VORBIS.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag("VIDEO_PRESTREAM").packets.size(), 1);
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("VIDEO_PRESTREAM")
|
||||
.Tag(kVideoPrestreamTag)
|
||||
.packets[0]
|
||||
.ValidateAsType<VideoHeader>());
|
||||
const mediapipe::VideoHeader& header =
|
||||
runner.Outputs().Tag("VIDEO_PRESTREAM").packets[0].Get<VideoHeader>();
|
||||
runner.Outputs().Tag(kVideoPrestreamTag).packets[0].Get<VideoHeader>();
|
||||
EXPECT_EQ(ImageFormat::SRGB, header.format);
|
||||
EXPECT_EQ(640, header.width);
|
||||
EXPECT_EQ(320, header.height);
|
||||
|
@ -148,10 +152,10 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
|
|||
// The number of the output packets should be 180.
|
||||
// Some OpenCV version returns the first two frames with the same timestamp on
|
||||
// macos and we might miss one frame here.
|
||||
int num_of_packets = runner.Outputs().Tag("VIDEO").packets.size();
|
||||
int num_of_packets = runner.Outputs().Tag(kVideoTag).packets.size();
|
||||
EXPECT_GE(num_of_packets, 179);
|
||||
for (int i = 0; i < num_of_packets; ++i) {
|
||||
Packet image_frame_packet = runner.Outputs().Tag("VIDEO").packets[i];
|
||||
Packet image_frame_packet = runner.Outputs().Tag(kVideoTag).packets[i];
|
||||
cv::Mat output_mat =
|
||||
formats::MatView(&(image_frame_packet.Get<ImageFrame>()));
|
||||
EXPECT_EQ(640, output_mat.size().width);
|
||||
|
|
|
@ -36,6 +36,11 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kAudioFilePathTag[] = "AUDIO_FILE_PATH";
|
||||
constexpr char kOutputFilePathTag[] = "OUTPUT_FILE_PATH";
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
|
||||
// Encodes the input video stream and produces a media file.
|
||||
// The media file can be output to the output_file_path specified as a side
|
||||
// packet. Currently, the calculator only supports one video stream (in
|
||||
|
@ -90,15 +95,15 @@ class OpenCvVideoEncoderCalculator : public CalculatorBase {
|
|||
};
|
||||
|
||||
absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag("VIDEO"));
|
||||
cc->Inputs().Tag("VIDEO").Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
RET_CHECK(cc->Inputs().HasTag(kVideoTag));
|
||||
cc->Inputs().Tag(kVideoTag).Set<ImageFrame>();
|
||||
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||
}
|
||||
RET_CHECK(cc->InputSidePackets().HasTag("OUTPUT_FILE_PATH"));
|
||||
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
|
||||
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set<std::string>();
|
||||
RET_CHECK(cc->InputSidePackets().HasTag(kOutputFilePathTag));
|
||||
cc->InputSidePackets().Tag(kOutputFilePathTag).Set<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
|
||||
cc->InputSidePackets().Tag(kAudioFilePathTag).Set<std::string>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -116,7 +121,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
|||
<< "Video format must be specified in "
|
||||
"OpenCvVideoEncoderCalculatorOptions";
|
||||
output_file_path_ =
|
||||
cc->InputSidePackets().Tag("OUTPUT_FILE_PATH").Get<std::string>();
|
||||
cc->InputSidePackets().Tag(kOutputFilePathTag).Get<std::string>();
|
||||
std::vector<std::string> splited_file_path =
|
||||
absl::StrSplit(output_file_path_, '.');
|
||||
RET_CHECK(splited_file_path.size() >= 2 &&
|
||||
|
@ -126,7 +131,7 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
|||
// If the video header will be available, the video metadata will be fetched
|
||||
// from the video header directly. The calculator will receive the video
|
||||
// header packet at timestamp prestream.
|
||||
if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) {
|
||||
if (cc->Inputs().HasTag(kVideoPrestreamTag)) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
return SetUpVideoWriter(options.fps(), options.width(), options.height());
|
||||
|
@ -135,13 +140,13 @@ absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) {
|
|||
absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
const VideoHeader& video_header =
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
||||
cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||
return SetUpVideoWriter(video_header.frame_rate, video_header.width,
|
||||
video_header.height);
|
||||
}
|
||||
|
||||
const ImageFrame& image_frame =
|
||||
cc->Inputs().Tag("VIDEO").Value().Get<ImageFrame>();
|
||||
cc->Inputs().Tag(kVideoTag).Value().Get<ImageFrame>();
|
||||
ImageFormat::Format format = image_frame.Format();
|
||||
cv::Mat frame;
|
||||
if (format == ImageFormat::GRAY8) {
|
||||
|
@ -149,7 +154,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
|||
if (frame.empty()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Receive empty frame at timestamp "
|
||||
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
|
||||
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
|
||||
<< " in OpenCvVideoEncoderCalculator::Process()";
|
||||
}
|
||||
} else {
|
||||
|
@ -157,7 +162,7 @@ absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) {
|
|||
if (tmp_frame.empty()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Receive empty frame at timestamp "
|
||||
<< cc->Inputs().Tag("VIDEO").Value().Timestamp()
|
||||
<< cc->Inputs().Tag(kVideoTag).Value().Timestamp()
|
||||
<< " in OpenCvVideoEncoderCalculator::Process()";
|
||||
}
|
||||
if (format == ImageFormat::SRGB) {
|
||||
|
@ -177,10 +182,10 @@ absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) {
|
|||
if (writer_ && writer_->isOpened()) {
|
||||
writer_->release();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) {
|
||||
if (cc->InputSidePackets().HasTag(kAudioFilePathTag)) {
|
||||
#ifdef HAVE_FFMPEG
|
||||
const std::string& audio_file_path =
|
||||
cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Get<std::string>();
|
||||
cc->InputSidePackets().Tag(kAudioFilePathTag).Get<std::string>();
|
||||
if (audio_file_path.empty()) {
|
||||
LOG(WARNING) << "OpenCvVideoEncoderCalculator isn't able to attach the "
|
||||
"audio tracks to the generated video because the audio "
|
||||
|
|
|
@ -23,6 +23,11 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kBackwardFlowTag[] = "BACKWARD_FLOW";
|
||||
constexpr char kForwardFlowTag[] = "FORWARD_FLOW";
|
||||
constexpr char kSecondFrameTag[] = "SECOND_FRAME";
|
||||
constexpr char kFirstFrameTag[] = "FIRST_FRAME";
|
||||
|
||||
// Checks that img1 and img2 have the same dimensions.
|
||||
bool ImageSizesMatch(const ImageFrame& img1, const ImageFrame& img2) {
|
||||
return (img1.Width() == img2.Width()) && (img1.Height() == img2.Height());
|
||||
|
@ -94,19 +99,19 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase {
|
|||
};
|
||||
|
||||
absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) {
|
||||
if (!cc->Inputs().HasTag("FIRST_FRAME") ||
|
||||
!cc->Inputs().HasTag("SECOND_FRAME")) {
|
||||
if (!cc->Inputs().HasTag(kFirstFrameTag) ||
|
||||
!cc->Inputs().HasTag(kSecondFrameTag)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Missing required input streams. Both FIRST_FRAME and SECOND_FRAME "
|
||||
"must be specified.");
|
||||
}
|
||||
cc->Inputs().Tag("FIRST_FRAME").Set<ImageFrame>();
|
||||
cc->Inputs().Tag("SECOND_FRAME").Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
|
||||
cc->Outputs().Tag("FORWARD_FLOW").Set<OpticalFlowField>();
|
||||
cc->Inputs().Tag(kFirstFrameTag).Set<ImageFrame>();
|
||||
cc->Inputs().Tag(kSecondFrameTag).Set<ImageFrame>();
|
||||
if (cc->Outputs().HasTag(kForwardFlowTag)) {
|
||||
cc->Outputs().Tag(kForwardFlowTag).Set<OpticalFlowField>();
|
||||
}
|
||||
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
|
||||
cc->Outputs().Tag("BACKWARD_FLOW").Set<OpticalFlowField>();
|
||||
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
|
||||
cc->Outputs().Tag(kBackwardFlowTag).Set<OpticalFlowField>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -116,10 +121,10 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
|
|||
absl::MutexLock lock(&mutex_);
|
||||
tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1());
|
||||
}
|
||||
if (cc->Outputs().HasTag("FORWARD_FLOW")) {
|
||||
if (cc->Outputs().HasTag(kForwardFlowTag)) {
|
||||
forward_requested_ = true;
|
||||
}
|
||||
if (cc->Outputs().HasTag("BACKWARD_FLOW")) {
|
||||
if (cc->Outputs().HasTag(kBackwardFlowTag)) {
|
||||
backward_requested_ = true;
|
||||
}
|
||||
|
||||
|
@ -128,15 +133,15 @@ absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) {
|
|||
|
||||
absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
|
||||
const ImageFrame& first_frame =
|
||||
cc->Inputs().Tag("FIRST_FRAME").Value().Get<ImageFrame>();
|
||||
cc->Inputs().Tag(kFirstFrameTag).Value().Get<ImageFrame>();
|
||||
const ImageFrame& second_frame =
|
||||
cc->Inputs().Tag("SECOND_FRAME").Value().Get<ImageFrame>();
|
||||
cc->Inputs().Tag(kSecondFrameTag).Value().Get<ImageFrame>();
|
||||
if (forward_requested_) {
|
||||
auto forward_optical_flow_field = absl::make_unique<OpticalFlowField>();
|
||||
MP_RETURN_IF_ERROR(CalculateOpticalFlow(first_frame, second_frame,
|
||||
forward_optical_flow_field.get()));
|
||||
cc->Outputs()
|
||||
.Tag("FORWARD_FLOW")
|
||||
.Tag(kForwardFlowTag)
|
||||
.Add(forward_optical_flow_field.release(), cc->InputTimestamp());
|
||||
}
|
||||
if (backward_requested_) {
|
||||
|
@ -144,7 +149,7 @@ absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) {
|
|||
MP_RETURN_IF_ERROR(CalculateOpticalFlow(second_frame, first_frame,
|
||||
backward_optical_flow_field.get()));
|
||||
cc->Outputs()
|
||||
.Tag("BACKWARD_FLOW")
|
||||
.Tag(kBackwardFlowTag)
|
||||
.Add(backward_optical_flow_field.release(), cc->InputTimestamp());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kFrameTag[] = "FRAME";
|
||||
|
||||
// Sets up VideoHeader based on the 1st ImageFrame and emits it with timestamp
|
||||
// PreStream. Note that this calculator only fills in format, width, and height,
|
||||
// i.e. frame_rate and duration will not be filled, unless:
|
||||
|
@ -64,8 +67,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
|
|||
if (!cc->Inputs().UsesTags()) {
|
||||
cc->Inputs().Index(0).Set<ImageFrame>();
|
||||
} else {
|
||||
cc->Inputs().Tag("FRAME").Set<ImageFrame>();
|
||||
cc->Inputs().Tag("VIDEO_PRESTREAM").Set<VideoHeader>();
|
||||
cc->Inputs().Tag(kFrameTag).Set<ImageFrame>();
|
||||
cc->Inputs().Tag(kVideoPrestreamTag).Set<VideoHeader>();
|
||||
}
|
||||
cc->Outputs().Index(0).Set<VideoHeader>();
|
||||
return absl::OkStatus();
|
||||
|
@ -73,8 +76,8 @@ absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) {
|
|||
|
||||
absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) {
|
||||
frame_rate_in_prestream_ = cc->Inputs().UsesTags() &&
|
||||
cc->Inputs().HasTag("FRAME") &&
|
||||
cc->Inputs().HasTag("VIDEO_PRESTREAM");
|
||||
cc->Inputs().HasTag(kFrameTag) &&
|
||||
cc->Inputs().HasTag(kVideoPrestreamTag);
|
||||
header_ = absl::make_unique<VideoHeader>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -82,15 +85,15 @@ absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream(
|
|||
CalculatorContext* cc) {
|
||||
cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment();
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
RET_CHECK(cc->Inputs().Tag("FRAME").IsEmpty());
|
||||
RET_CHECK(!cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty());
|
||||
*header_ = cc->Inputs().Tag("VIDEO_PRESTREAM").Get<VideoHeader>();
|
||||
RET_CHECK(cc->Inputs().Tag(kFrameTag).IsEmpty());
|
||||
RET_CHECK(!cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty());
|
||||
*header_ = cc->Inputs().Tag(kVideoPrestreamTag).Get<VideoHeader>();
|
||||
RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero";
|
||||
} else {
|
||||
RET_CHECK(cc->Inputs().Tag("VIDEO_PRESTREAM").IsEmpty())
|
||||
RET_CHECK(cc->Inputs().Tag(kVideoPrestreamTag).IsEmpty())
|
||||
<< "Packet on VIDEO_PRESTREAM must come in at Timestamp::PreStream().";
|
||||
RET_CHECK(!cc->Inputs().Tag("FRAME").IsEmpty());
|
||||
const auto& frame = cc->Inputs().Tag("FRAME").Get<ImageFrame>();
|
||||
RET_CHECK(!cc->Inputs().Tag(kFrameTag).IsEmpty());
|
||||
const auto& frame = cc->Inputs().Tag(kFrameTag).Get<ImageFrame>();
|
||||
header_->format = frame.Format();
|
||||
header_->width = frame.Width();
|
||||
header_->height = frame.Height();
|
||||
|
|
|
@ -44,28 +44,32 @@ using mediapipe::MakePacket;
|
|||
using mediapipe::OutputStreamShardSet;
|
||||
using mediapipe::Timestamp;
|
||||
namespace proto_ns = mediapipe::proto_ns;
|
||||
|
||||
constexpr char kEventTag[] = "EVENT";
|
||||
constexpr char kOutTag[] = "OUT";
|
||||
|
||||
using mediapipe::CalculatorGraph;
|
||||
using mediapipe::Packet;
|
||||
|
||||
class InfiniteSequenceCalculator : public mediapipe::CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(mediapipe::CalculatorContract* cc) {
|
||||
cc->Outputs().Tag("OUT").Set<int>();
|
||||
cc->Outputs().Tag("EVENT").Set<int>();
|
||||
cc->Outputs().Tag(kOutTag).Set<int>();
|
||||
cc->Outputs().Tag(kEventTag).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
cc->Outputs().Tag("OUT").AddPacket(
|
||||
cc->Outputs().Tag(kOutTag).AddPacket(
|
||||
MakePacket<int>(count_).At(Timestamp(count_)));
|
||||
count_++;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -81,11 +85,11 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
|
|||
cc->Inputs().Get("", i).SetAny();
|
||||
cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i));
|
||||
}
|
||||
cc->Outputs().Tag("EVENT").Set<int>();
|
||||
cc->Outputs().Tag(kEventTag).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(1).At(Timestamp(1)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
|
@ -98,7 +102,7 @@ class StoppingPassThroughCalculator : public mediapipe::CalculatorBase {
|
|||
: mediapipe::tool::StatusStop();
|
||||
}
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
cc->Outputs().Tag("EVENT").AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||
cc->Outputs().Tag(kEventTag).AddPacket(MakePacket<int>(2).At(Timestamp(2)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -65,6 +65,16 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kCounter2Tag[] = "COUNTER2";
|
||||
constexpr char kCounter1Tag[] = "COUNTER1";
|
||||
constexpr char kExtraTag[] = "EXTRA";
|
||||
constexpr char kWaitSemTag[] = "WAIT_SEM";
|
||||
constexpr char kPostSemTag[] = "POST_SEM";
|
||||
constexpr char kErrorOnOpenTag[] = "ERROR_ON_OPEN";
|
||||
constexpr char kOutputTag[] = "OUTPUT";
|
||||
constexpr char kInputTag[] = "INPUT";
|
||||
constexpr char kSelectTag[] = "SELECT";
|
||||
|
||||
using testing::ElementsAre;
|
||||
using testing::HasSubstr;
|
||||
|
||||
|
@ -125,8 +135,8 @@ class DemuxTimedCalculator : public CalculatorBase {
|
|||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
||||
cc->Inputs().Tag("SELECT").Set<int>();
|
||||
PacketType* data_input = &cc->Inputs().Tag("INPUT");
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||
PacketType* data_input = &cc->Inputs().Tag(kInputTag);
|
||||
data_input->SetAny();
|
||||
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
||||
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
||||
|
@ -182,7 +192,7 @@ REGISTER_CALCULATOR(DemuxTimedCalculator);
|
|||
class MuxTimedCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("SELECT").Set<int>();
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||
CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT");
|
||||
PacketType* data_input0 = &cc->Inputs().Get(data_input_id);
|
||||
data_input0->SetAny();
|
||||
|
@ -191,7 +201,7 @@ class MuxTimedCalculator : public CalculatorBase {
|
|||
cc->Inputs().Get(data_input_id).SetSameAs(data_input0);
|
||||
}
|
||||
RET_CHECK_EQ(cc->Outputs().NumEntries(), 1);
|
||||
cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0);
|
||||
cc->Outputs().Tag(kOutputTag).SetSameAs(data_input0);
|
||||
cc->SetTimestampOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -598,12 +608,12 @@ class ErrorOnOpenCalculator : public CalculatorBase {
|
|||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set<bool>();
|
||||
cc->InputSidePackets().Tag(kErrorOnOpenTag).Set<bool>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get<bool>()) {
|
||||
if (cc->InputSidePackets().Tag(kErrorOnOpenTag).Get<bool>()) {
|
||||
return absl::NotFoundError("expected error");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -920,8 +930,8 @@ class SemaphoreCalculator : public CalculatorBase {
|
|||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).SetAny();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
cc->InputSidePackets().Tag("POST_SEM").Set<Semaphore*>();
|
||||
cc->InputSidePackets().Tag("WAIT_SEM").Set<Semaphore*>();
|
||||
cc->InputSidePackets().Tag(kPostSemTag).Set<Semaphore*>();
|
||||
cc->InputSidePackets().Tag(kWaitSemTag).Set<Semaphore*>();
|
||||
cc->SetTimestampOffset(TimestampDiff(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -929,8 +939,8 @@ class SemaphoreCalculator : public CalculatorBase {
|
|||
absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); }
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
cc->InputSidePackets().Tag("POST_SEM").Get<Semaphore*>()->Release(1);
|
||||
cc->InputSidePackets().Tag("WAIT_SEM").Get<Semaphore*>()->Acquire(1);
|
||||
cc->InputSidePackets().Tag(kPostSemTag).Get<Semaphore*>()->Release(1);
|
||||
cc->InputSidePackets().Tag(kWaitSemTag).Get<Semaphore*>()->Acquire(1);
|
||||
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -1177,9 +1187,9 @@ class IncrementingStatusHandler : public StatusHandler {
|
|||
static absl::Status FillExpectations(
|
||||
const MediaPipeOptions& extendable_options,
|
||||
PacketTypeSet* input_side_packets) {
|
||||
input_side_packets->Tag("EXTRA").SetAny().Optional();
|
||||
input_side_packets->Tag("COUNTER1").Set<std::unique_ptr<int>>();
|
||||
input_side_packets->Tag("COUNTER2").Set<std::unique_ptr<int>>();
|
||||
input_side_packets->Tag(kExtraTag).SetAny().Optional();
|
||||
input_side_packets->Tag(kCounter1Tag).Set<std::unique_ptr<int>>();
|
||||
input_side_packets->Tag(kCounter2Tag).Set<std::unique_ptr<int>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -1187,7 +1197,7 @@ class IncrementingStatusHandler : public StatusHandler {
|
|||
const MediaPipeOptions& extendable_options,
|
||||
const PacketSet& input_side_packets, //
|
||||
const absl::Status& pre_run_status) {
|
||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER1"));
|
||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter1Tag));
|
||||
(*counter)++;
|
||||
return pre_run_status_result_;
|
||||
}
|
||||
|
@ -1195,7 +1205,7 @@ class IncrementingStatusHandler : public StatusHandler {
|
|||
static absl::Status HandleStatus(const MediaPipeOptions& extendable_options,
|
||||
const PacketSet& input_side_packets, //
|
||||
const absl::Status& run_status) {
|
||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag("COUNTER2"));
|
||||
int* counter = GetFromUniquePtr<int>(input_side_packets.Tag(kCounter2Tag));
|
||||
(*counter)++;
|
||||
return post_run_status_result_;
|
||||
}
|
||||
|
@ -2228,20 +2238,20 @@ class DemuxUntimedCalculator : public CalculatorBase {
|
|||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
RET_CHECK_EQ(cc->Inputs().NumEntries(), 2);
|
||||
cc->Inputs().Tag("INPUT").SetAny();
|
||||
cc->Inputs().Tag("SELECT").Set<int>();
|
||||
cc->Inputs().Tag(kInputTag).SetAny();
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>();
|
||||
for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT");
|
||||
id < cc->Outputs().EndId("OUTPUT"); ++id) {
|
||||
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT"));
|
||||
cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag(kInputTag));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
int index = cc->Inputs().Tag("SELECT").Get<int>();
|
||||
if (!cc->Inputs().Tag("INPUT").IsEmpty()) {
|
||||
int index = cc->Inputs().Tag(kSelectTag).Get<int>();
|
||||
if (!cc->Inputs().Tag(kInputTag).IsEmpty()) {
|
||||
cc->Outputs()
|
||||
.Get("OUTPUT", index)
|
||||
.AddPacket(cc->Inputs().Tag("INPUT").Value());
|
||||
.AddPacket(cc->Inputs().Tag(kInputTag).Value());
|
||||
} else {
|
||||
cc->Outputs()
|
||||
.Get("OUTPUT", index)
|
||||
|
|
|
@ -32,6 +32,11 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kTag[] = "";
|
||||
constexpr char kBTag[] = "B";
|
||||
constexpr char kATag[] = "A";
|
||||
constexpr char kSideOutputTag[] = "SIDE_OUTPUT";
|
||||
|
||||
// Inputs: 2 streams with ints. Headers are strings.
|
||||
// Input side packets: 1.
|
||||
// Outputs: 3 streams with ints. #0 and #1 will contain the negated values from
|
||||
|
@ -48,7 +53,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
|
|||
cc->Outputs().Index(2).SetSameAs(&cc->InputSidePackets().Index(0));
|
||||
cc->InputSidePackets().Index(0).SetAny();
|
||||
cc->OutputSidePackets()
|
||||
.Tag("SIDE_OUTPUT")
|
||||
.Tag(kSideOutputTag)
|
||||
.SetSameAs(&cc->InputSidePackets().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -64,7 +69,7 @@ class CalculatorRunnerTestCalculator : public CalculatorBase {
|
|||
Adopt(new std::string(absl::StrCat(input_header_string, i))));
|
||||
}
|
||||
cc->OutputSidePackets()
|
||||
.Tag("SIDE_OUTPUT")
|
||||
.Tag(kSideOutputTag)
|
||||
.Set(cc->InputSidePackets().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -152,7 +157,7 @@ TEST(CalculatorRunner, RunsCalculator) {
|
|||
Adopt(new int(input_side_packet_content));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
EXPECT_EQ(input_side_packet_content,
|
||||
runner.OutputSidePackets().Tag("SIDE_OUTPUT").Get<int>());
|
||||
runner.OutputSidePackets().Tag(kSideOutputTag).Get<int>());
|
||||
const auto& outputs = runner.Outputs();
|
||||
ASSERT_EQ(3, outputs.NumEntries());
|
||||
|
||||
|
@ -209,9 +214,9 @@ TEST(CalculatorRunner, MultiTagTestCalculatorOk) {
|
|||
const auto& outputs = runner.Outputs();
|
||||
ASSERT_EQ(3, outputs.NumEntries());
|
||||
for (int ts = 0; ts < 5; ++ts) {
|
||||
const std::vector<Packet>& a_packets = outputs.Tag("A").packets;
|
||||
const std::vector<Packet>& b_packets = outputs.Tag("B").packets;
|
||||
const std::vector<Packet>& c_packets = outputs.Tag("").packets;
|
||||
const std::vector<Packet>& a_packets = outputs.Tag(kATag).packets;
|
||||
const std::vector<Packet>& b_packets = outputs.Tag(kBTag).packets;
|
||||
const std::vector<Packet>& c_packets = outputs.Tag(kTag).packets;
|
||||
EXPECT_EQ(Timestamp(ts), a_packets[ts].Timestamp());
|
||||
EXPECT_EQ(Timestamp(ts), b_packets[ts].Timestamp());
|
||||
EXPECT_EQ(Timestamp(ts), c_packets[ts].Timestamp());
|
||||
|
|
|
@ -24,6 +24,10 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kTag2Tag[] = "TAG_2";
|
||||
constexpr char kTag0Tag[] = "TAG_0";
|
||||
constexpr char kTag1Tag[] = "TAG_1";
|
||||
|
||||
TEST(CollectionTest, BasicByIndex) {
|
||||
tool::TagAndNameInfo info;
|
||||
info.names.push_back("name_1");
|
||||
|
@ -55,14 +59,14 @@ TEST(CollectionTest, BasicByTag) {
|
|||
info.names.push_back("name_2");
|
||||
info.tags.push_back("TAG_2");
|
||||
internal::Collection<int> collection(info);
|
||||
collection.Tag("TAG_1") = 101;
|
||||
collection.Tag("TAG_0") = 100;
|
||||
collection.Tag("TAG_2") = 102;
|
||||
collection.Tag(kTag1Tag) = 101;
|
||||
collection.Tag(kTag0Tag) = 100;
|
||||
collection.Tag(kTag2Tag) = 102;
|
||||
|
||||
// Test the stored values.
|
||||
EXPECT_EQ(100, collection.Tag("TAG_0"));
|
||||
EXPECT_EQ(101, collection.Tag("TAG_1"));
|
||||
EXPECT_EQ(102, collection.Tag("TAG_2"));
|
||||
EXPECT_EQ(100, collection.Tag(kTag0Tag));
|
||||
EXPECT_EQ(101, collection.Tag(kTag1Tag));
|
||||
EXPECT_EQ(102, collection.Tag(kTag2Tag));
|
||||
// Test access using a range based for.
|
||||
int i = 0;
|
||||
for (int num : collection) {
|
||||
|
|
|
@ -134,6 +134,21 @@ void Tensor::AllocateMtlBuffer(id<MTLDevice> device) const {
|
|||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
bool Tensor::NeedsHalfFloatRenderTarget() const {
|
||||
static bool has_color_buffer_float =
|
||||
gl_context_->HasGlExtension("WEBGL_color_buffer_float") ||
|
||||
gl_context_->HasGlExtension("EXT_color_buffer_float");
|
||||
if (!has_color_buffer_float) {
|
||||
static bool has_color_buffer_half_float =
|
||||
gl_context_->HasGlExtension("EXT_color_buffer_half_float");
|
||||
LOG_IF(FATAL, !has_color_buffer_half_float)
|
||||
<< "EXT_color_buffer_half_float or WEBGL_color_buffer_float "
|
||||
<< "required on web to use MP tensor";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
||||
LOG_IF(FATAL, valid_ == kValidNone)
|
||||
<< "Tensor must be written prior to read from.";
|
||||
|
@ -164,8 +179,24 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
|||
// Set alignment for the proper value (default) to avoid address sanitizer
|
||||
// error "out of boundary reading".
|
||||
glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
|
||||
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
|
||||
GL_RGBA, GL_FLOAT, temp_buffer.get());
|
||||
#ifdef __EMSCRIPTEN__
|
||||
// Under WebGL1, format must match in order to use glTexSubImage2D, so if we
|
||||
// have a half-float texture, then uploading from GL_FLOAT here would fail.
|
||||
// We change the texture's data type to float here to accommodate.
|
||||
// Furthermore, for a full-image replacement operation, glTexImage2D is
|
||||
// expected to be more performant than glTexSubImage2D. Note that for WebGL2
|
||||
// we cannot use glTexImage2D, because we allocate using glTexStorage2D in
|
||||
// that case, which is incompatible.
|
||||
if (gl_context_->GetGlVersion() == mediapipe::GlVersion::kGLES2) {
|
||||
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
|
||||
0, GL_RGBA, GL_FLOAT, temp_buffer.get());
|
||||
texture_is_half_float_ = false;
|
||||
} else
|
||||
#endif // __EMSCRIPTEN__
|
||||
{
|
||||
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_,
|
||||
GL_RGBA, GL_FLOAT, temp_buffer.get());
|
||||
}
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
valid_ |= kValidOpenGlTexture2d;
|
||||
}
|
||||
|
@ -175,6 +206,16 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const {
|
|||
Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const {
|
||||
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
||||
AllocateOpenGlTexture2d();
|
||||
#ifdef __EMSCRIPTEN__
|
||||
// On web, we may have to change type from float to half-float
|
||||
if (!texture_is_half_float_ && NeedsHalfFloatRenderTarget()) {
|
||||
glBindTexture(GL_TEXTURE_2D, opengl_texture2d_);
|
||||
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, 0,
|
||||
GL_RGBA, GL_HALF_FLOAT_OES, 0 /* data */);
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
texture_is_half_float_ = true;
|
||||
}
|
||||
#endif
|
||||
valid_ = kValidOpenGlTexture2d;
|
||||
return {opengl_texture2d_, std::move(lock)};
|
||||
}
|
||||
|
@ -255,8 +296,18 @@ void Tensor::AllocateOpenGlTexture2d() const {
|
|||
<< "with GLES 2.0";
|
||||
// Allocate the image data; note that it's no longer RGBA32F, so will be
|
||||
// lower precision.
|
||||
auto type = GL_FLOAT;
|
||||
// On web, we might need to change type to half-float (e.g. for iOS-
|
||||
// Safari) in order to have a valid framebuffer. See b/194442743 for more
|
||||
// details.
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (NeedsHalfFloatRenderTarget()) {
|
||||
type = GL_HALF_FLOAT_OES;
|
||||
texture_is_half_float_ = true;
|
||||
}
|
||||
#endif // __EMSCRIPTEN
|
||||
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_,
|
||||
0, GL_RGBA, GL_FLOAT, 0 /* data */);
|
||||
0, GL_RGBA, type, 0 /* data */);
|
||||
}
|
||||
glBindTexture(GL_TEXTURE_2D, 0);
|
||||
glGenFramebuffers(1, &frame_buffer_);
|
||||
|
@ -443,7 +494,6 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
|||
glPixelStorei(GL_PACK_ALIGNMENT, 4);
|
||||
glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT,
|
||||
buffer);
|
||||
|
||||
uint8_t* dest_buffer = reinterpret_cast<uint8_t*>(cpu_buffer_);
|
||||
const int actual_depth_size =
|
||||
BhwcDepthFromShape(shape_) * element_size();
|
||||
|
|
|
@ -266,11 +266,15 @@ class Tensor {
|
|||
mutable GLuint frame_buffer_ = GL_INVALID_INDEX;
|
||||
mutable int texture_width_;
|
||||
mutable int texture_height_;
|
||||
#ifdef __EMSCRIPTEN__
|
||||
mutable bool texture_is_half_float_ = false;
|
||||
#endif // __EMSCRIPTEN__
|
||||
void AllocateOpenGlTexture2d() const;
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
mutable GLuint opengl_buffer_ = GL_INVALID_INDEX;
|
||||
void AllocateOpenGlBuffer() const;
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
bool NeedsHalfFloatRenderTarget() const;
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,11 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kOutputTag[] = "OUTPUT";
|
||||
constexpr char kEnableTag[] = "ENABLE";
|
||||
constexpr char kSelectTag[] = "SELECT";
|
||||
constexpr char kSideinputTag[] = "SIDEINPUT";
|
||||
|
||||
// Shows validation success for a graph and a subgraph.
|
||||
TEST(GraphValidationTest, InitializeGraphFromProtos) {
|
||||
auto config_1 = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -323,20 +328,21 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) {
|
|||
class OptionalSideInputTestCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag("SIDEINPUT").Set<std::string>().Optional();
|
||||
cc->Inputs().Tag("SELECT").Set<int>().Optional();
|
||||
cc->Inputs().Tag("ENABLE").Set<bool>().Optional();
|
||||
cc->Outputs().Tag("OUTPUT").Set<std::string>();
|
||||
cc->InputSidePackets().Tag(kSideinputTag).Set<std::string>().Optional();
|
||||
cc->Inputs().Tag(kSelectTag).Set<int>().Optional();
|
||||
cc->Inputs().Tag(kEnableTag).Set<bool>().Optional();
|
||||
cc->Outputs().Tag(kOutputTag).Set<std::string>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
std::string value("default");
|
||||
if (cc->InputSidePackets().HasTag("SIDEINPUT")) {
|
||||
value = cc->InputSidePackets().Tag("SIDEINPUT").Get<std::string>();
|
||||
if (cc->InputSidePackets().HasTag(kSideinputTag)) {
|
||||
value = cc->InputSidePackets().Tag(kSideinputTag).Get<std::string>();
|
||||
}
|
||||
cc->Outputs().Tag("OUTPUT").Add(new std::string(value),
|
||||
cc->InputTimestamp());
|
||||
cc->Outputs()
|
||||
.Tag(kOutputTag)
|
||||
.Add(new std::string(value), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -26,17 +26,20 @@ namespace {
|
|||
|
||||
namespace test_ns {
|
||||
|
||||
constexpr char kOutTag[] = "OUT";
|
||||
constexpr char kInTag[] = "IN";
|
||||
|
||||
class TestSinkCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("IN").Set<mediapipe::InputOnlyProto>();
|
||||
cc->Outputs().Tag("OUT").Set<int>();
|
||||
cc->Inputs().Tag(kInTag).Set<mediapipe::InputOnlyProto>();
|
||||
cc->Outputs().Tag(kOutTag).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
int x = cc->Inputs().Tag("IN").Get<mediapipe::InputOnlyProto>().x();
|
||||
cc->Outputs().Tag("OUT").AddPacket(
|
||||
int x = cc->Inputs().Tag(kInTag).Get<mediapipe::InputOnlyProto>().x();
|
||||
cc->Outputs().Tag(kOutTag).AddPacket(
|
||||
MakePacket<int>(x).At(cc->InputTimestamp()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -34,6 +34,19 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
constexpr char kOutTag[] = "OUT";
|
||||
constexpr char kClockTag[] = "CLOCK";
|
||||
constexpr char kSleepMicrosTag[] = "SLEEP_MICROS";
|
||||
constexpr char kCloseTag[] = "CLOSE";
|
||||
constexpr char kProcessTag[] = "PROCESS";
|
||||
constexpr char kOpenTag[] = "OPEN";
|
||||
constexpr char kTag[] = "";
|
||||
constexpr char kMeanTag[] = "MEAN";
|
||||
constexpr char kDataTag[] = "DATA";
|
||||
constexpr char kPairTag[] = "PAIR";
|
||||
constexpr char kLowTag[] = "LOW";
|
||||
constexpr char kHighTag[] = "HIGH";
|
||||
|
||||
using RandomEngine = std::mt19937_64;
|
||||
|
||||
// A Calculator that outputs twice the value of its input packet (an int).
|
||||
|
@ -95,9 +108,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
|
|||
PacketTypeSet* input_side_packets, //
|
||||
PacketTypeSet* output_side_packets) {
|
||||
input_side_packets->Index(0).Set<uint64>();
|
||||
output_side_packets->Tag("HIGH").Set<uint32>();
|
||||
output_side_packets->Tag("LOW").Set<uint32>();
|
||||
output_side_packets->Tag("PAIR").Set<std::pair<uint32, uint32>>();
|
||||
output_side_packets->Tag(kHighTag).Set<uint32>();
|
||||
output_side_packets->Tag(kLowTag).Set<uint32>();
|
||||
output_side_packets->Tag(kPairTag).Set<std::pair<uint32, uint32>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -108,9 +121,9 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator {
|
|||
uint64 value = input_side_packets.Index(0).Get<uint64>();
|
||||
uint32 high = value >> 32;
|
||||
uint32 low = value & 0xFFFFFFFF;
|
||||
output_side_packets->Tag("HIGH") = Adopt(new uint32(high));
|
||||
output_side_packets->Tag("LOW") = Adopt(new uint32(low));
|
||||
output_side_packets->Tag("PAIR") =
|
||||
output_side_packets->Tag(kHighTag) = Adopt(new uint32(high));
|
||||
output_side_packets->Tag(kLowTag) = Adopt(new uint32(low));
|
||||
output_side_packets->Tag(kPairTag) =
|
||||
Adopt(new std::pair<uint32, uint32>(high, low));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -221,8 +234,8 @@ class StdDevCalculator : public CalculatorBase {
|
|||
StdDevCalculator() {}
|
||||
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Tag("DATA").Set<int>();
|
||||
cc->Inputs().Tag("MEAN").Set<double>();
|
||||
cc->Inputs().Tag(kDataTag).Set<int>();
|
||||
cc->Inputs().Tag(kMeanTag).Set<double>();
|
||||
cc->Outputs().Index(0).Set<int>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -234,15 +247,15 @@ class StdDevCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
RET_CHECK(cc->Inputs().Tag("DATA").Value().IsEmpty());
|
||||
RET_CHECK(!cc->Inputs().Tag("MEAN").Value().IsEmpty());
|
||||
mean_ = cc->Inputs().Tag("MEAN").Get<double>();
|
||||
RET_CHECK(cc->Inputs().Tag(kDataTag).Value().IsEmpty());
|
||||
RET_CHECK(!cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
|
||||
mean_ = cc->Inputs().Tag(kMeanTag).Get<double>();
|
||||
initialized_ = true;
|
||||
} else {
|
||||
RET_CHECK(initialized_);
|
||||
RET_CHECK(!cc->Inputs().Tag("DATA").Value().IsEmpty());
|
||||
RET_CHECK(cc->Inputs().Tag("MEAN").Value().IsEmpty());
|
||||
double diff = cc->Inputs().Tag("DATA").Get<int>() - mean_;
|
||||
RET_CHECK(!cc->Inputs().Tag(kDataTag).Value().IsEmpty());
|
||||
RET_CHECK(cc->Inputs().Tag(kMeanTag).Value().IsEmpty());
|
||||
double diff = cc->Inputs().Tag(kDataTag).Get<int>() - mean_;
|
||||
cummulative_variance_ += diff * diff;
|
||||
++count_;
|
||||
}
|
||||
|
@ -564,8 +577,8 @@ class LambdaCalculator : public CalculatorBase {
|
|||
id < cc->Outputs().EndId(); ++id) {
|
||||
cc->Outputs().Get(id).SetAny();
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("") > 0) {
|
||||
cc->InputSidePackets().Tag("").Set<ProcessFunction>();
|
||||
if (cc->InputSidePackets().HasTag(kTag) > 0) {
|
||||
cc->InputSidePackets().Tag(kTag).Set<ProcessFunction>();
|
||||
}
|
||||
for (const std::string& tag : {"OPEN", "PROCESS", "CLOSE"}) {
|
||||
if (cc->InputSidePackets().HasTag(tag)) {
|
||||
|
@ -576,24 +589,24 @@ class LambdaCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
if (cc->InputSidePackets().HasTag("OPEN")) {
|
||||
if (cc->InputSidePackets().HasTag(kOpenTag)) {
|
||||
return GetContextFn(cc, "OPEN")(cc);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (cc->InputSidePackets().HasTag("PROCESS")) {
|
||||
if (cc->InputSidePackets().HasTag(kProcessTag)) {
|
||||
return GetContextFn(cc, "PROCESS")(cc);
|
||||
}
|
||||
if (cc->InputSidePackets().HasTag("") > 0) {
|
||||
if (cc->InputSidePackets().HasTag(kTag) > 0) {
|
||||
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Close(CalculatorContext* cc) final {
|
||||
if (cc->InputSidePackets().HasTag("CLOSE")) {
|
||||
if (cc->InputSidePackets().HasTag(kCloseTag)) {
|
||||
return GetContextFn(cc, "CLOSE")(cc);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -645,17 +658,18 @@ class PassThroughWithSleepCalculator : public CalculatorBase {
|
|||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
cc->InputSidePackets().Tag("SLEEP_MICROS").Set<int>();
|
||||
cc->InputSidePackets().Tag("CLOCK").Set<std::shared_ptr<Clock>>();
|
||||
cc->InputSidePackets().Tag(kSleepMicrosTag).Set<int>();
|
||||
cc->InputSidePackets().Tag(kClockTag).Set<std::shared_ptr<Clock>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get<int>();
|
||||
sleep_micros_ = cc->InputSidePackets().Tag(kSleepMicrosTag).Get<int>();
|
||||
if (sleep_micros_ < 0) {
|
||||
return absl::InternalError("SLEEP_MICROS should be >= 0");
|
||||
}
|
||||
clock_ = cc->InputSidePackets().Tag("CLOCK").Get<std::shared_ptr<Clock>>();
|
||||
clock_ =
|
||||
cc->InputSidePackets().Tag(kClockTag).Get<std::shared_ptr<Clock>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
|
@ -678,8 +692,8 @@ class MultiplyIntCalculator : public CalculatorBase {
|
|||
cc->Inputs().Index(0).Set<int>();
|
||||
cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0));
|
||||
// cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
|
||||
RET_CHECK(cc->Outputs().HasTag("OUT"));
|
||||
cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0));
|
||||
RET_CHECK(cc->Outputs().HasTag(kOutTag));
|
||||
cc->Outputs().Tag(kOutTag).SetSameAs(&cc->Inputs().Index(0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
|
@ -689,7 +703,7 @@ class MultiplyIntCalculator : public CalculatorBase {
|
|||
absl::Status Process(CalculatorContext* cc) final {
|
||||
int x = cc->Inputs().Index(0).Value().Get<int>();
|
||||
int y = cc->Inputs().Index(1).Value().Get<int>();
|
||||
cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp());
|
||||
cc->Outputs().Tag(kOutTag).Add(new int(x * y), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -60,6 +60,18 @@ def mediapipe_aar(
|
|||
assets: additional assets to be included into the archive.
|
||||
assets_dir: path where the assets will the packaged.
|
||||
"""
|
||||
|
||||
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
|
||||
# the OpenCV so libraries will be excluded from the AAR package to
|
||||
# save the package size.
|
||||
native.config_setting(
|
||||
name = "exclude_opencv_so_lib",
|
||||
define_values = {
|
||||
"EXCLUDE_OPENCV_SO_LIB": "1",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
_mediapipe_jni(
|
||||
name = name + "_jni",
|
||||
gen_libmediapipe = gen_libmediapipe,
|
||||
|
@ -133,6 +145,7 @@ EOF
|
|||
] + select({
|
||||
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
"exclude_opencv_so_lib": [],
|
||||
}),
|
||||
assets = assets,
|
||||
assets_dir = assets_dir,
|
||||
|
|
|
@ -245,14 +245,6 @@ void Box::Fit(const std::vector<T>& vertices) {
|
|||
auto system_g = system_h.colPivHouseholderQr();
|
||||
auto solution = system_g.solve(v).eval();
|
||||
transformation_.topLeftCorner<3, 4>() = solution.transpose();
|
||||
|
||||
// Adjust rotation matrix to its nearest orthogonal matrix.
|
||||
const auto rotation = transformation_.topLeftCorner<3, 3>();
|
||||
Eigen::JacobiSVD<Eigen::Matrix3f> svd(
|
||||
rotation, Eigen::ComputeFullV | Eigen::ComputeFullU);
|
||||
const Eigen::Matrix3f matrix_u = svd.matrixU();
|
||||
const Eigen::Matrix3f matrix_v = svd.matrixV();
|
||||
transformation_.topLeftCorner<3, 3>() = matrix_u * matrix_v.transpose();
|
||||
Update();
|
||||
}
|
||||
|
||||
|
|
|
@ -254,11 +254,11 @@ class SolutionBase:
|
|||
for stream_name in self._output_stream_type_info.keys():
|
||||
self._graph.observe_output_stream(stream_name, callback, True)
|
||||
|
||||
input_side_packets = {
|
||||
self._input_side_packets = {
|
||||
name: self._make_packet(self._side_input_type_info[name], data)
|
||||
for name, data in (side_inputs or {}).items()
|
||||
}
|
||||
self._graph.start_run(input_side_packets)
|
||||
self._graph.start_run(self._input_side_packets)
|
||||
|
||||
# TODO: Use "inspect.Parameter" to fetch the input argument names and
|
||||
# types from "_input_stream_type_info" and then auto generate the process
|
||||
|
@ -353,6 +353,12 @@ class SolutionBase:
|
|||
self._input_stream_type_info = None
|
||||
self._output_stream_type_info = None
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the graph for another run."""
|
||||
if self._graph:
|
||||
self._graph.close()
|
||||
self._graph.start_run(self._input_side_packets)
|
||||
|
||||
def _initialize_graph_interface(
|
||||
self,
|
||||
validated_graph: validated_graph_config.ValidatedGraphConfig,
|
||||
|
|
|
@ -298,6 +298,56 @@ class SolutionBaseTest(parameterized.TestCase):
|
|||
'ImageTransformation.output_height': 0
|
||||
})
|
||||
|
||||
@parameterized.named_parameters(('graph_without_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:transformed_image_in'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}
|
||||
""", None), ('graph_with_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
input_side_packet: 'allow_signal'
|
||||
input_side_packet: 'rotation_degrees'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'transformed_image_in'
|
||||
input_side_packet: 'ALLOW:allow_signal'
|
||||
output_stream: 'image_out_to_transform'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_out_to_transform'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}""", {
|
||||
'allow_signal': True,
|
||||
'rotation_degrees': 0
|
||||
}))
|
||||
def test_solution_reset(self, text_config, side_inputs):
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
|
||||
with solution_base.SolutionBase(
|
||||
graph_config=config_proto, side_inputs=side_inputs) as solution:
|
||||
for _ in range(20):
|
||||
outputs = solution.process(input_image)
|
||||
self.assertTrue(np.array_equal(input_image, outputs.image_out))
|
||||
solution.reset()
|
||||
|
||||
def _process_and_verify(self,
|
||||
config_proto,
|
||||
side_inputs=None,
|
||||
|
|
|
@ -26,7 +26,7 @@ import numpy.testing as npt
|
|||
from mediapipe.python.solutions import objectron as mp_objectron
|
||||
|
||||
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
||||
DIFF_THRESHOLD = 35 # pixels
|
||||
DIFF_THRESHOLD = 30 # pixels
|
||||
EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457],
|
||||
[383, 505], [80, 478], [408, 345],
|
||||
[130, 347], [384, 355], [72, 353]],
|
||||
|
|
Loading…
Reference in New Issue
Block a user