diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index ac69d969f..751f9c06b 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1617,6 +1617,48 @@ cc_test( ], ) +cc_library( + name = "combine_joints_calculator", + srcs = ["combine_joints_calculator.cc"], + hdrs = ["combine_joints_calculator.h"], + deps = [ + ":combine_joints_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/port:ret_check", + ], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "combine_joints_calculator_proto", + srcs = ["combine_joints_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/framework/formats:body_rig_proto", + ], +) + +cc_test( + name = "combine_joints_calculator_test", + srcs = ["combine_joints_calculator_test.cc"], + deps = [ + ":combine_joints_calculator", + ":combine_joints_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "pass_through_or_empty_detection_vector_calculator", srcs = ["pass_through_or_empty_detection_vector_calculator.cc"], diff --git a/mediapipe/calculators/util/combine_joints_calculator.cc b/mediapipe/calculators/util/combine_joints_calculator.cc new file mode 100644 index 000000000..ac5e6b033 --- /dev/null +++ b/mediapipe/calculators/util/combine_joints_calculator.cc @@ -0,0 +1,79 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/combine_joints_calculator.h" + +#include + +#include "mediapipe/calculators/util/combine_joints_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { + +namespace {} // namespace + +class CombineJointsCalculatorImpl : public NodeImpl { + public: + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + RET_CHECK_GE(options_.num_joints(), 0); + RET_CHECK_GT(kInJoints(cc).Count(), 0); + RET_CHECK_EQ(kInJoints(cc).Count(), options_.joints_mapping_size()); + RET_CHECK(options_.has_default_joint()); + for (const auto& mapping : options_.joints_mapping()) { + for (int idx : mapping.idx()) { + RET_CHECK_GE(idx, 0); + RET_CHECK_LT(idx, options_.num_joints()); + } + } + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // Initialize output joints with default values. + JointList out_joints; + for (int i = 0; i < options_.num_joints(); ++i) { + *out_joints.add_joint() = options_.default_joint(); + } + + // Override default joints with provided joints. + for (int i = 0; i < kInJoints(cc).Count(); ++i) { + // Skip empty joint streams. + if (kInJoints(cc)[i].IsEmpty()) { + continue; + } + + const JointList& in_joints = kInJoints(cc)[i].Get(); + const auto& mapping = options_.joints_mapping(i); + RET_CHECK_EQ(in_joints.joint_size(), mapping.idx_size()); + for (int j = 0; j < in_joints.joint_size(); ++j) { + *out_joints.mutable_joint(mapping.idx(j)) = in_joints.joint(j); + } + } + + kOutJoints(cc).Send(std::move(out_joints)); + return absl::OkStatus(); + } + + private: + CombineJointsCalculatorOptions options_; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(CombineJointsCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/combine_joints_calculator.h b/mediapipe/calculators/util/combine_joints_calculator.h new file mode 100644 index 000000000..41b3c314d --- /dev/null +++ b/mediapipe/calculators/util/combine_joints_calculator.h @@ -0,0 +1,64 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" + +namespace mediapipe { +namespace api2 { + +// A calculator to combine several joint sets into one. +// +// Input: +// JOINTS - Multiple JointList +// Joint sets to combine into one. Subsets are applied in provided order and +// overwrite each other. +// +// Output: +// JOINTS - JointList +// Combined joints. +// +// Example: +// node { +// calculator: "CombineJointsCalculator" +// input_stream: "JOINTS:0:joints_0" +// input_stream: "JOINTS:1:joints_1" +// output_stream: "JOINTS:combined_joints" +// options: { +// [mediapipe.CombineJointsCalculatorOptions.ext] { +// num_joints: 63 +// joints_mapping: { idx: [0, 1, 2] } +// joints_mapping: { idx: [2, 3] } +// default_joint: { +// rotation_6d: [1, 0, 0, 1, 0, 0] +// visibility: 1.0 +// } +// } +// } +// } +class CombineJointsCalculator : public NodeIntf { + public: + static constexpr Input::Multiple kInJoints{"JOINTS"}; + static constexpr Output kOutJoints{"JOINTS"}; + MEDIAPIPE_NODE_INTERFACE(CombineJointsCalculator, kInJoints, kOutJoints); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_COMBINE_JOINTS_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/combine_joints_calculator.proto b/mediapipe/calculators/util/combine_joints_calculator.proto new file mode 100644 index 000000000..d90f6df1a --- /dev/null +++ b/mediapipe/calculators/util/combine_joints_calculator.proto @@ -0,0 +1,46 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/formats/body_rig.proto"; + +message CombineJointsCalculatorOptions { + extend CalculatorOptions { + optional CombineJointsCalculatorOptions ext = 406440185; + } + + // Mapping from joint set to the resulting set. + message JointsMapping { + // Indexes of provided joints in the resulting joint set. + // All indexes must be within the [0, num_joints - 1] range. + repeated int32 idx = 1 [packed = true]; + } + + // Number of joints in the resulting set. + optional int32 num_joints = 1; + + // Mapping from joint sets to the resulting set. + // Number of mappings must be equal to number of provided joint sets. Number + // of indexes in each mapping must be equal to number of joints in + // corresponding joint set. Mappings are applied in the provided order and can + // overwrite each other. + repeated JointsMapping joints_mapping = 2; + + // Default joint to initialize joints in the resulting set. + optional Joint default_joint = 3; +} diff --git a/mediapipe/calculators/util/combine_joints_calculator_test.cc b/mediapipe/calculators/util/combine_joints_calculator_test.cc new file mode 100644 index 000000000..3625f2456 --- /dev/null +++ b/mediapipe/calculators/util/combine_joints_calculator_test.cc @@ -0,0 +1,174 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace api2 { +namespace { + +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +Joint MakeJoint(const std::vector& rotation_6d, + std::optional visibility) { + Joint joint; + for (float r : rotation_6d) { + joint.add_rotation_6d(r); + } + if (visibility) { + joint.set_visibility(visibility.value()); + } + return joint; +} + +struct CombineJointsTestCase { + std::string test_name; + int num_joints; + std::string joints_mapping; + std::vector> in_joints; + std::vector out_joints; +}; + +using CombineJointsTest = ::testing::TestWithParam; + +TEST_P(CombineJointsTest, CombineJointsTest) { + const CombineJointsTestCase& tc = GetParam(); + + std::string input_joint_streams = ""; + for (int i = 0; i < tc.in_joints.size(); ++i) { + input_joint_streams += + absl::StrFormat("input_stream: \"JOINTS:%d:joints_%d\"\n", i, i); + } + + // Prepare graph. + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(absl::Substitute( + R"( + calculator: "CombineJointsCalculator" + $0 + output_stream: "JOINTS:combined_joints" + options: { + [mediapipe.CombineJointsCalculatorOptions.ext] { + num_joints: $1 + joints_mapping: [ $2 ] + default_joint: { + rotation_6d: [1, 0, 0, 1, 0, 0] + visibility: 1.0 + } + } + } + )", + input_joint_streams, tc.num_joints, tc.joints_mapping))); + + // Prepare and send joints. + for (int i = 0; i < tc.in_joints.size(); ++i) { + JointList in_joints; + for (const auto& joint : tc.in_joints[i]) { + *in_joints.add_joint() = joint; + } + runner.MutableInputs() + ->Get("JOINTS", i) + .packets.push_back(MakePacket(std::move(in_joints)) + .At(mediapipe::Timestamp(0))); + } + + // Run the graph. + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets = runner.Outputs().Tag("JOINTS").packets; + EXPECT_EQ(1, output_packets.size()); + + const auto& out_joints = output_packets[0].Get(); + EXPECT_EQ(out_joints.joint_size(), tc.out_joints.size()); + for (int i = 0; i < out_joints.joint_size(); ++i) { + const Joint& actual = out_joints.joint(i); + const Joint& expected = tc.out_joints[i]; + + EXPECT_EQ(actual.rotation_6d_size(), expected.rotation_6d_size()) + << "Unexpected joint #" << i << " rotation"; + for (int j = 0; j < actual.rotation_6d_size(); ++j) { + EXPECT_NEAR(actual.rotation_6d(j), expected.rotation_6d(j), 1e-5) + << "Unexpected joint #" << i << " rotation"; + } + + EXPECT_EQ(actual.has_visibility(), expected.has_visibility()) + << "Unexpected joint #" << i << " visibility"; + if (actual.has_visibility()) { + EXPECT_NEAR(actual.visibility(), expected.visibility(), 1e-5) + << "Unexpected joint #" << i << " visibility"; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + CombineJointsTests, CombineJointsTest, + testing::ValuesIn({ + {"Empty_NoOutJoints", 0, "{ idx: [] }", {{}}, {}}, + {"Empty_SingleOutJoint", + 1, + "{ idx: [] }", + {{}}, + {MakeJoint({1, 0, 0, 1, 0, 0}, 1)}}, + + {"Single_SetFirst", + 2, + "{ idx: [0] }", + {{MakeJoint({3, 3, 3, 3, 3, 3}, 4)}}, + {MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({1, 0, 0, 1, 0, 0}, 1)}}, + {"Single_SetBoth", + 2, + "{ idx: [0, 1] }", + {{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}}, + {MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}}, + {"Single_SetBoth_ReverseOrder", + 2, + "{ idx: [1, 0] }", + {{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8)}}, + {MakeJoint({7, 7, 7, 7, 7, 7}, 8), MakeJoint({3, 3, 3, 3, 3, 3}, 4)}}, + + {"Double_NoOverwrite", + 3, + "{ idx: [0] }, { idx: [1] }", + {{MakeJoint({3, 3, 3, 3, 3, 3}, 4)}, + {MakeJoint({7, 7, 7, 7, 7, 7}, 8)}}, + {MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({7, 7, 7, 7, 7, 7}, 8), + MakeJoint({1, 0, 0, 1, 0, 0}, 1)}}, + {"Double_OverwriteSecond", + 3, + "{ idx: [0, 1] }, { idx: [1, 2] }", + {{MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({4, 4, 4, 4, 4, 4}, 5)}, + {MakeJoint({6, 6, 6, 6, 6, 6}, 7), MakeJoint({8, 8, 8, 8, 8, 8}, 9)}}, + {MakeJoint({3, 3, 3, 3, 3, 3}, 4), MakeJoint({6, 6, 6, 6, 6, 6}, 7), + MakeJoint({8, 8, 8, 8, 8, 8}, 9)}}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace api2 +} // namespace mediapipe