diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 5f5f51657..995ba6a7f 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -379,6 +379,7 @@ cc_library( "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 521d27063..c8bdfd9e2 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -41,6 +41,7 @@ const char kImageTag[] = "IMAGE"; const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; const char kClipLabelPrefixTag[] = "CLIP_LABEL_"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; +const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_"; @@ -175,6 +176,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } @@ -279,6 +283,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearClipLabelString(key, sequence_.get()); mpms::ClearClipLabelConfidence(key, sequence_.get()); } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { + const std::string& key = + tag.substr(sizeof(kIntsContextFeaturePrefixTag) / + sizeof(*kIntsContextFeaturePrefixTag) - + 1); + mpms::ClearContextFeatureInts(key, sequence_.get()); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / sizeof(*kFloatFeaturePrefixTag) - @@ -518,6 +529,19 @@ class PackMediaSequenceCalculator : public CalculatorBase { key, cc->Inputs().Tag(tag).Get>(), sequence_.get()); } + if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + const std::string& key = + tag.substr(sizeof(kIntsContextFeaturePrefixTag) / + sizeof(*kIntsContextFeaturePrefixTag) - + 1); + // To ensure only one packet is provided for this tag. + RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream()); + for (const auto& value : + cc->Inputs().Tag(tag).Get>()) { + mpms::AddContextFeatureInts(key, value, sequence_.get()); + } + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 9adca0013..3644505d8 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -58,6 +58,8 @@ constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST"; constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; +constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST"; +constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER"; constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER"; @@ -451,6 +453,119 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { testing::ElementsAre(4, 4)); } +TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + + const std::vector vi_1 = {2, 3}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {2, 4}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(2, 3)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(2, 4)); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vi_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(5, 6)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(7, 8)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextIntLists) { + SetUpCalculator( + /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", + "INTS_CONTEXT_FEATURE_OTHER:test2"}, + /*features=*/{}, + /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = absl::make_unique(); + mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); + mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); + + const std::vector vi_1 = {5, 6}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureTestTag) + .packets.push_back( + MakePacket>(vi_1).At(Timestamp::PostStream())); + const std::vector vi_2 = {7, 8}; + runner_->MutableInputs() + ->Tag(kIntsContextFeatureOtherTag) + .packets.push_back( + MakePacket>(vi_2).At(Timestamp::PostStream())); + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), + testing::ElementsAre(2, 3, 5, 6)); + ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), + testing::ElementsAre(2, 4, 7, 8)); +} + TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { tf::Features context; (*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES");