Update PackMediaSequenceCalculator to support setting clip/media/string, clip/media/confidence and clip/label/index.
The input stream is provided as drishti::Detection. PiperOrigin-RevId: 562070790
This commit is contained in:
		
							parent
							
								
									e7d071ab39
								
							
						
					
					
						commit
						23057ac146
					
				|  | @ -379,6 +379,7 @@ cc_library( | |||
|         "//mediapipe/util/sequence:media_sequence", | ||||
|         "//mediapipe/util/sequence:media_sequence_util", | ||||
|         "@com_google_absl//absl/container:flat_hash_map", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@org_tensorflow//tensorflow/core:protos_all_cc", | ||||
|     ], | ||||
|  | @ -950,8 +951,10 @@ cc_test( | |||
|         "//mediapipe/framework/port:gtest_main", | ||||
|         "//mediapipe/framework/port:opencv_imgcodecs", | ||||
|         "//mediapipe/util/sequence:media_sequence", | ||||
|         "//mediapipe/util/sequence:media_sequence_util", | ||||
|         "@com_google_absl//absl/log:absl_check", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_googletest//:gtest_main", | ||||
|         "@org_tensorflow//tensorflow/core:protos_all_cc", | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ | |||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/container/flat_hash_map.h" | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/strings/match.h" | ||||
| #include "absl/strings/strip.h" | ||||
| #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" | ||||
|  | @ -38,6 +39,7 @@ namespace mediapipe { | |||
| const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; | ||||
| const char kImageTag[] = "IMAGE"; | ||||
| const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; | ||||
| const char kClipLabelPrefixTag[] = "CLIP_LABEL_"; | ||||
| const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; | ||||
| const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; | ||||
| const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; | ||||
|  | @ -70,6 +72,8 @@ namespace mpms = mediapipe::mediasequence; | |||
| // * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
 | ||||
| //   vector<pair<float, float>>>,
 | ||||
| // * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
 | ||||
| // * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
 | ||||
| //   mediapipe::Detection.
 | ||||
| // "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
 | ||||
| // prefixed versions of each stream, which allows for multiple image streams to
 | ||||
| // be included. However, the default names are suppored by more tools.
 | ||||
|  | @ -165,6 +169,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|         } | ||||
|         cc->Inputs().Tag(tag).Set<std::vector<Detection>>(); | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kClipLabelPrefixTag)) { | ||||
|         cc->Inputs().Tag(tag).Set<Detection>(); | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { | ||||
|         cc->Inputs().Tag(tag).Set<std::vector<float>>(); | ||||
|       } | ||||
|  | @ -217,6 +224,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|     replace_keypoints_ = false; | ||||
|     if (cc->Options<PackMediaSequenceCalculatorOptions>() | ||||
|             .replace_data_instead_of_append()) { | ||||
|       // Clear the existing values under the same key.
 | ||||
|       for (const auto& tag : cc->Inputs().GetTags()) { | ||||
|         if (absl::StartsWith(tag, kImageTag)) { | ||||
|           if (absl::StartsWith(tag, kImageLabelPrefixTag)) { | ||||
|  | @ -264,6 +272,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|           mpms::ClearBBoxTrackIndex(key, sequence_.get()); | ||||
|           mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get()); | ||||
|         } | ||||
|         if (absl::StartsWith(tag, kClipLabelPrefixTag)) { | ||||
|           const std::string& key = tag.substr( | ||||
|               sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); | ||||
|           mpms::ClearClipLabelIndex(key, sequence_.get()); | ||||
|           mpms::ClearClipLabelString(key, sequence_.get()); | ||||
|           mpms::ClearClipLabelConfidence(key, sequence_.get()); | ||||
|         } | ||||
|         if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { | ||||
|           std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / | ||||
|                                            sizeof(*kFloatFeaturePrefixTag) - | ||||
|  | @ -465,6 +480,33 @@ class PackMediaSequenceCalculator : public CalculatorBase { | |||
|         } | ||||
|         replace_keypoints_ = false; | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kClipLabelPrefixTag) && | ||||
|           !cc->Inputs().Tag(tag).IsEmpty()) { | ||||
|         const std::string& key = tag.substr( | ||||
|             sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); | ||||
|         const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>(); | ||||
|         if (detection.label().size() != detection.score().size()) { | ||||
|           return absl::InvalidArgumentError( | ||||
|               "Different size of detection.label and detection.score"); | ||||
|         } | ||||
|         // Allow empty label_ids, but if label_ids is not empty, it should have
 | ||||
|         // the same size as the label and score fields.
 | ||||
|         if (!detection.label_id().empty()) { | ||||
|           if (detection.label_id().size() != detection.label().size()) { | ||||
|             return absl::InvalidArgumentError( | ||||
|                 "Different size of detection.label_id and detection.label"); | ||||
|           } | ||||
|         } | ||||
|         for (int i = 0; i < detection.label().size(); ++i) { | ||||
|           if (!detection.label_id().empty()) { | ||||
|             mpms::AddClipLabelIndex(key, detection.label_id(i), | ||||
|                                     sequence_.get()); | ||||
|           } | ||||
|           mpms::AddClipLabelString(key, detection.label(i), sequence_.get()); | ||||
|           mpms::AddClipLabelConfidence(key, detection.score(i), | ||||
|                                        sequence_.get()); | ||||
|         } | ||||
|       } | ||||
|       if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) && | ||||
|           !cc->Inputs().Tag(tag).IsEmpty()) { | ||||
|         std::string key = | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ | |||
| 
 | ||||
| #include "absl/log/absl_check.h" | ||||
| #include "absl/memory/memory.h" | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/strings/str_cat.h" | ||||
| #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" | ||||
| #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" | ||||
|  | @ -31,6 +32,7 @@ | |||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/framework/timestamp.h" | ||||
| #include "mediapipe/util/sequence/media_sequence.h" | ||||
| #include "mediapipe/util/sequence/media_sequence_util.h" | ||||
| #include "tensorflow/core/example/example.pb.h" | ||||
| #include "tensorflow/core/example/feature.pb.h" | ||||
| #include "testing/base/public/gmock.h" | ||||
|  | @ -66,6 +68,8 @@ constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; | |||
| constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; | ||||
| constexpr char kImageTag[] = "IMAGE"; | ||||
| constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID"; | ||||
| constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST"; | ||||
| constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER"; | ||||
| 
 | ||||
| class PackMediaSequenceCalculatorTest : public ::testing::Test { | ||||
|  protected: | ||||
|  | @ -848,6 +852,283 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { | |||
|               testing::ElementsAreArray(::std::vector<std::string>({"mask"}))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
| 
 | ||||
|   Detection detection_1; | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_label_id(1); | ||||
|   detection_1.add_label_id(2); | ||||
|   detection_1.add_score(0.1); | ||||
|   detection_1.add_score(0.2); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   // No label ID for detection_2.
 | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_score(0.3); | ||||
|   detection_2.add_score(0.4); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), | ||||
|               testing::ElementsAre("label_1", "label_2")); | ||||
|   ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), | ||||
|               testing::ElementsAre(1, 2)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), | ||||
|               testing::ElementsAre(0.1, 0.2)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), | ||||
|               testing::ElementsAre("label_3", "label_4")); | ||||
|   ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), | ||||
|               testing::ElementsAre(0.3, 0.4)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, | ||||
|        PackTwoClipLabels_DifferentLabelScoreSize) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
| 
 | ||||
