Better handling of empty packets in vector calculators.

PiperOrigin-RevId: 493000695
This commit is contained in:
MediaPipe Team 2022-12-05 07:22:51 -08:00 committed by Copybara-Service
parent e457039fc6
commit 35bb18945f
5 changed files with 51 additions and 14 deletions

View File

@ -65,6 +65,7 @@ class GetVectorItemCalculator : public Node {
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
absl::Status Open(CalculatorContext* cc) final {
cc->SetOffset(mediapipe::TimestampDiff(0));
auto& options = cc->Options<mediapipe::GetVectorItemCalculatorOptions>();
RET_CHECK(kIdx(cc).IsConnected() || options.has_item_index());
return absl::OkStatus();
@ -90,8 +91,12 @@ class GetVectorItemCalculator : public Node {
return absl::OkStatus();
}
RET_CHECK(idx >= 0 && idx < items.size());
kOut(cc).Send(items[idx]);
RET_CHECK(idx >= 0);
RET_CHECK(options.output_empty_on_oob() || idx < items.size());
if (idx < items.size()) {
kOut(cc).Send(items[idx]);
}
return absl::OkStatus();
}

View File

@ -26,4 +26,7 @@ message GetVectorItemCalculatorOptions {
// Index of vector item to get. INDEX input stream can be used instead, or to
// override.
optional int32 item_index = 1;
// Set to true to output an empty packet when the index is out of bounds.
optional bool output_empty_on_oob = 2;
}

View File

@ -32,18 +32,21 @@ CalculatorRunner MakeRunnerWithStream() {
)");
}
CalculatorRunner MakeRunnerWithOptions(int set_index) {
return CalculatorRunner(absl::StrFormat(R"(
CalculatorRunner MakeRunnerWithOptions(int set_index,
bool output_empty_on_oob = false) {
return CalculatorRunner(
absl::StrFormat(R"(
calculator: "TestGetIntVectorItemCalculator"
input_stream: "VECTOR:vector_stream"
output_stream: "ITEM:item_stream"
options {
[mediapipe.GetVectorItemCalculatorOptions.ext] {
item_index: %d
output_empty_on_oob: %s
}
}
)",
set_index));
set_index, output_empty_on_oob ? "true" : "false"));
}
void AddInputVector(CalculatorRunner& runner, const std::vector<int>& inputs,
@ -140,8 +143,7 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail1) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
}
TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
@ -155,7 +157,8 @@ TEST(TestGetIntVectorItemCalculatorTest, StreamIndexBoundsCheckFail2) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
@ -167,8 +170,7 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail1) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
EXPECT_THAT(status.message(), testing::HasSubstr("idx >= 0"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
@ -181,7 +183,21 @@ TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail2) {
absl::Status status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("idx >= 0 && idx < items.size()"));
testing::HasSubstr(
"options.output_empty_on_oob() || idx < items.size()"));
}
TEST(TestGetIntVectorItemCalculatorTest, OptionsIndexBoundsCheckFail3) {
const int try_index = 3;
CalculatorRunner runner = MakeRunnerWithOptions(try_index, true);
const std::vector<int> inputs = {1, 2, 3};
AddInputVector(runner, inputs, 1);
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets;
EXPECT_THAT(outputs, testing::ElementsAre());
}
TEST(TestGetIntVectorItemCalculatorTest, IndexStreamTwoTimestamps) {

View File

@ -23,5 +23,9 @@ namespace api2 {
typedef MergeToVectorCalculator<mediapipe::Image> MergeImagesToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeImagesToVectorCalculator);
typedef MergeToVectorCalculator<mediapipe::GpuBuffer>
MergeGpuBuffersToVectorCalculator;
MEDIAPIPE_REGISTER_NODE(MergeGpuBuffersToVectorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -42,11 +42,20 @@ class MergeToVectorCalculator : public Node {
return absl::OkStatus();
}
absl::Status Open(::mediapipe::CalculatorContext* cc) {
cc->SetOffset(::mediapipe::TimestampDiff(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) {
const int input_num = kIn(cc).Count();
std::vector<T> output_vector(input_num);
std::transform(kIn(cc).begin(), kIn(cc).end(), output_vector.begin(),
[](const auto& elem) -> T { return elem.Get(); });
std::vector<T> output_vector;
for (auto it = kIn(cc).begin(); it != kIn(cc).end(); it++) {
const auto& elem = *it;
if (!elem.IsEmpty()) {
output_vector.push_back(elem.Get());
}
}
kOut(cc).Send(output_vector);
return absl::OkStatus();
}