diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index cd4d1ad88..5f5f51657 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -379,6 +379,7 @@ cc_library( "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -950,8 +951,10 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/util/sequence:media_sequence", + "//mediapipe/util/sequence:media_sequence_util", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:protos_all_cc", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 7a1f24722..521d27063 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -18,6 +18,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/strip.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" @@ -38,6 +39,7 @@ namespace mediapipe { const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kImageTag[] = "IMAGE"; const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; +const char kClipLabelPrefixTag[] = "CLIP_LABEL_"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; @@ -70,6 +72,8 @@ namespace mpms = mediapipe::mediasequence; // * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map>>, // * "CLIP_MEDIA_ID", which stores the clip's media ID as a string. +// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in +// mediapipe::Detection. // "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store // prefixed versions of each stream, which allows for multiple image streams to // be included. However, the default names are suppored by more tools. @@ -165,6 +169,9 @@ class PackMediaSequenceCalculator : public CalculatorBase { } cc->Inputs().Tag(tag).Set>(); } + if (absl::StartsWith(tag, kClipLabelPrefixTag)) { + cc->Inputs().Tag(tag).Set(); + } if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) { cc->Inputs().Tag(tag).Set>(); } @@ -217,6 +224,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { replace_keypoints_ = false; if (cc->Options() .replace_data_instead_of_append()) { + // Clear the existing values under the same key. for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { if (absl::StartsWith(tag, kImageLabelPrefixTag)) { @@ -264,6 +272,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { mpms::ClearBBoxTrackIndex(key, sequence_.get()); mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get()); } + if (absl::StartsWith(tag, kClipLabelPrefixTag)) { + const std::string& key = tag.substr( + sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); + mpms::ClearClipLabelIndex(key, sequence_.get()); + mpms::ClearClipLabelString(key, sequence_.get()); + mpms::ClearClipLabelConfidence(key, sequence_.get()); + } if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) { std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) / sizeof(*kFloatFeaturePrefixTag) - @@ -465,6 +480,33 @@ class PackMediaSequenceCalculator : public CalculatorBase { } replace_keypoints_ = false; } + if (absl::StartsWith(tag, kClipLabelPrefixTag) && + !cc->Inputs().Tag(tag).IsEmpty()) { + const std::string& key = tag.substr( + sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1); + const Detection& detection = cc->Inputs().Tag(tag).Get(); + if (detection.label().size() != detection.score().size()) { + return absl::InvalidArgumentError( + "Different size of detection.label and detection.score"); + } + // Allow empty label_ids, but if label_ids is not empty, it should have + // the same size as the label and score fields. + if (!detection.label_id().empty()) { + if (detection.label_id().size() != detection.label().size()) { + return absl::InvalidArgumentError( + "Different size of detection.label_id and detection.label"); + } + } + for (int i = 0; i < detection.label().size(); ++i) { + if (!detection.label_id().empty()) { + mpms::AddClipLabelIndex(key, detection.label_id(i), + sequence_.get()); + } + mpms::AddClipLabelString(key, detection.label(i), sequence_.get()); + mpms::AddClipLabelConfidence(key, detection.score(i), + sequence_.get()); + } + } if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 3fb48d1e7..9adca0013 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -18,6 +18,7 @@ #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" @@ -31,6 +32,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/util/sequence/media_sequence.h" +#include "mediapipe/util/sequence/media_sequence_util.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "testing/base/public/gmock.h" @@ -66,6 +68,8 @@ constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; constexpr char kImageTag[] = "IMAGE"; constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID"; +constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST"; +constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER"; class PackMediaSequenceCalculatorTest : public ::testing::Test { protected: @@ -848,6 +852,283 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) { testing::ElementsAreArray(::std::vector({"mask"}))); } +TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_label_id(2); + detection_1.add_score(0.1); + detection_1.add_score(0.2); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + // No label ID for detection_2. + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.1, 0.2)); + ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("label_3", "label_4")); + ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.3, 0.4)); +} + +TEST_F(PackMediaSequenceCalculatorTest, + PackTwoClipLabels_DifferentLabelScoreSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + // 2 labels and 1 score in detection_1. + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_score(0.1); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT( + runner_->Run(), + testing::status::StatusIs( + absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "Different size of detection.label and detection.score"))); +} + +TEST_F(PackMediaSequenceCalculatorTest, + PackTwoClipLabels_DifferentLabelIdSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + // 2 labels and 1 label_id in detection_1. + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_score(0.1); + detection_1.add_score(0.2); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT( + runner_->Run(), + testing::status::StatusIs( + absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "Different size of detection.label_id and detection.label"))); +} + +TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) { + // Replace existing clip/label/string and clip/label/confidence values for + // the prefixes. + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, + input_sequence.get()); + mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); + mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, + input_sequence.get()); + mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(1); + detection_1.add_label_id(2); + detection_1.add_score(0.9); + detection_1.add_score(0.8); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_label_id(3); + detection_2.add_label_id(4); + detection_2.add_score(0.7); + detection_2.add_score(0.6); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.9, 0.8)); + ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("label_3", "label_4")); + ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), + testing::ElementsAre(3, 4)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.7, 0.6)); +} + +TEST_F(PackMediaSequenceCalculatorTest, AppendTwoClipLabels) { + // Append to the existing clip/label/string and clip/label/confidence values + // for the prefixes. + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/false); + auto input_sequence = ::absl::make_unique(); + mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"}, + input_sequence.get()); + mpms::SetClipLabelIndex("TEST", {1, 2}, input_sequence.get()); + mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get()); + mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"}, + input_sequence.get()); + mpms::SetClipLabelIndex("OTHER", {3, 4}, input_sequence.get()); + mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get()); + + Detection detection_1; + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_label_id(9); + detection_1.add_label_id(8); + detection_1.add_score(0.9); + detection_1.add_score(0.8); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_label_id(7); + detection_2.add_label_id(6); + detection_2.add_score(0.7); + detection_2.add_score(0.6); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_THAT( + mpms::GetClipLabelString("TEST", output_sequence), + testing::ElementsAre("old_label_1", "old_label_2", "label_1", "label_2")); + ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence), + testing::ElementsAre(1, 2, 9, 8)); + ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence), + testing::ElementsAre(0.1, 0.2, 0.9, 0.8)); + ASSERT_THAT( + mpms::GetClipLabelString("OTHER", output_sequence), + testing::ElementsAre("old_label_3", "old_label_4", "label_3", "label_4")); + ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence), + testing::ElementsAre(3, 4, 7, 6)); + ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence), + testing::ElementsAre(0.3, 0.4, 0.7, 0.6)); +} + +TEST_F(PackMediaSequenceCalculatorTest, + DifferentClipLabelScoreAndConfidenceSize) { + SetUpCalculator( + /*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"}, + /*features=*/{}, /*output_only_if_all_present=*/false, + /*replace_instead_of_append=*/true); + auto input_sequence = ::absl::make_unique(); + + Detection detection_1; + // 2 labels and 1 score. + detection_1.add_label("label_1"); + detection_1.add_label("label_2"); + detection_1.add_score(0.1); + runner_->MutableInputs() + ->Tag(kClipLabelTestTag) + .packets.push_back(MakePacket(detection_1).At(Timestamp(1))); + Detection detection_2; + detection_2.add_label("label_3"); + detection_2.add_label("label_4"); + detection_2.add_score(0.3); + detection_2.add_score(0.4); + runner_->MutableInputs() + ->Tag(kClipLabelOtherTag) + .packets.push_back(MakePacket(detection_2).At(Timestamp(2))); + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + ASSERT_THAT(runner_->Run(), + testing::status::StatusIs(absl::StatusCode::kInvalidArgument)); +} + TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) { SetUpCalculator( /*input_streams=*/{"FLOAT_FEATURE_TEST:test",