Update PackMediaSequenceCalculator to support setting clip/media/string, clip/media/confidence and clip/label/index.
The input stream is provided as drishti::Detection. PiperOrigin-RevId: 562070790
This commit is contained in:
parent
e7d071ab39
commit
23057ac146
|
@ -379,6 +379,7 @@ cc_library(
|
||||||
"//mediapipe/util/sequence:media_sequence",
|
"//mediapipe/util/sequence:media_sequence",
|
||||||
"//mediapipe/util/sequence:media_sequence_util",
|
"//mediapipe/util/sequence:media_sequence_util",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
|
@ -950,8 +951,10 @@ cc_test(
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||||
"//mediapipe/util/sequence:media_sequence",
|
"//mediapipe/util/sequence:media_sequence",
|
||||||
|
"//mediapipe/util/sequence:media_sequence_util",
|
||||||
"@com_google_absl//absl/log:absl_check",
|
"@com_google_absl//absl/log:absl_check",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/strip.h"
|
#include "absl/strings/strip.h"
|
||||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||||
|
@ -38,6 +39,7 @@ namespace mediapipe {
|
||||||
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
const char kImageTag[] = "IMAGE";
|
const char kImageTag[] = "IMAGE";
|
||||||
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
|
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
|
||||||
|
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
|
||||||
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
||||||
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
||||||
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
||||||
|
@ -70,6 +72,8 @@ namespace mpms = mediapipe::mediasequence;
|
||||||
// * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
|
// * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
|
||||||
// vector<pair<float, float>>>,
|
// vector<pair<float, float>>>,
|
||||||
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
|
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
|
||||||
|
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
|
||||||
|
// mediapipe::Detection.
|
||||||
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
|
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
|
||||||
// prefixed versions of each stream, which allows for multiple image streams to
|
// prefixed versions of each stream, which allows for multiple image streams to
|
||||||
// be included. However, the default names are suppored by more tools.
|
// be included. However, the default names are suppored by more tools.
|
||||||
|
@ -165,6 +169,9 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
cc->Inputs().Tag(tag).Set<std::vector<Detection>>();
|
cc->Inputs().Tag(tag).Set<std::vector<Detection>>();
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
|
||||||
|
cc->Inputs().Tag(tag).Set<Detection>();
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
|
||||||
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
cc->Inputs().Tag(tag).Set<std::vector<float>>();
|
||||||
}
|
}
|
||||||
|
@ -217,6 +224,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
replace_keypoints_ = false;
|
replace_keypoints_ = false;
|
||||||
if (cc->Options<PackMediaSequenceCalculatorOptions>()
|
if (cc->Options<PackMediaSequenceCalculatorOptions>()
|
||||||
.replace_data_instead_of_append()) {
|
.replace_data_instead_of_append()) {
|
||||||
|
// Clear the existing values under the same key.
|
||||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||||
if (absl::StartsWith(tag, kImageTag)) {
|
if (absl::StartsWith(tag, kImageTag)) {
|
||||||
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
|
@ -264,6 +272,13 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
mpms::ClearBBoxTrackIndex(key, sequence_.get());
|
mpms::ClearBBoxTrackIndex(key, sequence_.get());
|
||||||
mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get());
|
mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get());
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
|
||||||
|
const std::string& key = tag.substr(
|
||||||
|
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
|
||||||
|
mpms::ClearClipLabelIndex(key, sequence_.get());
|
||||||
|
mpms::ClearClipLabelString(key, sequence_.get());
|
||||||
|
mpms::ClearClipLabelConfidence(key, sequence_.get());
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
|
||||||
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
|
||||||
sizeof(*kFloatFeaturePrefixTag) -
|
sizeof(*kFloatFeaturePrefixTag) -
|
||||||
|
@ -465,6 +480,33 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
replace_keypoints_ = false;
|
replace_keypoints_ = false;
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kClipLabelPrefixTag) &&
|
||||||
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
|
const std::string& key = tag.substr(
|
||||||
|
sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
|
||||||
|
const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
|
||||||
|
if (detection.label().size() != detection.score().size()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"Different size of detection.label and detection.score");
|
||||||
|
}
|
||||||
|
// Allow empty label_ids, but if label_ids is not empty, it should have
|
||||||
|
// the same size as the label and score fields.
|
||||||
|
if (!detection.label_id().empty()) {
|
||||||
|
if (detection.label_id().size() != detection.label().size()) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"Different size of detection.label_id and detection.label");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < detection.label().size(); ++i) {
|
||||||
|
if (!detection.label_id().empty()) {
|
||||||
|
mpms::AddClipLabelIndex(key, detection.label_id(i),
|
||||||
|
sequence_.get());
|
||||||
|
}
|
||||||
|
mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
|
||||||
|
mpms::AddClipLabelConfidence(key, detection.score(i),
|
||||||
|
sequence_.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) &&
|
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) &&
|
||||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
std::string key =
|
std::string key =
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include "absl/log/absl_check.h"
|
#include "absl/log/absl_check.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||||
|
@ -31,6 +32,7 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/util/sequence/media_sequence.h"
|
#include "mediapipe/util/sequence/media_sequence.h"
|
||||||
|
#include "mediapipe/util/sequence/media_sequence_util.h"
|
||||||
#include "tensorflow/core/example/example.pb.h"
|
#include "tensorflow/core/example/example.pb.h"
|
||||||
#include "tensorflow/core/example/feature.pb.h"
|
#include "tensorflow/core/example/feature.pb.h"
|
||||||
#include "testing/base/public/gmock.h"
|
#include "testing/base/public/gmock.h"
|
||||||
|
@ -66,6 +68,8 @@ constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||||
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
|
constexpr char kClipMediaIdTag[] = "CLIP_MEDIA_ID";
|
||||||
|
constexpr char kClipLabelTestTag[] = "CLIP_LABEL_TEST";
|
||||||
|
constexpr char kClipLabelOtherTag[] = "CLIP_LABEL_OTHER";
|
||||||
|
|
||||||
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
class PackMediaSequenceCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
@ -848,6 +852,283 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoMaskDetections) {
|
||||||
testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
|
testing::ElementsAreArray(::std::vector<std::string>({"mask"})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, PackTwoClipLabels) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
Detection detection_1;
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_label_id(1);
|
||||||
|
detection_1.add_label_id(2);
|
||||||
|
detection_1.add_score(0.1);
|
||||||
|
detection_1.add_score(0.2);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
// No label ID for detection_2.
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_score(0.3);
|
||||||
|
detection_2.add_score(0.4);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("label_1", "label_2"));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(1, 2));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(0.1, 0.2));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("label_3", "label_4"));
|
||||||
|
ASSERT_FALSE(mpms::HasClipLabelIndex("OTHER", output_sequence));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(0.3, 0.4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest,
|
||||||
|
PackTwoClipLabels_DifferentLabelScoreSize) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
// 2 labels and 1 score in detection_1.
|
||||||
|
Detection detection_1;
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_score(0.1);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_score(0.3);
|
||||||
|
detection_2.add_score(0.4);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
ASSERT_THAT(
|
||||||
|
runner_->Run(),
|
||||||
|
testing::status::StatusIs(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
testing::HasSubstr(
|
||||||
|
"Different size of detection.label and detection.score")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest,
|
||||||
|
PackTwoClipLabels_DifferentLabelIdSize) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
// 2 labels and 1 label_id in detection_1.
|
||||||
|
Detection detection_1;
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_label_id(1);
|
||||||
|
detection_1.add_score(0.1);
|
||||||
|
detection_1.add_score(0.2);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_score(0.3);
|
||||||
|
detection_2.add_score(0.4);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
ASSERT_THAT(
|
||||||
|
runner_->Run(),
|
||||||
|
testing::status::StatusIs(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
testing::HasSubstr(
|
||||||
|
"Different size of detection.label_id and detection.label")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoClipLabels) {
|
||||||
|
// Replace existing clip/label/string and clip/label/confidence values for
|
||||||
|
// the prefixes.
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get());
|
||||||
|
mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get());
|
||||||
|
|
||||||
|
Detection detection_1;
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_label_id(1);
|
||||||
|
detection_1.add_label_id(2);
|
||||||
|
detection_1.add_score(0.9);
|
||||||
|
detection_1.add_score(0.8);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_label_id(3);
|
||||||
|
detection_2.add_label_id(4);
|
||||||
|
detection_2.add_score(0.7);
|
||||||
|
detection_2.add_score(0.6);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelString("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("label_1", "label_2"));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(1, 2));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(0.9, 0.8));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelString("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("label_3", "label_4"));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(3, 4));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(0.7, 0.6));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, AppendTwoClipLabels) {
|
||||||
|
// Append to the existing clip/label/string and clip/label/confidence values
|
||||||
|
// for the prefixes.
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/false);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetClipLabelString("TEST", {"old_label_1", "old_label_2"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetClipLabelIndex("TEST", {1, 2}, input_sequence.get());
|
||||||
|
mpms::SetClipLabelConfidence("TEST", {0.1, 0.2}, input_sequence.get());
|
||||||
|
mpms::SetClipLabelString("OTHER", {"old_label_3", "old_label_4"},
|
||||||
|
input_sequence.get());
|
||||||
|
mpms::SetClipLabelIndex("OTHER", {3, 4}, input_sequence.get());
|
||||||
|
mpms::SetClipLabelConfidence("OTHER", {0.3, 0.4}, input_sequence.get());
|
||||||
|
|
||||||
|
Detection detection_1;
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_label_id(9);
|
||||||
|
detection_1.add_label_id(8);
|
||||||
|
detection_1.add_score(0.9);
|
||||||
|
detection_1.add_score(0.8);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_label_id(7);
|
||||||
|
detection_2.add_label_id(6);
|
||||||
|
detection_2.add_score(0.7);
|
||||||
|
detection_2.add_score(0.6);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(
|
||||||
|
mpms::GetClipLabelString("TEST", output_sequence),
|
||||||
|
testing::ElementsAre("old_label_1", "old_label_2", "label_1", "label_2"));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelIndex("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(1, 2, 9, 8));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(0.1, 0.2, 0.9, 0.8));
|
||||||
|
ASSERT_THAT(
|
||||||
|
mpms::GetClipLabelString("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre("old_label_3", "old_label_4", "label_3", "label_4"));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelIndex("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(3, 4, 7, 6));
|
||||||
|
ASSERT_THAT(mpms::GetClipLabelConfidence("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(0.3, 0.4, 0.7, 0.6));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest,
|
||||||
|
DifferentClipLabelScoreAndConfidenceSize) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"CLIP_LABEL_TEST:test", "CLIP_LABEL_OTHER:test2"},
|
||||||
|
/*features=*/{}, /*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
Detection detection_1;
|
||||||
|
// 2 labels and 1 score.
|
||||||
|
detection_1.add_label("label_1");
|
||||||
|
detection_1.add_label("label_2");
|
||||||
|
detection_1.add_score(0.1);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelTestTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_1).At(Timestamp(1)));
|
||||||
|
Detection detection_2;
|
||||||
|
detection_2.add_label("label_3");
|
||||||
|
detection_2.add_label("label_4");
|
||||||
|
detection_2.add_score(0.3);
|
||||||
|
detection_2.add_score(0.4);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kClipLabelOtherTag)
|
||||||
|
.packets.push_back(MakePacket<Detection>(detection_2).At(Timestamp(2)));
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
ASSERT_THAT(runner_->Run(),
|
||||||
|
testing::status::StatusIs(absl::StatusCode::kInvalidArgument));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) {
|
TEST_F(PackMediaSequenceCalculatorTest, AddClipMediaId) {
|
||||||
SetUpCalculator(
|
SetUpCalculator(
|
||||||
/*input_streams=*/{"FLOAT_FEATURE_TEST:test",
|
/*input_streams=*/{"FLOAT_FEATURE_TEST:test",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user