Preserves all elements of BASE_HAND_RECTS input streams in HandAssociationCalculator.

PiperOrigin-RevId: 516339343
This commit is contained in:
Esha Uboweja 2023-03-13 15:27:49 -07:00 committed by Copybara-Service
parent c32ddcb04c
commit 0f58d89992
4 changed files with 175 additions and 79 deletions

View File

@ -36,6 +36,7 @@ cc_library(
":hand_association_calculator_cc_proto",
"//mediapipe/calculators/util:association_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
@ -29,30 +30,55 @@ namespace mediapipe::api2 {
using ::mediapipe::NormalizedRect;
// HandAssociationCalculator accepts multiple inputs of vectors of
// NormalizedRect. The output is a vector of NormalizedRect that contains
// rects from the input vectors that don't overlap with each other. When two
// rects overlap, the rect that comes in from an earlier input stream is
// kept in the output. If a rect has no ID (i.e. from detection stream),
// then a unique rect ID is assigned for it.
// The rects in multiple input streams are effectively flattened to a single
// list. For example:
// Stream1 : rect 1, rect 2
// Stream2: rect 3, rect 4
// Stream3: rect 5, rect 6
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
// In the flattened list, if a rect with a higher index overlaps with a rect a
// lower index, beyond a specified IOU threshold, the rect with the lower
// index will be in the output, and the rect with higher index will be
// discarded.
// Input:
// BASE_RECTS - Vector of NormalizedRect.
// RECTS - Vector of NormalizedRect.
//
// Output:
// No tag - Vector of NormalizedRect.
//
// Example use:
// node {
// calculator: "HandAssociationCalculator"
// input_stream: "BASE_RECTS:base_rects"
// input_stream: "RECTS:0:rects0"
// input_stream: "RECTS:1:rects1"
// input_stream: "RECTS:2:rects2"
// output_stream: "output_rects"
// options {
// [mediapipe.HandAssociationCalculatorOptions.ext] {
// min_similarity_threshold: 0.1
// }
// }
//
// IMPORTANT Notes:
// - Rects from input streams tagged with "BASE_RECTS" are always preserved.
// - This calculator checks for overlap among rects from input streams tagged
// with "RECTS". Rects are prioritized based on their index in the vector and
// input streams to the calculator. When two rects overlap, the rect that
// comes from an input stream with lower tag-index is kept in the output.
// - Example of inputs for the node above:
// "base_rects": rect 0, rect 1
// "rects0": rect 2, rect 3
// "rects1": rect 4, rect 5
// "rects2": rect 6, rect 7
// (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7.
// Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for
// overlap. If a rect with a higher index overlaps with a rect with lower
// index, beyond a specified IOU threshold, the rect with the lower index
// will be in the output, and the rect with higher index will be discarded.
// TODO: Upgrade this to latest API for calculators
class HandAssociationCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
// Initialize input and output streams.
for (auto& input_stream : cc->Inputs()) {
input_stream.Set<std::vector<NormalizedRect>>();
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
}
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
id != cc->Inputs().EndId("RECTS"); ++id) {
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
}
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
@ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase {
CalculatorContext* cc) {
std::vector<NormalizedRect> result;
for (const auto& input_stream : cc->Inputs()) {
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
const auto& input_stream = cc->Inputs().Get(id);
if (input_stream.IsEmpty()) {
continue;
}
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
if (!rect.has_rect_id()) {
rect.set_rect_id(GetNextRectId());
}
result.push_back(rect);
}
}
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
id != cc->Inputs().EndId("RECTS"); ++id) {
const auto& input_stream = cc->Inputs().Get(id);
if (input_stream.IsEmpty()) {
continue;
}

View File

@ -27,6 +27,8 @@ namespace mediapipe {
namespace {
using ::mediapipe::NormalizedRect;
using ::testing::ElementsAre;
using ::testing::EqualsProto;
class HandAssociationCalculatorTest : public testing::Test {
protected:
@ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test {
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
input_stream: "BASE_RECTS:input_vec_0"
input_stream: "RECTS:0:input_vec_1"
input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
@ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
input_vec_0->push_back(nr_0_);
input_vec_0->push_back(nr_1_);
input_vec_0->push_back(nr_2_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_3_);
input_vec_1->push_back(nr_4_);
runner.MutableInputs()->Index(1).packets.push_back(
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_5.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_5_);
runner.MutableInputs()->Index(2).packets.push_back(
index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
nr_0_.set_rect_id(1);
nr_1_.set_rect_id(2);
nr_2_.set_rect_id(3);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
EqualsProto(nr_2_)));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
input_stream: "BASE_RECTS:input_vec_0"
input_stream: "RECTS:0:input_vec_1"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
@ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
input_vec_0->push_back(nr_0_);
nr_1_.set_rect_id(-1);
input_vec_0->push_back(nr_1_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_2, nr_3. Newly detected palms.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_2_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
runner.MutableInputs()->Tag("RECTS").packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
nr_2_.set_rect_id(1);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
EqualsProto(nr_2_)));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
input_stream: "BASE_RECTS:input_vec_0"
input_stream: "RECTS:0:input_vec_1"
input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
@ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
// Input Stream 0: nr_5.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4, nr_3
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_1, nr_0.
@ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_1_);
input_vec_2->push_back(nr_0_);
runner.MutableInputs()->Index(2).packets.push_back(
index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
nr_5_.set_rect_id(1);
nr_4_.set_rect_id(2);
nr_0_.set_rect_id(3);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_),
EqualsProto(nr_0_)));
}
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "BASE_RECTS:input_vec_0"
input_stream: "RECTS:0:input_vec_1"
input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
// Input Stream 0: nr_5, nr_3, nr_1.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_5_);
input_vec_0->push_back(nr_3_);
input_vec_0->push_back(nr_1_);
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_);
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_0.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_0_);
index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_5 is added because it is in BASE_RECTS input stream.
// nr_3 is added because it is in BASE_RECTS input stream.
// nr_1 is added because it is in BASE_RECTS input stream.
// nr_4 is added because it does not overlap with nr_5.
// nr_2 is NOT added because it overlaps with nr_4.
// nr_0 is NOT added because it overlaps with nr_3.
EXPECT_EQ(4, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in.
nr_5_.set_rect_id(1);
nr_3_.set_rect_id(2);
nr_1_.set_rect_id(3);
nr_4_.set_rect_id(4);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_),
EqualsProto(nr_1_), EqualsProto(nr_4_)));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec"
input_stream: "BASE_RECTS:input_vec"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
@ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
input_vec->push_back(nr_3_);
input_vec->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec.release()).At(Timestamp(1)));
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
@ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
// Rectangles are added in the following sequence:
// nr_3 is added 1st.
// nr_5 is NOT added because it overlaps with nr_3.
EXPECT_EQ(1, assoc_rects.size());
// nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3.
EXPECT_EQ(2, assoc_rects.size());
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
nr_3_.set_rect_id(1);
nr_5_.set_rect_id(2);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_)));
}
} // namespace

View File

@ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
.set_min_similarity_threshold(
tasks_options.min_tracking_confidence());
prev_hand_rects_from_landmarks >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
hand_association[Input<std::vector<NormalizedRect>>("BASE_RECTS")];
hand_rects_from_hand_detector >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
hand_association[Input<std::vector<NormalizedRect>>("RECTS")];
auto hand_rects = hand_association.Out("");
hand_rects >> clip_hand_rects.In("");
} else {