No public description
PiperOrigin-RevId: 562075110
This commit is contained in:
parent
ab70d92752
commit
d6119957a4
|
@ -42,6 +42,7 @@ const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
|
||||||
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
|
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
|
||||||
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
||||||
const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_";
|
const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_";
|
||||||
|
const char kBytesContextFeaturePrefixTag[] = "BYTES_CONTEXT_FEATURE_";
|
||||||
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
||||||
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
||||||
const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_";
|
const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_";
|
||||||
|
@ -179,6 +180,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
||||||
cc->Inputs().Tag(tag).Set<std::vector<int64_t>>();
|
cc->Inputs().Tag(tag).Set<std::vector<int64_t>>();
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) {
|
||||||
|
cc->Inputs().Tag(tag).Set<std::vector<std::string>>();
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||||
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
|
@ -290,6 +294,13 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
1);
|
1);
|
||||||
mpms::ClearContextFeatureInts(key, sequence_.get());
|
mpms::ClearContextFeatureInts(key, sequence_.get());
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) {
|
||||||
|
const std::string& key =
|
||||||
|
tag.substr(sizeof(kBytesContextFeaturePrefixTag) /
|
||||||
|
sizeof(*kBytesContextFeaturePrefixTag) -
|
||||||
|
1);
|
||||||
|
mpms::ClearContextFeatureBytes(key, sequence_.get());
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||||
sizeof(*kFloatFeaturePrefixTag) -
|
sizeof(*kFloatFeaturePrefixTag) -
|
||||||
|
@ -542,6 +553,19 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
mpms::AddContextFeatureInts(key, value, sequence_.get());
|
mpms::AddContextFeatureInts(key, value, sequence_.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag) &&
|
||||||
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
|
const std::string& key =
|
||||||
|
tag.substr(sizeof(kBytesContextFeaturePrefixTag) /
|
||||||
|
sizeof(*kBytesContextFeaturePrefixTag) -
|
||||||
|
1);
|
||||||
|
// To ensure only one packet is provided for this tag.
|
||||||
|
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
|
||||||
|
for (const auto& value :
|
||||||
|
cc->Inputs().Tag(tag).Get<std::vector<std::string>>()) {
|
||||||
|
mpms::AddContextFeatureBytes(key, value, sequence_.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
|
if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
|
||||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||||
|
|
|
@ -60,6 +60,8 @@ constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
|
||||||
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
|
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
|
||||||
constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST";
|
constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST";
|
||||||
constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER";
|
constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER";
|
||||||
|
constexpr char kBytesContextFeatureTestTag[] = "BYTES_CONTEXT_FEATURE_TEST";
|
||||||
|
constexpr char kBytesContextFeatureOtherTag[] = "BYTES_CONTEXT_FEATURE_OTHER";
|
||||||
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||||
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||||
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
|
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
|
||||||
|
@ -566,6 +568,125 @@ TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextIntLists) {
|
||||||
testing::ElementsAre(2, 4, 7, 8));
|
testing::ElementsAre(2, 4, 7, 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextByteLists) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test",
|
||||||
|
"BYTES_CONTEXT_FEATURE_OTHER:test2"},
|
||||||
|
/*features=*/{},
|
||||||
|
/*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
const std::vector<std::string> vb_1 = {"value_1", "value_2"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureTestTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_1).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
const std::vector<std::string> vb_2 = {"value_3", "value_4"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureOtherTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_2).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("value_1", "value_2"));
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("value_3", "value_4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextByteLists) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test",
|
||||||
|
"BYTES_CONTEXT_FEATURE_OTHER:test2"},
|
||||||
|
/*features=*/{},
|
||||||
|
/*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetContextFeatureBytes("TEST", {"existing_value_1", "existing_value_2"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetContextFeatureBytes(
|
||||||
|
"OTHER", {"existing_value_3", "existing_value_4"}, input_sequence.get());
|
||||||
|
|
||||||
|
const std::vector<std::string> vb_1 = {"value_1", "value_2"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureTestTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_1).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
const std::vector<std::string> vb_2 = {"value_3", "value_4"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureOtherTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_2).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("value_1", "value_2"));
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("value_3", "value_4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextByteLists) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"BYTES_CONTEXT_FEATURE_TEST:test",
|
||||||
|
"BYTES_CONTEXT_FEATURE_OTHER:test2"},
|
||||||
|
/*features=*/{},
|
||||||
|
/*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/false);
|
||||||
|
auto input_sequence = absl::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetContextFeatureBytes("TEST", {"existing_value_1", "existing_value_2"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetContextFeatureBytes(
|
||||||
|
"OTHER", {"existing_value_3", "existing_value_4"}, input_sequence.get());
|
||||||
|
|
||||||
|
const std::vector<std::string> vb_1 = {"value_1", "value_2"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureTestTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_1).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
const std::vector<std::string> vb_2 = {"value_3", "value_4"};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kBytesContextFeatureOtherTag)
|
||||||
|
.packets.push_back(MakePacket<std::vector<std::string>>(vb_2).At(
|
||||||
|
Timestamp::PostStream()));
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("existing_value_1", "existing_value_2",
|
||||||
|
"value_1", "value_2"));
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureBytes("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("existing_value_3", "existing_value_4",
|
||||||
|
"value_3", "value_4"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) {
|
||||||
tf::Features context;
|
tf::Features context;
|
||||||
(*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES");
|
(*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES");
|
||||||
|
|
Loading…
Reference in New Issue
Block a user