Better handling of empty packets in vector calculators.
PiperOrigin-RevId: 493000695
This commit is contained in:
parent
e457039fc6
commit
35bb18945f
|
@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node {
|
||||||
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) final {
|
absl::Status Open(CalculatorContext* cc) final {
|
||||||
|
cc->SetOffset(mediapipe::TimestampDiff(0));
|
||||||
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
|
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
|
||||||
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
|
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(idx >= 0 && idx < items.size());
|
RET_CHECK(idx >= 0);
|
||||||
|
RET_CHECK(options.output_empty_on_oob() || idx < items.size());
|
||||||
|
|
||||||
|
if (idx < items.size()) {
|
||||||
kOut(cc).Send(items[idx]);
|
kOut(cc).Send(items[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions {
|
||||||
// Index of vector item to get. INDEX input stream can be used instead, or to
|
// Index of vector item to get. INDEX input stream can be used instead, or to
|
||||||
// override.
|
// override.
|
||||||
optional int32 item_index = 1;
|
optional int32 item_index = 1;
|
||||||
|
|
||||||
|
// Set to true to output an empty packet when the index is out of bounds.
|
||||||
|
optional bool output_empty_on_oob = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() {
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
CalculatorRunner MakeRunnerWithOptions(int set_index) {
|
CalculatorRunner MakeRunnerWithOptions(int set_index,
|
||||||
return CalculatorRunner(absl::StrFormat(R"(
|
bool output_empty_on_oob = false) {
|
||||||
|
return CalculatorRunner(
|
||||||
|
absl::StrFormat(R"(
|
||||||
calculator: "TestGetIntVectorItemCalculator"
|
calculator: "TestGetIntVectorItemCalculator"
|
||||||
input_stream: "VECTOR:vector_stream"
|
input_stream: "VECTOR:vector_stream"
|
||||||
output_stream: "ITEM:item_stream"
|
output_stream: "ITEM:item_stream"
|
||||||
options {
|
options {
|
||||||
[mediapipe.GetVectorItemCalculatorOptions.ext] {
|
[mediapipe.GetVectorItemCalculatorOptions.ext] {
|
||||||
item_index: %d
|
item_index: %d
|
||||||
|
output_empty_on_oob: %s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
set_index));
|
set_index, output_empty_on_oob ? "true" : "false"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
|
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
|
||||||
|
@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) {
|
||||||
|
|
||||||
absl::Status status = runner.Run();
|
absl::Status status = runner.Run();
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
|
||||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
|
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
|
||||||
|
@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
|
||||||
absl::Status status = runner.Run();
|
absl::Status status = runner.Run();
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(),
|
||||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
testing::HasSubstr(
|
||||||
|
"options.output_empty_on_oob() || idx < items.size()"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
|
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
|
||||||
|
@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
|
||||||
|
|
||||||
absl::Status status = runner.Run();
|
absl::Status status = runner.Run();
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
|
||||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
|
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
|
||||||
|
@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
|
||||||
absl::Status status = runner.Run();
|
absl::Status status = runner.Run();
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(),
|
||||||
testing::HasSubstr("idx >= 0 && idx < items.size()"));
|
testing::HasSubstr(
|
||||||
|
"options.output_empty_on_oob() || idx < items.size()"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) {
|
||||||
|
const int try_index = 3;
|
||||||
|
CalculatorRunner runner = MakeRunnerWithOptions(try_index, true);
|
||||||
|
const std::vector<int> inputs = {1, 2, 3};
|
||||||
|
|
||||||
|
AddInputVector(runner, inputs, 1);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
|
||||||
|
EXPECT_THAT(outputs, testing::ElementsAre());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {
|
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {
|
||||||
|
|
|
@ -23,5 +23,9 @@ namespace api2 {
|
||||||
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
|
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
|
||||||
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
|
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
|
||||||
|
|
||||||
|
typedef MergeToVectorCalculator<mediapipe::GpuBuffer>
|
||||||
|
MergeGpuBuffersToVectorCalculator;
|
||||||
|
MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status Open(::mediapipe::CalculatorContext* cc) {
|
||||||
|
cc->SetOffset(::mediapipe::TimestampDiff(0));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) {
|
absl::Status Process(CalculatorContext* cc) {
|
||||||
const int input_num = kIn(cc).Count();
|
const int input_num = kIn(cc).Count();
|
||||||
std::vector<T> output_vector(input_num);
|
std::vector<T> output_vector;
|
||||||
std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(),
|
for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) {
|
||||||
[](const auto& elem) -> T { return elem.Get(); });
|
const auto& elem = *it;
|
||||||
|
if (!elem.IsEmpty()) {
|
||||||
|
output_vector.push_back(elem.Get());
|
||||||
|
}
|
||||||
|
}
|
||||||
kOut(cc).Send(output_vector);
|
kOut(cc).Send(output_vector);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user