Move hand_association_calculator to open source MP

PiperOrigin-RevId: 477901001
This commit is contained in:
MediaPipe Team 2022-09-30 05:13:59 +00:00 committed by Sebastian Schmidt
parent 133c3b3c00
commit 3a3a470a0c
4 changed files with 504 additions and 0 deletions

View File

@ -0,0 +1,49 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/app/xeno:__subpackages__",
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "hand_association_calculator_proto",
srcs = ["hand_association_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "hand_association_calculator",
srcs = ["hand_association_calculator.cc"],
deps = [
":hand_association_calculator_cc_proto",
"//mediapipe/calculators/util:association_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
"//mediapipe/util:rectangle_util",
],
alwayslink = 1,
)
# TODO: Enable this test

View File

@ -0,0 +1,125 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/util/rectangle_util.h"
namespace mediapipe::api2 {
// HandAssociationCalculator accepts multiple inputs of vectors of
// NormalizedRect. The output is a vector of NormalizedRect that contains
// rects from the input vectors that don't overlap with each other. When two
// rects overlap, the rect that comes in from an earlier input stream is
// kept in the output. If a rect has no ID (i.e. from detection stream),
// then a unique rect ID is assigned for it.
// The rects in multiple input streams are effectively flattened to a single
// list. For example:
// Stream1 : rect 1, rect 2
// Stream2: rect 3, rect 4
// Stream3: rect 5, rect 6
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
// In the flattened list, if a rect with a higher index overlaps with a rect a
// lower index, beyond a specified IOU threshold, the rect with the lower
// index will be in the output, and the rect with higher index will be
// discarded.
// TODO: Upgrade this to latest API for calculators
class HandAssociationCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
// Initialize input and output streams.
for (auto& input_stream : cc->Inputs()) {
input_stream.Set<std::vector<NormalizedRect>>();
}
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<HandAssociationCalculatorOptions>();
CHECK_GT(options_.min_similarity_threshold(), 0.0);
CHECK_LE(options_.min_similarity_threshold(), 1.0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc));
auto output =
std::make_unique<std::vector<NormalizedRect>>(std::move(result));
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
private:
HandAssociationCalculatorOptions options_;
// Return a list of non-overlapping elements from all input streams, with
// decreasing order of priority based on input stream index and indices
// within an input stream.
absl::StatusOr<std::vector<NormalizedRect>> GetNonOverlappingElements(
CalculatorContext* cc) {
std::vector<NormalizedRect> result;
for (const auto& input_stream : cc->Inputs()) {
if (input_stream.IsEmpty()) {
continue;
}
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
ASSIGN_OR_RETURN(
bool is_overlapping,
mediapipe::DoesRectOverlap(rect, result,
options_.min_similarity_threshold()));
if (!is_overlapping) {
if (!rect.has_rect_id()) {
rect.set_rect_id(GetNextRectId());
}
result.push_back(rect);
}
}
}
return result;
}
private:
// Each NormalizedRect processed by the calculator will be assigned
// an unique id, if it does not already have an ID. The starting ID will be 1.
// Note: This rect_id_ is local to an instance of this calculator. And it is
// expected that the hand tracking graph to have only one instance of
// this association calculator.
int64 rect_id_ = 1;
inline int GetNextRectId() { return rect_id_++; }
};
MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator);
} // namespace mediapipe::api2

View File

@ -0,0 +1,28 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message HandAssociationCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional HandAssociationCalculatorOptions ext = 408244367;
}
optional float min_similarity_threshold = 1 [default = 1.0];
}

View File

@ -0,0 +1,302 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
namespace mediapipe {
namespace {
class HandAssociationCalculatorTest : public testing::Test {
protected:
HandAssociationCalculatorTest() {
// 0.4 ================
// | | | |
// 0.3 ===================== | NR2 | |
// | | | NR1 | | | NR4 |
// 0.2 | NR0 | =========== ================
// | | | | | |
// 0.1 =====|=============== |
// | NR3 | | |
// 0.0 ================ |
// | NR5 |
// -0.1 ===========
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
// NormalizedRect nr_0.
nr_0_.set_x_center(0.2);
nr_0_.set_y_center(0.2);
nr_0_.set_width(0.2);
nr_0_.set_height(0.2);
// NormalizedRect nr_1.
nr_1_.set_x_center(0.4);
nr_1_.set_y_center(0.2);
nr_1_.set_width(0.2);
nr_1_.set_height(0.2);
// NormalizedRect nr_2.
nr_2_.set_x_center(1.0);
nr_2_.set_y_center(0.3);
nr_2_.set_width(0.2);
nr_2_.set_height(0.2);
// NormalizedRect nr_3.
nr_3_.set_x_center(0.35);
nr_3_.set_y_center(0.15);
nr_3_.set_width(0.3);
nr_3_.set_height(0.3);
// NormalizedRect nr_4.
nr_4_.set_x_center(1.1);
nr_4_.set_y_center(0.3);
nr_4_.set_width(0.2);
nr_4_.set_height(0.2);
// NormalizedRect nr_5.
nr_5_.set_x_center(0.5);
nr_5_.set_y_center(0.05);
nr_5_.set_width(0.3);
nr_5_.set_height(0.3);
}
NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_;
};
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_0, nr_1, nr_2.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_0_);
input_vec_0->push_back(nr_1_);
input_vec_0->push_back(nr_2_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_3_);
input_vec_1->push_back(nr_4_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_5.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_5_);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_0 is added 1st.
// nr_1 is added because it does not overlap with nr_0.
// nr_2 is added because it does not overlap with nr_0 or nr_1.
// nr_3 is NOT added because it overlaps with nr_0.
// nr_4 is NOT added because it overlaps with nr_2.
// nr_5 is NOT added because it overlaps with nr_1.
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_0, nr_1. Tracked hands.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
// Setting ID to a negative number for test only, since newly generated
// ID by HandAssociationCalculator are positive numbers.
nr_0_.set_rect_id(-2);
input_vec_0->push_back(nr_0_);
nr_1_.set_rect_id(-1);
input_vec_0->push_back(nr_1_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_2, nr_3. Newly detected palms.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_2_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_0 is added 1st.
// nr_1 is added because it does not overlap with nr_0.
// nr_2 is added because it does not overlap with nr_0 or nr_1.
// nr_3 is NOT added because it overlaps with nr_0.
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_5.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4, nr_3
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_1, nr_0.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_1_);
input_vec_2->push_back(nr_0_);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_5 is added 1st.
// nr_4 is added because it does not overlap with nr_5.
// nr_3 is NOT added because it overlaps with nr_5.
// nr_2 is NOT added because it overlaps with nr_4.
// nr_1 is NOT added because it overlaps with nr_5.
// nr_0 is added because it does not overlap with nr_5 or nr_4.
EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Just one input stream : nr_3, nr_5.
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
input_vec->push_back(nr_3_);
input_vec->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_3 is added 1st.
// nr_5 is NOT added because it overlaps with nr_3.
EXPECT_EQ(1, assoc_rects.size());
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
}
} // namespace
} // namespace mediapipe