|   // 2 labels and 1 score in detection_1.
 | ||||
|   Detection detection_1; | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_score(0.1); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_score(0.3); | ||||
|   detection_2.add_score(0.4); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   ASSERT_THAT( | ||||
|       runner_->Run(), | ||||
|       testing::status::StatusIs( | ||||
|           absl::StatusCode::kInvalidArgument, | ||||
|           testing::HasSubstr( | ||||
|               "Different size of detection.label and detection.score"))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, | ||||
|        PackTwoClipLabels_DifferentLabelIdSize) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
| 
 | ||||
|   // 2 labels and 1 label_id in detection_1.
 | ||||
|   Detection detection_1; | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_label_id(1); | ||||
|   detection_1.add_score(0.1); | ||||
|   detection_1.add_score(0.2); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_score(0.3); | ||||
|   detection_2.add_score(0.4); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   ASSERT_THAT( | ||||
|       runner_->Run(), | ||||
|       testing::status::StatusIs( | ||||
|           absl::StatusCode::kInvalidArgument, | ||||
|           testing::HasSubstr( | ||||
|               "Different size of detection.label_id and detection.label"))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) { | ||||
|   // Replace existing clip/label/string and clip/label/confidence values for
 | ||||
|   // the prefixes.
 | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
|   mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, | ||||
|                            input_sequence.get()); | ||||
|   mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); | ||||
|   mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, | ||||
|                            input_sequence.get()); | ||||
|   mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); | ||||
| 
 | ||||
