No public description
PiperOrigin-RevId: 562915674
This commit is contained in:
parent
4e52e96973
commit
bf32d1acb2
|
@ -379,7 +379,6 @@ cc_library(
|
||||||
"//mediapipe/util/sequence:media_sequence",
|
"//mediapipe/util/sequence:media_sequence",
|
||||||
"//mediapipe/util/sequence:media_sequence_util",
|
"//mediapipe/util/sequence:media_sequence_util",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/log",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||||
|
|
|
@ -287,6 +287,13 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
mpms::ClearClipLabelString(key, sequence_.get());
|
mpms::ClearClipLabelString(key, sequence_.get());
|
||||||
mpms::ClearClipLabelConfidence(key, sequence_.get());
|
mpms::ClearClipLabelConfidence(key, sequence_.get());
|
||||||
}
|
}
|
||||||
|
if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
|
||||||
|
const std::string& key =
|
||||||
|
tag.substr(sizeof(kFloatContextFeaturePrefixTag) /
|
||||||
|
sizeof(*kFloatContextFeaturePrefixTag) -
|
||||||
|
1);
|
||||||
|
mpms::ClearContextFeatureFloats(key, sequence_.get());
|
||||||
|
}
|
||||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
|
||||||
const std::string& key =
|
const std::string& key =
|
||||||
tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
|
tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
|
||||||
|
@ -536,9 +543,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
sizeof(*kFloatContextFeaturePrefixTag) -
|
sizeof(*kFloatContextFeaturePrefixTag) -
|
||||||
1);
|
1);
|
||||||
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
|
RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
|
||||||
mpms::SetContextFeatureFloats(
|
for (const auto& value :
|
||||||
key, cc->Inputs().Tag(tag).Get<std::vector<float>>(),
|
cc->Inputs().Tag(tag).Get<std::vector<float>>()) {
|
||||||
sequence_.get());
|
mpms::AddContextFeatureFloats(key, value, sequence_.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) &&
|
if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) &&
|
||||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -455,6 +456,83 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoContextFloatLists) {
|
||||||
testing::ElementsAre(4, 4));
|
testing::ElementsAre(4, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, ReplaceTwoContextFloatLists) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"FLOAT_CONTEXT_FEATURE_TEST:test",
|
||||||
|
"FLOAT_CONTEXT_FEATURE_OTHER:test2"},
|
||||||
|
/*features=*/{},
|
||||||
|
/*output_only_if_all_present=*/false, /*replace_instead_of_append=*/true);
|
||||||
|
auto input_sequence = std::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetContextFeatureFloats("TEST", {2, 3}, input_sequence.get());
|
||||||
|
mpms::SetContextFeatureFloats("OTHER", {2, 4}, input_sequence.get());
|
||||||
|
|
||||||
|
const std::vector<float> vf_1 = {5, 6};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kFloatContextFeatureTestTag)
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<float>>(vf_1).At(Timestamp::PostStream()));
|
||||||
|
const std::vector<float> vf_2 = {7, 8};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kFloatContextFeatureOtherTag)
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<float>>(vf_2).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureFloats("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(5, 6));
|
||||||
|
ASSERT_THAT(mpms::GetContextFeatureFloats("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(7, 8));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, AppendTwoContextFloatLists) {
|
||||||
|
SetUpCalculator(
|
||||||
|
/*input_streams=*/{"FLOAT_CONTEXT_FEATURE_TEST:test",
|
||||||
|
"FLOAT_CONTEXT_FEATURE_OTHER:test2"},
|
||||||
|
/*features=*/{},
|
||||||
|
/*output_only_if_all_present=*/false,
|
||||||
|
/*replace_instead_of_append=*/false);
|
||||||
|
auto input_sequence = std::make_unique<tf::SequenceExample>();
|
||||||
|
mpms::SetContextFeatureFloats("TEST", {2, 3}, input_sequence.get());
|
||||||
|
mpms::SetContextFeatureFloats("OTHER", {2, 4}, input_sequence.get());
|
||||||
|
|
||||||
|
const std::vector<float> vf_1 = {5, 6};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kFloatContextFeatureTestTag)
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<float>>(vf_1).At(Timestamp::PostStream()));
|
||||||
|
const std::vector<float> vf_2 = {7, 8};
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kFloatContextFeatureOtherTag)
|
||||||
|
.packets.push_back(
|
||||||
|
MakePacket<std::vector<float>>(vf_2).At(Timestamp::PostStream()));
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
EXPECT_THAT(mpms::GetContextFeatureFloats("TEST", output_sequence),
|
||||||
|
testing::ElementsAre(2, 3, 5, 6));
|
||||||
|
EXPECT_THAT(mpms::GetContextFeatureFloats("OTHER", output_sequence),
|
||||||
|
testing::ElementsAre(2, 4, 7, 8));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) {
|
TEST_F(PackMediaSequenceCalculatorTest, PackTwoContextIntLists) {
|
||||||
SetUpCalculator(
|
SetUpCalculator(
|
||||||
/*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test",
|
/*input_streams=*/{"INTS_CONTEXT_FEATURE_TEST:test",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user