No public description
PiperOrigin-RevId: 562071599
This commit is contained in:
parent
23057ac146
commit
ab70d92752
|
@ -379,6 +379,7 @@ cc_library(
|
|||
"//mediapipe/util/sequence:media_sequence",
|
||||
"//mediapipe/util/sequence:media_sequence_util",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -41,6 +41,7 @@ const char kImageTag[] = "IMAGE";
|
|||
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
|
||||
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
|
||||
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
||||
const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_";
|
||||
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
||||
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
||||
const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_";
|
||||
|
@ -175,6 +176,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
|
||||
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
||||
}
|
||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
||||
cc->Inputs().Tag(tag).Set<std::vector<int64_t>>();
|
||||
}
|
||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
||||
}
|
||||
|
@ -279,6 +283,13 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
mpms::ClearClipLabelString(key, sequence_.get());
|
||||
mpms::ClearClipLabelConfidence(key, sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
||||
const std::string& key =
|
||||
tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
|
||||
sizeof(*kIntsContextFeaturePrefixTag) -
|
||||
1);
|
||||
mpms::ClearContextFeatureInts(key, sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||
sizeof(*kFloatFeaturePrefixTag) -
|
||||
|
@ -518,6 +529,19 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
key, cc->Inputs().Tag(tag).Get<std::vector<float>>(),
|
||||
sequence_.get());
|
||||
}
|
||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) &&
|
||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
const std::string& key =
|
||||
tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
|
||||
sizeof(*kIntsContextFeaturePrefixTag) -
|
||||
1);
|
||||
// To ensure only one packet is provided for this tag.
|
||||
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
|
||||
for (const auto& value :
|
||||
cc->Inputs().Tag(tag).Get<std::vector<int64_t>>()) {
|
||||
mpms::AddContextFeatureInts(key, value, sequence_.get());
|
||||
}
|
||||
}
|
||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
|
||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||
|
|
|
@ -58,6 +58,8 @@ constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST";
|
|||
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
|
||||
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
|
||||
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
|
||||
constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST";
|
||||
constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER";
|
||||
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
|
||||
|
@ -451,6 +453,119 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
|
|||
testing::ElementsAre(4, 4));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test",
|
||||
"INTS_CONTEXT_FEATURE_OTHER:test2"},
|
||||
/*features=*/{},
|
||||
/*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true);
|
||||
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||
|
||||
const std::vector<int64_t> vi_1 = {2, 3};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureTestTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream()));
|
||||
const std::vector<int64_t> vi_2 = {2, 4};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureOtherTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream()));
|
||||
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence),
|
||||
testing::ElementsAre(2, 3));
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence),
|
||||
testing::ElementsAre(2, 4));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextIntLists) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test",
|
||||
"INTS_CONTEXT_FEATURE_OTHER:test2"},
|
||||
/*features=*/{},
|
||||
/*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true);
|
||||
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||
mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get());
|
||||
mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get());
|
||||
|
||||
const std::vector<int64_t> vi_1 = {5, 6};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureTestTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream()));
|
||||
const std::vector<int64_t> vi_2 = {7, 8};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureOtherTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream()));
|
||||
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence),
|
||||
testing::ElementsAre(5, 6));
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence),
|
||||
testing::ElementsAre(7, 8));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextIntLists) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test",
|
||||
"INTS_CONTEXT_FEATURE_OTHER:test2"},
|
||||
/*features=*/{},
|
||||
/*output_only_if_all_present=*/false,
|
||||
/*replace_instead_of_append=*/false);
|
||||
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||
mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get());
|
||||
mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get());
|
||||
|
||||
const std::vector<int64_t> vi_1 = {5, 6};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureTestTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream()));
|
||||
const std::vector<int64_t> vi_2 = {7, 8};
|
||||
runner_->MutableInputs()
|
||||
->Tag(kIntsContextFeatureOtherTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream()));
|
||||
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence),
|
||||
testing::ElementsAre(2, 3, 5, 6));
|
||||
ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence),
|
||||
testing::ElementsAre(2, 4, 7, 8));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
||||
tf::Features context;
|
||||
(*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES");
|
||||
|
|
Loading…
Reference in New Issue
Block a user