Update PackMediaSequenceCalculator to support index feature inputs on the CLIP_MEDIA_ input tag.
For Detection protos representing index features, the `label` field might be empty. With this change, only the `Detection::score` field is required, and `Detection.label` and `Detection.label_id` are both optional but at least one of them should be set. PiperOrigin-RevId: 568944596
This commit is contained in:
parent
698b154ff4
commit
8f8c66430f
|
@ -75,7 +75,8 @@ namespace mpms = mediapipe::mediasequence;
|
|||
// vector<pair<float, float>>>,
|
||||
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
|
||||
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
|
||||
// mediapipe::Detection.
|
||||
// mediapipe::Detection. In the input Detection, the score field is required,
|
||||
// and label and label_id are optional but at least one of them should be set.
|
||||
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
|
||||
// prefixed versions of each stream, which allows for multiple image streams to
|
||||
// be included. However, the default names are suppored by more tools.
|
||||
|
@ -514,24 +515,37 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
|||
const std::string& key = tag.substr(
|
||||
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
|
||||
const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
|
||||
if (detection.score().empty()) {
|
||||
continue;
|
||||
}
|
||||
if (detection.label().empty() && detection.label_id().empty()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"detection.label and detection.label_id can't be both empty");
|
||||
}
|
||||
// Allow empty label (for indexed feature inputs), but if label is not
|
||||
// empty, it should have the same size as the score field.
|
||||
if (!detection.label().empty()) {
|
||||
if (detection.label().size() != detection.score().size()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Different size of detection.label and detection.score");
|
||||
}
|
||||
}
|
||||
// Allow empty label_ids, but if label_ids is not empty, it should have
|
||||
// the same size as the label and score fields.
|
||||
// the same size as the score field.
|
||||
if (!detection.label_id().empty()) {
|
||||
if (detection.label_id().size() != detection.label().size()) {
|
||||
if (detection.label_id().size() != detection.score().size()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Different size of detection.label_id and detection.label");
|
||||
"Different size of detection.label_id and detection.score");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < detection.label().size(); ++i) {
|
||||
for (int i = 0; i < detection.score().size(); ++i) {
|
||||
if (!detection.label_id().empty()) {
|
||||
mpms::AddClipLabelIndex(key, detection.label_id(i),
|
||||
sequence_.get());
|
||||
}
|
||||
if (!detection.label().empty()) {
|
||||
mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
|
||||
}
|
||||
mpms::AddClipLabelConfidence(key, detection.score(i),
|
||||
sequence_.get());
|
||||
}
|
||||
|
|
|
@ -75,6 +75,7 @@ constexpr char kImageTag[] = "IMAGE";
|
|||
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
|
||||
constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST";
|
||||
constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER";
|
||||
constexpr char kClipLabelAnotherTag[] = "CLIP_LABEL_ANOTHER";
|
||||
|
||||
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||
protected:
|
||||
|
@ -1166,9 +1167,10 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
|||
testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PackThreeClipLabels) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2",
|
||||
"CLIP_LABEL_ANOTHER:test3"},
|
||||
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||
/*replace_instead_of_append=*/true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
|
@ -1192,6 +1194,16 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
|
|||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelOtherTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||
// No label for detection_3.
|
||||
Detection detection_3;
|
||||
detection_3.add_label_id(3);
|
||||
detection_3.add_label_id(4);
|
||||
detection_3.add_score(0.3);
|
||||
detection_3.add_score(0.4);
|
||||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelAnotherTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_3).At(Timestamp(3)));
|
||||
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
|
@ -1214,6 +1226,86 @@ TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
|
|||
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
|
||||
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
|
||||
testing::ElementsAre(0.3, 0.4));
|
||||
ASSERT_FALSE(mpms::HasClipLabelString("ANOTHER", output_sequence));
|
||||
ASSERT_THAT(mpms::GetClipLabelIndex("ANOTHER", output_sequence),
|
||||
testing::ElementsAre(3, 4));
|
||||
ASSERT_THAT(mpms::GetClipLabelConfidence("ANOTHER", output_sequence),
|
||||
testing::ElementsAre(0.3, 0.4));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_EmptyScore) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||
/*replace_instead_of_append=*/true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
|
||||
// No score in detection_1. detection_1 is ignored.
|
||||
Detection detection_1;
|
||||
detection_1.add_label("label_1");
|
||||
detection_1.add_label("label_2");
|
||||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelTestTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||
Detection detection_2;
|
||||
detection_2.add_label("label_3");
|
||||
detection_2.add_label("label_4");
|
||||
detection_2.add_score(0.3);
|
||||
detection_2.add_score(0.4);
|
||||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelOtherTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
MP_ASSERT_OK(runner_->Run());
|
||||
|
||||
const std::vector<Packet>& output_packets =
|
||||
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
const tf::SequenceExample& output_sequence =
|
||||
output_packets[0].Get<tf::SequenceExample>();
|
||||
|
||||
ASSERT_FALSE(mpms::HasClipLabelString("TEST", output_sequence));
|
||||
ASSERT_FALSE(mpms::HasClipLabelIndex("TEST", output_sequence));
|
||||
ASSERT_FALSE(mpms::HasClipLabelConfidence("TEST", output_sequence));
|
||||
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
|
||||
testing::ElementsAre("label_3", "label_4"));
|
||||
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
|
||||
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
|
||||
testing::ElementsAre(0.3, 0.4));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels_NoLabelOrLabelIndex) {
|
||||
SetUpCalculator(
|
||||
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||
/*replace_instead_of_append=*/true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
|
||||
// No label or label_index in detection_1.
|
||||
Detection detection_1;
|
||||
detection_1.add_score(0.1);
|
||||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelTestTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||
Detection detection_2;
|
||||
detection_2.add_label("label_3");
|
||||
detection_2.add_label("label_4");
|
||||
detection_2.add_score(0.3);
|
||||
detection_2.add_score(0.4);
|
||||
runner_->MutableInputs()
|
||||
->Tag(kClipLabelOtherTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||
Adopt(input_sequence.release());
|
||||
|
||||
ASSERT_THAT(
|
||||
runner_->Run(),
|
||||
testing::status::StatusIs(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
testing::HasSubstr(
|
||||
"detection.label and detection.label_id can't be both empty")));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest,
|
||||
|
@ -1259,7 +1351,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
|
|||
/*replace_instead_of_append=*/true);
|
||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||
|
||||
// 2 labels and 1 label_id in detection_1.
|
||||
// 2 scores and 1 label_id in detection_1.
|
||||
Detection detection_1;
|
||||
detection_1.add_label("label_1");
|
||||
detection_1.add_label("label_2");
|
||||
|
@ -1285,7 +1377,7 @@ TEST_F(PackMediaSequenceCalculatorTest,
|
|||
testing::status::StatusIs(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
testing::HasSubstr(
|
||||
"Different size of detection.label_id and detection.label")));
|
||||
"Different size of detection.label_id and detection.score")));
|
||||
}
|
||||
|
||||
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user