Update PackMediaSequenceCalculator to support setting clip/media/string, clip/media/confidence and clip/label/index.

The input stream is provided as drishti::Detection.

PiperOrigin-RevId: 562070790
This commit is contained in:
MediaPipe Team 2023-09-01 16:06:27 -07:00 committed by Copybara-Service
parent e7d071ab39
commit 23057ac146
3 changed files with 326 additions and 0 deletions

View File

@ -379,6 +379,7 @@ cc_library(
"//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
@ -950,8 +951,10 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/util/sequence:media_sequence",
"//mediapipe/util/sequence:media_sequence_util",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@org_tensorflow//tensorflow/core:protos_all_cc",

View File

@ -18,6 +18,7 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/strip.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
@ -38,6 +39,7 @@ namespace mediapipe {
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
const char kImageTag[] = "IMAGE";
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
@ -70,6 +72,8 @@ namespace mpms = mediapipe::mediasequence;
// * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
// vector<pair<float, float>>>,
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
// mediapipe::Detection.
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
// prefixed versions of each stream, which allows for multiple image streams to
// be included. However, the default names are suppored by more tools.
@ -165,6 +169,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
cc->Inputs().Tag(tag).Set<std::vector<Detection>>();
}
if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
cc->Inputs().Tag(tag).Set<Detection>();
}
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
cc->Inputs().Tag(tag).Set<std::vector<float>>();
}
@ -217,6 +224,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
replace_keypoints_ = false;
if (cc->Options<PackMediaSequenceCalculatorOptions>()
.replace_data_instead_of_append()) {
// Clear the existing values under the same key.
for (const auto& tag : cc->Inputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) {
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
@ -264,6 +272,13 @@ class PackMediaSequenceCalculator : public CalculatorBase {
mpms::ClearBBoxTrackIndex(key, sequence_.get());
mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get());
}
if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
const std::string& key = tag.substr(
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
mpms::ClearClipLabelIndex(key, sequence_.get());
mpms::ClearClipLabelString(key, sequence_.get());
mpms::ClearClipLabelConfidence(key, sequence_.get());
}
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
sizeof(*kFloatFeaturePrefixTag) -
@ -465,6 +480,33 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
replace_keypoints_ = false;
}
if (absl::StartsWith(tag, kClipLabelPrefixTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
const std::string& key = tag.substr(
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
if (detection.label().size() != detection.score().size()) {
return absl::InvalidArgumentError(
"Different size of detection.label and detection.score");
}
// Allow empty label_ids, but if label_ids is not empty, it should have
// the same size as the label and score fields.
if (!detection.label_id().empty()) {
if (detection.label_id().size() != detection.label().size()) {
return absl::InvalidArgumentError(
"Different size of detection.label_id and detection.label");
}
}
for (int i = 0; i < detection.label().size(); ++i) {
if (!detection.label_id().empty()) {
mpms::AddClipLabelIndex(key, detection.label_id(i),
sequence_.get());
}
mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
mpms::AddClipLabelConfidence(key, detection.score(i),
sequence_.get());
}
}
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) {
std::string key =

View File

@ -18,6 +18,7 @@
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
@ -31,6 +32,7 @@
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "mediapipe/util/sequence/media_sequence_util.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "testing/base/public/gmock.h"
@ -66,6 +68,8 @@ constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
constexpr char kImageTag[] = "IMAGE";
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST";
constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER";
class PackMediaSequenceCalculatorTest : public ::testing::Test {
protected:
@ -848,6 +852,283 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
}
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_label_id(1);
detection_1.add_label_id(2);
detection_1.add_score(0.1);
detection_1.add_score(0.2);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
// No label ID for detection_2.
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence),
testing::ElementsAre("label_1", "label_2"));
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
testing::ElementsAre(1, 2));
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
testing::ElementsAre(0.1, 0.2));
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
testing::ElementsAre("label_3", "label_4"));
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
testing::ElementsAre(0.3, 0.4));
}
TEST_F(PackMediaSequenceCalculatorTest,
PackTwoClipLabels_DifferentLabelScoreSize) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
// 2 labels and 1 score in detection_1.
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_score(0.1);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
ASSERT_THAT(
runner_->Run(),
testing::status::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::HasSubstr(
"Different size of detection.label and detection.score")));
}
TEST_F(PackMediaSequenceCalculatorTest,
PackTwoClipLabels_DifferentLabelIdSize) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
// 2 labels and 1 label_id in detection_1.
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_label_id(1);
detection_1.add_score(0.1);
detection_1.add_score(0.2);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
ASSERT_THAT(
runner_->Run(),
testing::status::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::HasSubstr(
"Different size of detection.label_id and detection.label")));
}
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) {
// Replace existing clip/label/string and clip/label/confidence values for
// the prefixes.
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"},
input_sequence.get());
mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get());
mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"},
input_sequence.get());
mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get());
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_label_id(1);
detection_1.add_label_id(2);
detection_1.add_score(0.9);
detection_1.add_score(0.8);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_label_id(3);
detection_2.add_label_id(4);
detection_2.add_score(0.7);
detection_2.add_score(0.6);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence),
testing::ElementsAre("label_1", "label_2"));
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
testing::ElementsAre(1, 2));
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
testing::ElementsAre(0.9, 0.8));
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
testing::ElementsAre("label_3", "label_4"));
ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence),
testing::ElementsAre(3, 4));
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
testing::ElementsAre(0.7, 0.6));
}
TEST_F(PackMediaSequenceCalculatorTest, AppendTwoClipLabels) {
// Append to the existing clip/label/string and clip/label/confidence values
// for the prefixes.
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/false);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"},
input_sequence.get());
mpms::SetClipLabelIndex("TEST", {1, 2}, input_sequence.get());
mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get());
mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"},
input_sequence.get());
mpms::SetClipLabelIndex("OTHER", {3, 4}, input_sequence.get());
mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get());
Detection detection_1;
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_label_id(9);
detection_1.add_label_id(8);
detection_1.add_score(0.9);
detection_1.add_score(0.8);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_label_id(7);
detection_2.add_label_id(6);
detection_2.add_score(0.7);
detection_2.add_score(0.6);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag(kSequenceExampleTag).packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_THAT(
mpms::GetClipLabelString("TEST", output_sequence),
testing::ElementsAre("old_label_1", "old_label_2", "label_1", "label_2"));
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
testing::ElementsAre(1, 2, 9, 8));
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
testing::ElementsAre(0.1, 0.2, 0.9, 0.8));
ASSERT_THAT(
mpms::GetClipLabelString("OTHER", output_sequence),
testing::ElementsAre("old_label_3", "old_label_4", "label_3", "label_4"));
ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence),
testing::ElementsAre(3, 4, 7, 6));
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
testing::ElementsAre(0.3, 0.4, 0.7, 0.6));
}
TEST_F(PackMediaSequenceCalculatorTest,
DifferentClipLabelScoreAndConfidenceSize) {
SetUpCalculator(
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
/*features=*/{}, /*output_only_if_all_present=*/false,
/*replace_instead_of_append=*/true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
Detection detection_1;
// 2 labels and 1 score.
detection_1.add_label("label_1");
detection_1.add_label("label_2");
detection_1.add_score(0.1);
runner_->MutableInputs()
->Tag(kClipLabelTestTag)
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
Detection detection_2;
detection_2.add_label("label_3");
detection_2.add_label("label_4");
detection_2.add_score(0.3);
detection_2.add_score(0.4);
runner_->MutableInputs()
->Tag(kClipLabelOtherTag)
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());
ASSERT_THAT(runner_->Run(),
testing::status::StatusIs(absl::StatusCode::kInvalidArgument));
}
TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) {
SetUpCalculator(
/*input_streams=*/{"FLOAT_FEATURE_TEST:test",