Introduce CombineJointsCalculator
PiperOrigin-RevId: 570739088
This commit is contained in:
parent
7f1c17065a
commit
c81624d7b2
|
@ -1617,6 +1617,48 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "combine_joints_calculator",
|
||||||
|
srcs = ["combine_joints_calculator.cc"],
|
||||||
|
hdrs = ["combine_joints_calculator.h"],
|
||||||
|
deps = [
|
||||||
|
":combine_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "combine_joints_calculator_proto",
|
||||||
|
srcs = ["combine_joints_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/framework/formats:body_rig_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "combine_joints_calculator_test",
|
||||||
|
srcs = ["combine_joints_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":combine_joints_calculator",
|
||||||
|
":combine_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "pass_through_or_empty_detection_vector_calculator",
|
name = "pass_through_or_empty_detection_vector_calculator",
|
||||||
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
|
srcs = ["pass_through_or_empty_detection_vector_calculator.cc"],
|
||||||
|
|
79
mediapipe/calculators/util/combine_joints_calculator.cc
Normal file
79
mediapipe/calculators/util/combine_joints_calculator.cc
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/combine_joints_calculator.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/util/combine_joints_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
namespace {} // namespace
|
||||||
|
|
||||||
|
class CombineJointsCalculatorImpl : public NodeImpl<CombineJointsCalculator> {
|
||||||
|
public:
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
options_ = cc->Options<CombineJointsCalculatorOptions>();
|
||||||
|
RET_CHECK_GE(options_.num_joints(), 0);
|
||||||
|
RET_CHECK_GT(kInJoints(cc).Count(), 0);
|
||||||
|
RET_CHECK_EQ(kInJoints(cc).Count(), options_.joints_mapping_size());
|
||||||
|
RET_CHECK(options_.has_default_joint());
|
||||||
|
for (const auto& mapping : options_.joints_mapping()) {
|
||||||
|
for (int idx : mapping.idx()) {
|
||||||
|
RET_CHECK_GE(idx, 0);
|
||||||
|
RET_CHECK_LT(idx, options_.num_joints());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
// Initialize output joints with default values.
|
||||||
|
JointList out_joints;
|
||||||
|
for (int i = 0; i < options_.num_joints(); ++i) {
|
||||||
|
*out_joints.add_joint() = options_.default_joint();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override default joints with provided joints.
|
||||||
|
for (int i = 0; i < kInJoints(cc).Count(); ++i) {
|
||||||
|
// Skip empty joint streams.
|
||||||
|
if (kInJoints(cc)[i].IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const JointList& in_joints = kInJoints(cc)[i].Get();
|
||||||
|
const auto& mapping = options_.joints_mapping(i);
|
||||||
|
RET_CHECK_EQ(in_joints.joint_size(), mapping.idx_size());
|
||||||
|
for (int j = 0; j < in_joints.joint_size(); ++j) {
|
||||||
|
*out_joints.mutable_joint(mapping.idx(j)) = in_joints.joint(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kOutJoints(cc).Send(std::move(out_joints));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CombineJointsCalculatorOptions options_;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_NODE_IMPLEMENTATION(CombineJointsCalculatorImpl);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
64
mediapipe/calculators/util/combine_joints_calculator.h
Normal file
64
mediapipe/calculators/util/combine_joints_calculator.h
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
// A calculator to combine several joint sets into one.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// JOINTS - Multiple JointList
|
||||||
|
// Joint sets to combine into one. Subsets are applied in provided order and
|
||||||
|
// overwrite each other.
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// JOINTS - JointList
|
||||||
|
// Combined joints.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "CombineJointsCalculator"
|
||||||
|
// input_stream: "JOINTS:0:joints_0"
|
||||||
|
// input_stream: "JOINTS:1:joints_1"
|
||||||
|
// output_stream: "JOINTS:combined_joints"
|
||||||
|
// options: {
|
||||||
|
// [mediapipe.CombineJointsCalculatorOptions.ext] {
|
||||||
|
// num_joints: 63
|
||||||
|
// joints_mapping: { idx: [0, 1, 2] }
|
||||||
|
// joints_mapping: { idx: [2, 3] }
|
||||||
|
// default_joint: {
|
||||||
|
// rotation_6d: [1, 0, 0, 1, 0, 0]
|
||||||
|
// visibility: 1.0
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class CombineJointsCalculator : public NodeIntf {
|
||||||
|
public:
|
||||||
|
static constexpr Input<mediapipe::JointList>::Multiple kInJoints{"JOINTS"};
|
||||||
|
static constexpr Output<mediapipe::JointList> kOutJoints{"JOINTS"};
|
||||||
|
MEDIAPIPE_NODE_INTERFACE(CombineJointsCalculator, kInJoints, kOutJoints);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_
|
46
mediapipe/calculators/util/combine_joints_calculator.proto
Normal file
46
mediapipe/calculators/util/combine_joints_calculator.proto
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/formats/body_rig.proto";
|
||||||
|
|
||||||
|
message CombineJointsCalculatorOptions {
|
||||||
|
extend CalculatorOptions {
|
||||||
|
optional CombineJointsCalculatorOptions ext = 406440185;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapping from joint set to the resulting set.
|
||||||
|
message JointsMapping {
|
||||||
|
// Indexes of provided joints in the resulting joint set.
|
||||||
|
// All indexes must be within the [0, num_joints - 1] range.
|
||||||
|
repeated int32 idx = 1 [packed = true];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of joints in the resulting set.
|
||||||
|
optional int32 num_joints = 1;
|
||||||
|
|
||||||
|
// Mapping from joint sets to the resulting set.
|
||||||
|
// Number of mappings must be equal to number of provided joint sets. Number
|
||||||
|
// of indexes in each mapping must be equal to number of joints in
|
||||||
|
// corresponding joint set. Mappings are applied in the provided order and can
|
||||||
|
// overwrite each other.
|
||||||
|
repeated JointsMapping joints_mapping = 2;
|
||||||
|
|
||||||
|
// Default joint to initialize joints in the resulting set.
|
||||||
|
optional Joint default_joint = 3;
|
||||||
|
}
|
174
mediapipe/calculators/util/combine_joints_calculator_test.cc
Normal file
174
mediapipe/calculators/util/combine_joints_calculator_test.cc
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
|
|
||||||
|
Joint MakeJoint(const std::vector<float>& rotation_6d,
|
||||||
|
std::optional<float> visibility) {
|
||||||
|
Joint joint;
|
||||||
|
for (float r : rotation_6d) {
|
||||||
|
joint.add_rotation_6d(r);
|
||||||
|
}
|
||||||
|
if (visibility) {
|
||||||
|
joint.set_visibility(visibility.value());
|
||||||
|
}
|
||||||
|
return joint;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CombineJointsTestCase {
|
||||||
|
std::string test_name;
|
||||||
|
int num_joints;
|
||||||
|
std::string joints_mapping;
|
||||||
|
std::vector<std::vector<Joint>> in_joints;
|
||||||
|
std::vector<Joint> out_joints;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CombineJointsTest = ::testing::TestWithParam<CombineJointsTestCase>;
|
||||||
|
|
||||||
|
TEST_P(CombineJointsTest, CombineJointsTest) {
|
||||||
|
const CombineJointsTestCase& tc = GetParam();
|
||||||
|
|
||||||
|
std::string input_joint_streams = "";
|
||||||
|
for (int i = 0; i < tc.in_joints.size(); ++i) {
|
||||||
|
input_joint_streams +=
|
||||||
|
absl::StrFormat("input_stream: \"JOINTS:%d:joints_%d\"\n", i, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare graph.
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(absl::Substitute(
|
||||||
|
R"(
|
||||||
|
calculator: "CombineJointsCalculator"
|
||||||
|
$0
|
||||||
|
output_stream: "JOINTS:combined_joints"
|
||||||
|
options: {
|
||||||
|
[mediapipe.CombineJointsCalculatorOptions.ext] {
|
||||||
|
num_joints: $1
|
||||||
|
joints_mapping: [ $2 ]
|
||||||
|
default_joint: {
|
||||||
|
rotation_6d: [1, 0, 0, 1, 0, 0]
|
||||||
|
visibility: 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)",
|
||||||
|
input_joint_streams, tc.num_joints, tc.joints_mapping)));
|
||||||
|
|
||||||
|
// Prepare and send joints.
|
||||||
|
for (int i = 0; i < tc.in_joints.size(); ++i) {
|
||||||
|
JointList in_joints;
|
||||||
|
for (const auto& joint : tc.in_joints[i]) {
|
||||||
|
*in_joints.add_joint() = joint;
|
||||||
|
}
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Get("JOINTS", i)
|
||||||
|
.packets.push_back(MakePacket<JointList>(std::move(in_joints))
|
||||||
|
.At(mediapipe::Timestamp(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the graph.
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& output_packets = runner.Outputs().Tag("JOINTS").packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
|
||||||
|
const auto& out_joints = output_packets[0].Get<JointList>();
|
||||||
|
EXPECT_EQ(out_joints.joint_size(), tc.out_joints.size());
|
||||||
|
for (int i = 0; i < out_joints.joint_size(); ++i) {
|
||||||
|
const Joint& actual = out_joints.joint(i);
|
||||||
|
const Joint& expected = tc.out_joints[i];
|
||||||
|
|
||||||
|
EXPECT_EQ(actual.rotation_6d_size(), expected.rotation_6d_size())
|
||||||
|
<< "Unexpected joint #" << i << " rotation";
|
||||||
|
for (int j = 0; j < actual.rotation_6d_size(); ++j) {
|
||||||
|
EXPECT_NEAR(actual.rotation_6d(j), expected.rotation_6d(j), 1e-5)
|
||||||
|
<< "Unexpected joint #" << i << " rotation";
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(actual.has_visibility(), expected.has_visibility())
|
||||||
|
<< "Unexpected joint #" << i << " visibility";
|
||||||
|
if (actual.has_visibility()) {
|
||||||
|
EXPECT_NEAR(actual.visibility(), expected.visibility(), 1e-5)
|
||||||
|
<< "Unexpected joint #" << i << " visibility";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
CombineJointsTests, CombineJointsTest,
|
||||||
|
testing::ValuesIn<CombineJointsTestCase>({
|
||||||
|
{"Empty_NoOutJoints", 0, "{ idx: [] }", {{}}, {}},
|
||||||
|
{"Empty_SingleOutJoint",
|
||||||
|
1,
|
||||||
|
"{ idx: [] }",
|
||||||
|
{{}},
|
||||||
|
{MakeJoint({1, 0, 0, 1, 0, 0}, 1)}},
|
||||||
|
|
||||||
|
{"Single_SetFirst",
|
||||||
|
2,
|
||||||
|
"{ idx: [0] }",
|
||||||
|
{{MakeJoint({3, 3, 3, 3, 3, 3}, 4)}},
|
||||||
|
{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({1, 0, 0, 1, 0, 0}, 1)}},
|
||||||
|
{"Single_SetBoth",
|
||||||
|
2,
|
||||||
|
"{ idx: [0, 1] }",
|
||||||
|
{{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}},
|
||||||
|
{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}},
|
||||||
|
{"Single_SetBoth_ReverseOrder",
|
||||||
|
2,
|
||||||
|
"{ idx: [1, 0] }",
|
||||||
|
{{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}},
|
||||||
|
{MakeJoint({7, 7, 7, 7, 7, 7}, 8), MakeJoint({3, 3, 3, 3, 3, 3}, 4)}},
|
||||||
|
|
||||||
|
{"Double_NoOverwrite",
|
||||||
|
3,
|
||||||
|
"{ idx: [0] }, { idx: [1] }",
|
||||||
|
{{MakeJoint({3, 3, 3, 3, 3, 3}, 4)},
|
||||||
|
{MakeJoint({7, 7, 7, 7, 7, 7}, 8)}},
|
||||||
|
{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8),
|
||||||
|
MakeJoint({1, 0, 0, 1, 0, 0}, 1)}},
|
||||||
|
{"Double_OverwriteSecond",
|
||||||
|
3,
|
||||||
|
"{ idx: [0, 1] }, { idx: [1, 2] }",
|
||||||
|
{{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({4, 4, 4, 4, 4, 4}, 5)},
|
||||||
|
{MakeJoint({6, 6, 6, 6, 6, 6}, 7), MakeJoint({8, 8, 8, 8, 8, 8}, 9)}},
|
||||||
|
{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({6, 6, 6, 6, 6, 6}, 7),
|
||||||
|
MakeJoint({8, 8, 8, 8, 8, 8}, 9)}},
|
||||||
|
}),
|
||||||
|
[](const testing::TestParamInfo<CombineJointsTest::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user