Update PackMediaSequenceCalculator to support index feature inputs on the CLIP_MEDIA_ input tag.

For Detection protos representing index features, the `label` field might be empty.

With this change, only the `Detection::score` field is required, and `Detection.label` and `Detection.label_id` are both optional but at least one of them should be set.

PiperOrigin-RevId: 568944596
This commit is contained in:
MediaPipe Team 2023-09-27 13:30:46 -07:00 committed by Copybara-Service
parent 698b154ff4
commit 8f8c66430f
2 changed files with 121 additions and 15 deletions

View File

@ -75,7 +75,8 @@ namespace mpms = mediapipe::mediasequence;
// vector<pair<float, float>>>, // vector<pair<float, float>>>,
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string. // * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in // * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
// mediapipe::Detection. // mediapipe::Detection. In the input Detection, the score field is required,
// and label and label_id are optional but at least one of them should be set.
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store // "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
// prefixed versions of each stream, which allows for multiple image streams to // prefixed versions of each stream, which allows for multiple image streams to
// be included. However, the default names are suppored by more tools. // be included. However, the default names are suppored by more tools.
@ -514,24 +515,37 @@ class PackMediaSequenceCalculator : public CalculatorBase {
const std::string& key = tag.substr( const std::string& key = tag.substr(
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>(); const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
if (detection.score().empty()) {
continue;
}
if (detection.label().empty() && detection.label_id().empty()) {
return absl::InvalidArgumentError(
"detection.label and detection.label_id can't be both empty");
}
// Allow empty label (for indexed feature inputs), but if label is not
// empty, it should have the same size as the score field.
if (!detection.label().empty()) {
if (detection.label().size() != detection.score().size()) { if (detection.label().size() != detection.score().size()) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"Different size of detection.label and detection.score"); "Different size of detection.label and detection.score");
} }
}
// Allow empty label_ids, but if label_ids is not empty, it should have // Allow empty label_ids, but if label_ids is not empty, it should have
// the same size as the label and score fields. // the same size as the score field.
if (!detection.label_id().empty()) { if (!detection.label_id().empty()) {
if (detection.label_id().size() != detection.label().size()) { if (detection.label_id().size() != detection.score().size()) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"Different size of detection.label_id and detection.label"); "Different size of detection.label_id and detection.score");
} }
} }
for (int i = 0; i < detection.label().size(); ++i) { for (int i = 0; i < detection.score().size(); ++i) {
if (!detection.label_id().empty()) { if (!detection.label_id().empty()) {
mpms::AddClipLabelIndex(key, detection.label_id(i), mpms::AddClipLabelIndex(key, detection.label_id(i),
sequence_.get()); sequence_.get());
} }
if (!detection.label().empty()) {
mpms::AddClipLabelString(key, detection.label(i), sequence_.get()); mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
}
mpms::AddClipLabelConfidence(key, detection.score(i), mpms::AddClipLabelConfidence(key, detection.score(i),
sequence_.get()); sequence_.get());
} }

View File

@ -75,6 +75,7 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID"; constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST"; constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST";
constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER"; constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER";
constexpr char kClipLabelAnotherTag[] = "CLIP_LABEL_ANOTHER";
class PackMediaSequenceCalculatorTest : public ::testing::Test { class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected: protected:
@ -1166,9 +1167,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
testing::ElementsAreArray(::std::vector<std::string>({"mask"}))); testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
} }
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) { TEST_F(PackMediaSequenceCalculatorTest, PackThreeClipLabels) {
SetUpCalculator( SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2",
"CLIP_LABEL_ANOTHER:test3"},
/*features=*/{}, /*output_only_if_all_present=*/false, /*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true); /*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
@ -1192,6 +1194,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
runner_->MutableInputs() runner_->MutableInputs()
->Tag(kClipLabelOtherTag) ->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2))); .packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
// No label for detection_3.
Detection detection_3;
detection_3.add_label_id(3);
detection_3.add_label_id(4);
detection_3.add_score(0.3);
detection_3.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelAnotherTag)
.packets.push_back(MakePacket<Detection>(detection_3).At(Timestamp(3)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release()); Adopt(input_sequence.release());
@ -1214,6 +1226,86 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence)); ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
testing::ElementsAre(0.3, 0.4)); testing::ElementsAre(0.3, 0.4));
ASSERT_FALSE(mpms::HasClipLabelString("ANOTHER", output_sequence));
ASSERT_THAT(mpms::GetClipLabelIndex("ANOTHER", output_sequence),
testing::ElementsAre(3, 4));
ASSERT_THAT(mpms::GetClipLabelConfidence("ANOTHER", output_sequence),
testing::ElementsAre(0.3, 0.4));
}
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_EmptyScore) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
// No score in detection_1. detection_1 is ignored.
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_FALSE(mpms::HasClipLabelString("TEST", output_sequence));
ASSERT_FALSE(mpms::HasClipLabelIndex("TEST", output_sequence));
ASSERT_FALSE(mpms::HasClipLabelConfidence("TEST", output_sequence));
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
testing::ElementsAre("label_3", "label_4"));
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
testing::ElementsAre(0.3, 0.4));
}
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_NoLabelOrLabelIndex) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
// No label or label_index in detection_1.
Detection detection_1;
detection_1.add_score(0.1);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
ASSERT_THAT(
runner_->Run(),
testing::status::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::HasSubstr(
"detection.label and detection.label_id can't be both empty")));
} }
TEST_F(PackMediaSequenceCalculatorTest, TEST_F(PackMediaSequenceCalculatorTest,
@ -1259,7 +1351,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
/*replace_instead_of_append=*/true); /*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
// 2 labels and 1 label_id in detection_1. // 2 scores and 1 label_id in detection_1.
Detection detection_1; Detection detection_1;
detection_1.add_label("label_1"); detection_1.add_label("label_1");
detection_1.add_label("label_2"); detection_1.add_label("label_2");
@ -1285,7 +1377,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
testing::status::StatusIs( testing::status::StatusIs(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
testing::HasSubstr( testing::HasSubstr(
"Different size of detection.label_id and detection.label"))); "Different size of detection.label_id and detection.score")));
} }
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) { TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) {