No public description
PiperOrigin-RevId: 562071599
This commit is contained in:
		
							parent
							
								
									23057ac146
								
							
						
					
					
						commit
						ab70d92752
					
				|  | @ -379,6 +379,7 @@ cc_library( | |||
|         "//mediapipe/util/sequence:media_sequence", | ||||
|         "//mediapipe/util/sequence:media_sequence_util", | ||||
|         "@com_google_absl//absl/container:flat_hash_map", | ||||
|         "@com_google_absl//absl/log", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@org_tensorflow//tensorflow/core:protos_all_cc", | ||||
|  |  | |||
|  | @ -41,6 +41,7 @@ const char kImageTag[] = "IMAGE"; | |||
| const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; | ||||
| const char kClipLabelPrefixTag[] = "CLIP_LABEL_"; | ||||
| const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; | ||||
| const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_"; | ||||
| const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; | ||||
| const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; | ||||
| const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_"; | ||||
|  | @ -175,6 +176,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|       if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { | ||||
|         cc->Inputs().Tag(tag).Set<std::vector<float>>(); | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { | ||||
|         cc->Inputs().Tag(tag).Set<std::vector<int64_t>>(); | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { | ||||
|         cc->Inputs().Tag(tag).Set<std::vector<float>>(); | ||||
|       } | ||||
|  | @ -279,6 +283,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|           mpms::ClearClipLabelString(key, sequence_.get()); | ||||
|           mpms::ClearClipLabelConfidence(key, sequence_.get()); | ||||
|         } | ||||
|         if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) { | ||||
|           const std::string& key = | ||||
|               tag.substr(sizeof(kIntsContextFeaturePrefixTag) / | ||||
|                              sizeof(*kIntsContextFeaturePrefixTag) - | ||||
|                          1); | ||||
|           mpms::ClearContextFeatureInts(key, sequence_.get()); | ||||
|         } | ||||
|         if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { | ||||
|           std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / | ||||
|                                            sizeof(*kFloatFeaturePrefixTag) - | ||||
|  | @ -518,6 +529,19 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|             key, cc->Inputs().Tag(tag).Get<std::vector<float>>(), | ||||
|             sequence_.get()); | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) && | ||||
|           !cc->Inputs().Tag(tag).IsEmpty()) { | ||||
|         const std::string& key = | ||||
|             tag.substr(sizeof(kIntsContextFeaturePrefixTag) / | ||||
|                            sizeof(*kIntsContextFeaturePrefixTag) - | ||||
|                        1); | ||||
|         // To ensure only one packet is provided for this tag.
 | ||||
|         RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream()); | ||||
|         for (const auto& value : | ||||
|              cc->Inputs().Tag(tag).Get<std::vector<int64_t>>()) { | ||||
|           mpms::AddContextFeatureInts(key, value, sequence_.get()); | ||||
|         } | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kFloatFeaturePrefixTag) && | ||||
|           !cc->Inputs().Tag(tag).IsEmpty()) { | ||||
|         std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / | ||||
|  |  | |||
|  | @ -58,6 +58,8 @@ constexpr char kBytesFeatureTestTag[] = "BYTES_FEATURE_TEST"; | |||
| constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; | ||||
| constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; | ||||
| constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; | ||||
| constexpr char kIntsContextFeatureTestTag[] = "INTS_CONTEXT_FEATURE_TEST"; | ||||
| constexpr char kIntsContextFeatureOtherTag[] = "INTS_CONTEXT_FEATURE_OTHER"; | ||||
| constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; | ||||
| constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; | ||||
| constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER"; | ||||
|  | @ -451,6 +453,119 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) { | |||
|               testing::ElementsAre(4, 4)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", | ||||
|                          "INTS_CONTEXT_FEATURE_OTHER:test2"}, | ||||
|       /*features=*/{}, | ||||
|       /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = absl::make_unique<tf::SequenceExample>(); | ||||
| 
 | ||||
|   const std::vector<int64_t> vi_1 = {2, 3}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureTestTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream())); | ||||
|   const std::vector<int64_t> vi_2 = {2, 4}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureOtherTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream())); | ||||
| 
 | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), | ||||
|               testing::ElementsAre(2, 3)); | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), | ||||
|               testing::ElementsAre(2, 4)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextIntLists) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", | ||||
|                          "INTS_CONTEXT_FEATURE_OTHER:test2"}, | ||||
|       /*features=*/{}, | ||||
|       /*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = absl::make_unique<tf::SequenceExample>(); | ||||
|   mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); | ||||
|   mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); | ||||
| 
 | ||||
|   const std::vector<int64_t> vi_1 = {5, 6}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureTestTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream())); | ||||
|   const std::vector<int64_t> vi_2 = {7, 8}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureOtherTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream())); | ||||
| 
 | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), | ||||
|               testing::ElementsAre(5, 6)); | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), | ||||
|               testing::ElementsAre(7, 8)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextIntLists) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test", | ||||
|                          "INTS_CONTEXT_FEATURE_OTHER:test2"}, | ||||
|       /*features=*/{}, | ||||
|       /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/false); | ||||
|   auto input_sequence = absl::make_unique<tf::SequenceExample>(); | ||||
|   mpms::SetContextFeatureInts("TEST", {2, 3}, input_sequence.get()); | ||||
|   mpms::SetContextFeatureInts("OTHER", {2, 4}, input_sequence.get()); | ||||
| 
 | ||||
|   const std::vector<int64_t> vi_1 = {5, 6}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureTestTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_1).At(Timestamp::PostStream())); | ||||
|   const std::vector<int64_t> vi_2 = {7, 8}; | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kIntsContextFeatureOtherTag) | ||||
|       .packets.push_back( | ||||
|           MakePacket<std::vector<int64_t>>(vi_2).At(Timestamp::PostStream())); | ||||
| 
 | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("TEST", output_sequence), | ||||
|               testing::ElementsAre(2, 3, 5, 6)); | ||||
|   ASSERT_THAT(mpms::GetContextFeatureInts("OTHER", output_sequence), | ||||
|               testing::ElementsAre(2, 4, 7, 8)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { | ||||
|   tf::Features context; | ||||
|   (*context.mutable_feature())["TEST"].mutable_bytes_list()->add_value("YES"); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user