Update PackMediaSequenceCalculator to support index feature inputs on the CLIP_MEDIA_ input tag.
For Detection protos representing index features, the `label` field might be empty. With this change, only the `Detection::score` field is required, and `Detection.label` and `Detection.label_id` are both optional but at least one of them should be set. PiperOrigin-RevId: 568944596
This commit is contained in:
		
							parent
							
								
									698b154ff4
								
							
						
					
					
						commit
						8f8c66430f
					
				| 
						 | 
				
			
			@ -75,7 +75,8 @@ namespace mpms = mediapipe::mediasequence;
 | 
			
		|||
//   vector<pair<float, float>>>,
 | 
			
		||||
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
 | 
			
		||||
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
 | 
			
		||||
//   mediapipe::Detection.
 | 
			
		||||
//   mediapipe::Detection. In the input Detection, the score field is required,
 | 
			
		||||
//   and label and label_id are optional but at least one of them should be set.
 | 
			
		||||
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
 | 
			
		||||
// prefixed versions of each stream, which allows for multiple image streams to
 | 
			
		||||
// be included. However, the default names are suppored by more tools.
 | 
			
		||||
| 
						 | 
				
			
			@ -514,24 +515,37 @@ class PackMediaSequenceCalculator : public CalculatorBase {
 | 
			
		|||
        const std::string& key = tag.substr(
 | 
			
		||||
            sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
 | 
			
		||||
        const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
 | 
			
		||||
        if (detection.score().empty()) {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        if (detection.label().empty() && detection.label_id().empty()) {
 | 
			
		||||
          return absl::InvalidArgumentError(
 | 
			
		||||
              "detection.label and detection.label_id can't be both empty");
 | 
			
		||||
        }
 | 
			
		||||
        // Allow empty label (for indexed feature inputs), but if label is not
 | 
			
		||||
        // empty, it should have the same size as the score field.
 | 
			
		||||
        if (!detection.label().empty()) {
 | 
			
		||||
          if (detection.label().size() != detection.score().size()) {
 | 
			
		||||
            return absl::InvalidArgumentError(
 | 
			
		||||
                "Different size of detection.label and detection.score");
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        // Allow empty label_ids, but if label_ids is not empty, it should have
 | 
			
		||||
        // the same size as the label and score fields.
 | 
			
		||||
        // the same size as the score field.
 | 
			
		||||
        if (!detection.label_id().empty()) {
 | 
			
		||||
          if (detection.label_id().size() != detection.label().size()) {
 | 
			
		||||
          if (detection.label_id().size() != detection.score().size()) {
 | 
			
		||||
            return absl::InvalidArgumentError(
 | 
			
		||||
                "Different size of detection.label_id and detection.label");
 | 
			
		||||
                "Different size of detection.label_id and detection.score");
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        for (int i = 0; i < detection.label().size(); ++i) {
 | 
			
		||||
        for (int i = 0; i < detection.score().size(); ++i) {
 | 
			
		||||
          if (!detection.label_id().empty()) {
 | 
			
		||||
            mpms::AddClipLabelIndex(key, detection.label_id(i),
 | 
			
		||||
                                    sequence_.get());
 | 
			
		||||
          }
 | 
			
		||||
          if (!detection.label().empty()) {
 | 
			
		||||
            mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
 | 
			
		||||
          }
 | 
			
		||||
          mpms::AddClipLabelConfidence(key, detection.score(i),
 | 
			
		||||
                                       sequence_.get());
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,6 +75,7 @@ constexpr char kImageTag[] = "IMAGE";
 | 
			
		|||
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
 | 
			
		||||
constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST";
 | 
			
		||||
constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER";
 | 
			
		||||
constexpr char kClipLabelAnotherTag[] = "CLIP_LABEL_ANOTHER";
 | 
			
		||||
 | 
			
		||||
class PackMediaSequenceCalculatorTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
| 
						 | 
				
			
			@ -1166,9 +1167,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
 | 
			
		|||
              testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest, PackThreeClipLabels) {
 | 
			
		||||
  SetUpCalculator(
 | 
			
		||||
      /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
 | 
			
		||||
      /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2",
 | 
			
		||||
                         "CLIP_LABEL_ANOTHER:test3"},
 | 
			
		||||
      /*features=*/{}, /*output_only_if_all_present=*/false,
 | 
			
		||||
      /*replace_instead_of_append=*/true);
 | 
			
		||||
  auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
 | 
			
		||||
| 
						 | 
				
			
			@ -1192,6 +1194,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
 | 
			
		|||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelOtherTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
 | 
			
		||||
  // No label for detection_3.
 | 
			
		||||
  Detection detection_3;
 | 
			
		||||
  detection_3.add_label_id(3);
 | 
			
		||||
  detection_3.add_label_id(4);
 | 
			
		||||
  detection_3.add_score(0.3);
 | 
			
		||||
  detection_3.add_score(0.4);
 | 
			
		||||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelAnotherTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_3).At(Timestamp(3)));
 | 
			
		||||
 | 
			
		||||
  runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
 | 
			
		||||
      Adopt(input_sequence.release());
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1214,6 +1226,86 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
 | 
			
		|||
  ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
 | 
			
		||||
  ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
 | 
			
		||||
              testing::ElementsAre(0.3, 0.4));
 | 
			
		||||
  ASSERT_FALSE(mpms::HasClipLabelString("ANOTHER", output_sequence));
 | 
			
		||||
  ASSERT_THAT(mpms::GetClipLabelIndex("ANOTHER", output_sequence),
 | 
			
		||||
              testing::ElementsAre(3, 4));
 | 
			
		||||
  ASSERT_THAT(mpms::GetClipLabelConfidence("ANOTHER", output_sequence),
 | 
			
		||||
              testing::ElementsAre(0.3, 0.4));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_EmptyScore) {
 | 
			
		||||
  SetUpCalculator(
 | 
			
		||||
      /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
 | 
			
		||||
      /*features=*/{}, /*output_only_if_all_present=*/false,
 | 
			
		||||
      /*replace_instead_of_append=*/true);
 | 
			
		||||
  auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
 | 
			
		||||
 | 
			
		||||
  // No score in detection_1. detection_1 is ignored.
 | 
			
		||||
  Detection detection_1;
 | 
			
		||||
  detection_1.add_label("label_1");
 | 
			
		||||
  detection_1.add_label("label_2");
 | 
			
		||||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelTestTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
 | 
			
		||||
  Detection detection_2;
 | 
			
		||||
  detection_2.add_label("label_3");
 | 
			
		||||
  detection_2.add_label("label_4");
 | 
			
		||||
  detection_2.add_score(0.3);
 | 
			
		||||
  detection_2.add_score(0.4);
 | 
			
		||||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelOtherTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
 | 
			
		||||
  runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
 | 
			
		||||
      Adopt(input_sequence.release());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(runner_->Run());
 | 
			
		||||
 | 
			
		||||
  const std::vector<Packet>& output_packets =
 | 
			
		||||
      runner_->Outputs().Tag(kSequenceExampleTag).packets;
 | 
			
		||||
  ASSERT_EQ(1, output_packets.size());
 | 
			
		||||
  const tf::SequenceExample& output_sequence =
 | 
			
		||||
      output_packets[0].Get<tf::SequenceExample>();
 | 
			
		||||
 | 
			
		||||
  ASSERT_FALSE(mpms::HasClipLabelString("TEST", output_sequence));
 | 
			
		||||
  ASSERT_FALSE(mpms::HasClipLabelIndex("TEST", output_sequence));
 | 
			
		||||
  ASSERT_FALSE(mpms::HasClipLabelConfidence("TEST", output_sequence));
 | 
			
		||||
  ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
 | 
			
		||||
              testing::ElementsAre("label_3", "label_4"));
 | 
			
		||||
  ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
 | 
			
		||||
  ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
 | 
			
		||||
              testing::ElementsAre(0.3, 0.4));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_NoLabelOrLabelIndex) {
 | 
			
		||||
  SetUpCalculator(
 | 
			
		||||
      /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
 | 
			
		||||
      /*features=*/{}, /*output_only_if_all_present=*/false,
 | 
			
		||||
      /*replace_instead_of_append=*/true);
 | 
			
		||||
  auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
 | 
			
		||||
 | 
			
		||||
  // No label or label_index in detection_1.
 | 
			
		||||
  Detection detection_1;
 | 
			
		||||
  detection_1.add_score(0.1);
 | 
			
		||||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelTestTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
 | 
			
		||||
  Detection detection_2;
 | 
			
		||||
  detection_2.add_label("label_3");
 | 
			
		||||
  detection_2.add_label("label_4");
 | 
			
		||||
  detection_2.add_score(0.3);
 | 
			
		||||
  detection_2.add_score(0.4);
 | 
			
		||||
  runner_->MutableInputs()
 | 
			
		||||
      ->Tag(kClipLabelOtherTag)
 | 
			
		||||
      .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
 | 
			
		||||
  runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
 | 
			
		||||
      Adopt(input_sequence.release());
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(
 | 
			
		||||
      runner_->Run(),
 | 
			
		||||
      testing::status::StatusIs(
 | 
			
		||||
          absl::StatusCode::kInvalidArgument,
 | 
			
		||||
          testing::HasSubstr(
 | 
			
		||||
              "detection.label and detection.label_id can't be both empty")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest,
 | 
			
		||||
| 
						 | 
				
			
			@ -1259,7 +1351,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
 | 
			
		|||
      /*replace_instead_of_append=*/true);
 | 
			
		||||
  auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
 | 
			
		||||
 | 
			
		||||
  // 2 labels and 1 label_id in detection_1.
 | 
			
		||||
  // 2 scores and 1 label_id in detection_1.
 | 
			
		||||
  Detection detection_1;
 | 
			
		||||
  detection_1.add_label("label_1");
 | 
			
		||||
  detection_1.add_label("label_2");
 | 
			
		||||
| 
						 | 
				
			
			@ -1285,7 +1377,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
 | 
			
		|||
      testing::status::StatusIs(
 | 
			
		||||
          absl::StatusCode::kInvalidArgument,
 | 
			
		||||
          testing::HasSubstr(
 | 
			
		||||
              "Different size of detection.label_id and detection.label")));
 | 
			
		||||
              "Different size of detection.label_id and detection.score")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user