|   Detection detection_1; | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_label_id(1); | ||||
|   detection_1.add_label_id(2); | ||||
|   detection_1.add_score(0.9); | ||||
|   detection_1.add_score(0.8); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_label_id(3); | ||||
|   detection_2.add_label_id(4); | ||||
|   detection_2.add_score(0.7); | ||||
|   detection_2.add_score(0.6); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), | ||||
|               testing::ElementsAre("label_1", "label_2")); | ||||
|   ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), | ||||
|               testing::ElementsAre(1, 2)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), | ||||
|               testing::ElementsAre(0.9, 0.8)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), | ||||
|               testing::ElementsAre("label_3", "label_4")); | ||||
|   ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), | ||||
|               testing::ElementsAre(3, 4)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), | ||||
|               testing::ElementsAre(0.7, 0.6)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, AppendTwoClipLabels) { | ||||
|   // Append to the existing clip/label/string and clip/label/confidence values
 | ||||
|   // for the prefixes.
 | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/false); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
|   mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, | ||||
|                            input_sequence.get()); | ||||
|   mpms::SetClipLabelIndex("TEST", {1, 2}, input_sequence.get()); | ||||
|   mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); | ||||
|   mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, | ||||
|                            input_sequence.get()); | ||||
|   mpms::SetClipLabelIndex("OTHER", {3, 4}, input_sequence.get()); | ||||
|   mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); | ||||
| 
 | ||||
|   Detection detection_1; | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_label_id(9); | ||||
|   detection_1.add_label_id(8); | ||||
|   detection_1.add_score(0.9); | ||||
|   detection_1.add_score(0.8); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_label_id(7); | ||||
|   detection_2.add_label_id(6); | ||||
|   detection_2.add_score(0.7); | ||||
|   detection_2.add_score(0.6); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner_->Run()); | ||||
| 
 | ||||
|   const std::vector<Packet>& output_packets = | ||||
|       runner_->Outputs().Tag(kSequenceExampleTag).packets; | ||||
|   ASSERT_EQ(1, output_packets.size()); | ||||
|   const tf::SequenceExample& output_sequence = | ||||
|       output_packets[0].Get<tf::SequenceExample>(); | ||||
| 
 | ||||
|   ASSERT_THAT( | ||||
|       mpms::GetClipLabelString("TEST", output_sequence), | ||||
|       testing::ElementsAre("old_label_1", "old_label_2", "label_1", "label_2")); | ||||
|   ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), | ||||
|               testing::ElementsAre(1, 2, 9, 8)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), | ||||
|               testing::ElementsAre(0.1, 0.2, 0.9, 0.8)); | ||||
|   ASSERT_THAT( | ||||
|       mpms::GetClipLabelString("OTHER", output_sequence), | ||||
|       testing::ElementsAre("old_label_3", "old_label_4", "label_3", "label_4")); | ||||
|   ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), | ||||
|               testing::ElementsAre(3, 4, 7, 6)); | ||||
|   ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), | ||||
|               testing::ElementsAre(0.3, 0.4, 0.7, 0.6)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, | ||||
|        DifferentClipLabelScoreAndConfidenceSize) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, | ||||
|       /*features=*/{}, /*output_only_if_all_present=*/false, | ||||
|       /*replace_instead_of_append=*/true); | ||||
|   auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); | ||||
| 
 | ||||
|   Detection detection_1; | ||||
|   // 2 labels and 1 score.
 | ||||
|   detection_1.add_label("label_1"); | ||||
|   detection_1.add_label("label_2"); | ||||
|   detection_1.add_score(0.1); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelTestTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1))); | ||||
|   Detection detection_2; | ||||
|   detection_2.add_label("label_3"); | ||||
|   detection_2.add_label("label_4"); | ||||
|   detection_2.add_score(0.3); | ||||
|   detection_2.add_score(0.4); | ||||
|   runner_->MutableInputs() | ||||
|       ->Tag(kClipLabelOtherTag) | ||||
|       .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); | ||||
|   runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = | ||||
|       Adopt(input_sequence.release()); | ||||
| 
 | ||||
|   ASSERT_THAT(runner_->Run(), | ||||
|               testing::status::StatusIs(absl::StatusCode::kInvalidArgument)); | ||||
| } | ||||
| 
 | ||||
| TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) { | ||||
|   SetUpCalculator( | ||||
|       /*input_streams=*/{"FLOAT_FEATURE_TEST:test", | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user