Preserves all elements of BASE_HAND_RECTS
input streams in HandAssociationCalculator
.
PiperOrigin-RevId: 516339343
This commit is contained in:
parent
c32ddcb04c
commit
0f58d89992
|
@ -36,6 +36,7 @@ cc_library(
|
||||||
":hand_association_calculator_cc_proto",
|
":hand_association_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:association_calculator",
|
"//mediapipe/calculators/util:association_calculator",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection_item_id",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:rectangle",
|
"//mediapipe/framework/port:rectangle",
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/collection_item_id.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/rectangle.h"
|
#include "mediapipe/framework/port/rectangle.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
@ -29,30 +30,55 @@ namespace mediapipe::api2 {
|
||||||
|
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
|
||||||
// HandAssociationCalculator accepts multiple inputs of vectors of
|
// Input:
|
||||||
// NormalizedRect. The output is a vector of NormalizedRect that contains
|
// BASE_RECTS - Vector of NormalizedRect.
|
||||||
// rects from the input vectors that don't overlap with each other. When two
|
// RECTS - Vector of NormalizedRect.
|
||||||
// rects overlap, the rect that comes in from an earlier input stream is
|
//
|
||||||
// kept in the output. If a rect has no ID (i.e. from detection stream),
|
// Output:
|
||||||
// then a unique rect ID is assigned for it.
|
// No tag - Vector of NormalizedRect.
|
||||||
|
//
|
||||||
// The rects in multiple input streams are effectively flattened to a single
|
// Example use:
|
||||||
// list. For example:
|
// node {
|
||||||
// Stream1 : rect 1, rect 2
|
// calculator: "HandAssociationCalculator"
|
||||||
// Stream2: rect 3, rect 4
|
// input_stream: "BASE_RECTS:base_rects"
|
||||||
// Stream3: rect 5, rect 6
|
// input_stream: "RECTS:0:rects0"
|
||||||
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
|
// input_stream: "RECTS:1:rects1"
|
||||||
// In the flattened list, if a rect with a higher index overlaps with a rect a
|
// input_stream: "RECTS:2:rects2"
|
||||||
// lower index, beyond a specified IOU threshold, the rect with the lower
|
// output_stream: "output_rects"
|
||||||
// index will be in the output, and the rect with higher index will be
|
// options {
|
||||||
// discarded.
|
// [mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
// min_similarity_threshold: 0.1
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// IMPORTANT Notes:
|
||||||
|
// - Rects from input streams tagged with "BASE_RECTS" are always preserved.
|
||||||
|
// - This calculator checks for overlap among rects from input streams tagged
|
||||||
|
// with "RECTS". Rects are prioritized based on their index in the vector and
|
||||||
|
// input streams to the calculator. When two rects overlap, the rect that
|
||||||
|
// comes from an input stream with lower tag-index is kept in the output.
|
||||||
|
// - Example of inputs for the node above:
|
||||||
|
// "base_rects": rect 0, rect 1
|
||||||
|
// "rects0": rect 2, rect 3
|
||||||
|
// "rects1": rect 4, rect 5
|
||||||
|
// "rects2": rect 6, rect 7
|
||||||
|
// (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7.
|
||||||
|
// Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for
|
||||||
|
// overlap. If a rect with a higher index overlaps with a rect with lower
|
||||||
|
// index, beyond a specified IOU threshold, the rect with the lower index
|
||||||
|
// will be in the output, and the rect with higher index will be discarded.
|
||||||
// TODO: Upgrade this to latest API for calculators
|
// TODO: Upgrade this to latest API for calculators
|
||||||
class HandAssociationCalculator : public CalculatorBase {
|
class HandAssociationCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
// Initialize input and output streams.
|
// Initialize input and output streams.
|
||||||
for (auto& input_stream : cc->Inputs()) {
|
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
|
||||||
input_stream.Set<std::vector<NormalizedRect>>();
|
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
|
||||||
|
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
|
||||||
|
}
|
||||||
|
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
|
||||||
|
id != cc->Inputs().EndId("RECTS"); ++id) {
|
||||||
|
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
|
||||||
}
|
}
|
||||||
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
||||||
|
|
||||||
|
@ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase {
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
std::vector<NormalizedRect> result;
|
std::vector<NormalizedRect> result;
|
||||||
|
|
||||||
for (const auto& input_stream : cc->Inputs()) {
|
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
|
||||||
|
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
|
||||||
|
const auto& input_stream = cc->Inputs().Get(id);
|
||||||
|
if (input_stream.IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
|
||||||
|
if (!rect.has_rect_id()) {
|
||||||
|
rect.set_rect_id(GetNextRectId());
|
||||||
|
}
|
||||||
|
result.push_back(rect);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
|
||||||
|
id != cc->Inputs().EndId("RECTS"); ++id) {
|
||||||
|
const auto& input_stream = cc->Inputs().Get(id);
|
||||||
if (input_stream.IsEmpty()) {
|
if (input_stream.IsEmpty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,8 @@ namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::EqualsProto;
|
||||||
|
|
||||||
class HandAssociationCalculatorTest : public testing::Test {
|
class HandAssociationCalculatorTest : public testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
@ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test {
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
input_vec_0->push_back(nr_0_);
|
input_vec_0->push_back(nr_0_);
|
||||||
input_vec_0->push_back(nr_1_);
|
input_vec_0->push_back(nr_1_);
|
||||||
input_vec_0->push_back(nr_2_);
|
input_vec_0->push_back(nr_2_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_3, nr_4.
|
// Input Stream 1: nr_3, nr_4.
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
input_vec_1->push_back(nr_4_);
|
input_vec_1->push_back(nr_4_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 2: nr_5.
|
// Input Stream 2: nr_5.
|
||||||
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_2->push_back(nr_5_);
|
input_vec_2->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(2).packets.push_back(
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Check that IDs are filled in and contents match.
|
// Check that IDs are filled in and contents match.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_0_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_1_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
nr_2_.set_rect_id(3);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
EqualsProto(nr_2_)));
|
||||||
assoc_rects[1].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
|
||||||
assoc_rects[2].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
input_vec_0->push_back(nr_0_);
|
input_vec_0->push_back(nr_0_);
|
||||||
nr_1_.set_rect_id(-1);
|
nr_1_.set_rect_id(-1);
|
||||||
input_vec_0->push_back(nr_1_);
|
input_vec_0->push_back(nr_1_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_2, nr_3. Newly detected palms.
|
// Input Stream 1: nr_2, nr_3. Newly detected palms.
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_2_);
|
input_vec_1->push_back(nr_2_);
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
runner.MutableInputs()->Tag("RECTS").packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Check that IDs are filled in and contents match.
|
// Check that IDs are filled in and contents match.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
|
nr_2_.set_rect_id(1);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
|
||||||
|
EqualsProto(nr_2_)));
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
|
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
|
|
||||||
assoc_rects[2].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
// Input Stream 0: nr_5.
|
// Input Stream 0: nr_5.
|
||||||
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_0->push_back(nr_5_);
|
input_vec_0->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_4, nr_3
|
// Input Stream 1: nr_4, nr_3
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_4_);
|
input_vec_1->push_back(nr_4_);
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 2: nr_2, nr_1, nr_0.
|
// Input Stream 2: nr_2, nr_1, nr_0.
|
||||||
|
@ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
input_vec_2->push_back(nr_2_);
|
input_vec_2->push_back(nr_2_);
|
||||||
input_vec_2->push_back(nr_1_);
|
input_vec_2->push_back(nr_1_);
|
||||||
input_vec_2->push_back(nr_0_);
|
input_vec_2->push_back(nr_0_);
|
||||||
runner.MutableInputs()->Index(2).packets.push_back(
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Outputs are in same order as inputs, and IDs are filled in.
|
// Outputs are in same order as inputs, and IDs are filled in.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_5_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_4_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
|
nr_0_.set_rect_id(3);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_),
|
||||||
|
EqualsProto(nr_0_)));
|
||||||
|
}
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) {
|
||||||
assoc_rects[1].clear_rect_id();
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
|
calculator: "HandAssociationCalculator"
|
||||||
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
|
output_stream: "output_vec"
|
||||||
|
options {
|
||||||
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
min_similarity_threshold: 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
// Input Stream 0: nr_5, nr_3, nr_1.
|
||||||
assoc_rects[2].clear_rect_id();
|
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
|
input_vec_0->push_back(nr_5_);
|
||||||
|
input_vec_0->push_back(nr_3_);
|
||||||
|
input_vec_0->push_back(nr_1_);
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
// Input Stream 1: nr_4.
|
||||||
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
|
input_vec_1->push_back(nr_4_);
|
||||||
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
// Input Stream 2: nr_2, nr_0.
|
||||||
|
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
|
input_vec_2->push_back(nr_2_);
|
||||||
|
input_vec_2->push_back(nr_0_);
|
||||||
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, output.size());
|
||||||
|
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||||
|
|
||||||
|
// Rectangles are added in the following sequence:
|
||||||
|
// nr_5 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_3 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_1 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_4 is added because it does not overlap with nr_5.
|
||||||
|
// nr_2 is NOT added because it overlaps with nr_4.
|
||||||
|
// nr_0 is NOT added because it overlaps with nr_3.
|
||||||
|
EXPECT_EQ(4, assoc_rects.size());
|
||||||
|
|
||||||
|
// Outputs are in same order as inputs, and IDs are filled in.
|
||||||
|
nr_5_.set_rect_id(1);
|
||||||
|
nr_3_.set_rect_id(2);
|
||||||
|
nr_1_.set_rect_id(3);
|
||||||
|
nr_4_.set_rect_id(4);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_),
|
||||||
|
EqualsProto(nr_1_), EqualsProto(nr_4_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec"
|
input_stream: "BASE_RECTS:input_vec"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec->push_back(nr_3_);
|
input_vec->push_back(nr_3_);
|
||||||
input_vec->push_back(nr_5_);
|
input_vec->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||||
|
@ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
|
|
||||||
// Rectangles are added in the following sequence:
|
// Rectangles are added in the following sequence:
|
||||||
// nr_3 is added 1st.
|
// nr_3 is added 1st.
|
||||||
// nr_5 is NOT added because it overlaps with nr_3.
|
// nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3.
|
||||||
EXPECT_EQ(1, assoc_rects.size());
|
EXPECT_EQ(2, assoc_rects.size());
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_3_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_5_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
.set_min_similarity_threshold(
|
.set_min_similarity_threshold(
|
||||||
tasks_options.min_tracking_confidence());
|
tasks_options.min_tracking_confidence());
|
||||||
prev_hand_rects_from_landmarks >>
|
prev_hand_rects_from_landmarks >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
hand_association[Input<std::vector<NormalizedRect>>("BASE_RECTS")];
|
||||||
hand_rects_from_hand_detector >>
|
hand_rects_from_hand_detector >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
hand_association[Input<std::vector<NormalizedRect>>("RECTS")];
|
||||||
auto hand_rects = hand_association.Out("");
|
auto hand_rects = hand_association.Out("");
|
||||||
hand_rects >> clip_hand_rects.In("");
|
hand_rects >> clip_hand_rects.In("");
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user