Better handling of empty packets in vector calculators.

PiperOrigin-RevId: 493000695
This commit is contained in:
MediaPipe Team 2022-12-05 07:22:51 -08:00 committed by Copybara-Service
parent e457039fc6
commit 35bb18945f
5 changed files with 51 additions and 14 deletions

View File

@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node {
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(mediapipe::TimestampDiff(0));
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>(); auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index()); RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
return absl::OkStatus(); return absl::OkStatus();
@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node {
return absl::OkStatus(); return absl::OkStatus();
} }
RET_CHECK(idx >= 0 && idx < items.size()); RET_CHECK(idx >= 0);
RET_CHECK(options.output_empty_on_oob() || idx < items.size());
if (idx < items.size()) {
kOut(cc).Send(items[idx]); kOut(cc).Send(items[idx]);
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions {
// Index of vector item to get. INDEX input stream can be used instead, or to // Index of vector item to get. INDEX input stream can be used instead, or to
// override. // override.
optional int32 item_index = 1; optional int32 item_index = 1;
// Set to true to output an empty packet when the index is out of bounds.
optional bool output_empty_on_oob = 2;
} }

View File

@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() {
)"); )");
} }
CalculatorRunner MakeRunnerWithOptions(int set_index) { CalculatorRunner MakeRunnerWithOptions(int set_index,
return CalculatorRunner(absl::StrFormat(R"( bool output_empty_on_oob = false) {
return CalculatorRunner(
absl::StrFormat(R"(
calculator: "TestGetIntVectorItemCalculator" calculator: "TestGetIntVectorItemCalculator"
input_stream: "VECTOR:vector_stream" input_stream: "VECTOR:vector_stream"
output_stream: "ITEM:item_stream" output_stream: "ITEM:item_stream"
options { options {
[mediapipe.GetVectorItemCalculatorOptions.ext] { [mediapipe.GetVectorItemCalculatorOptions.ext] {
item_index: %d item_index: %d
output_empty_on_oob: %s
} }
} }
)", )",
set_index)); set_index, output_empty_on_oob ? "true" : "false"));
} }
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs, void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) {
absl::Status status = runner.Run(); absl::Status status = runner.Run();
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
testing::HasSubstr("idx >= 0 && idx < items.size()"));
} }
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) { TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
absl::Status status = runner.Run(); absl::Status status = runner.Run();
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()")); testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
} }
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) { TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
absl::Status status = runner.Run(); absl::Status status = runner.Run();
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
testing::HasSubstr("idx >= 0 && idx < items.size()"));
} }
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) { TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
absl::Status status = runner.Run(); absl::Status status = runner.Run();
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()")); testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) {
const int try_index = 3;
CalculatorRunner runner = MakeRunnerWithOptions(try_index, true);
const std::vector<int> inputs = {1, 2, 3};
AddInputVector(runner, inputs, 1);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
EXPECT_THAT(outputs, testing::ElementsAre());
} }
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) { TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {

View File

@ -23,5 +23,9 @@ namespace api2 {
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator; typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator); MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
typedef MergeToVectorCalculator<mediapipe::GpuBuffer>
MergeGpuBuffersToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator);
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(::mediapipe::CalculatorContext* cc) {
cc->SetOffset(::mediapipe::TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) { absl::Status Process(CalculatorContext* cc) {
const int input_num = kIn(cc).Count(); const int input_num = kIn(cc).Count();
std::vector<T> output_vector(input_num); std::vector<T> output_vector;
std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(), for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) {
[](const auto& elem) -> T { return elem.Get(); }); const auto& elem = *it;
if (!elem.IsEmpty()) {
output_vector.push_back(elem.Get());
}
}
kOut(cc).Send(output_vector); kOut(cc).Send(output_vector);
return absl::OkStatus(); return absl::OkStatus();
} }