diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 7ebba901c..d7e2313d5 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -980,6 +980,48 @@ cc_test( ], ) +cc_library( + name = "tensor_to_joints_calculator", + srcs = ["tensor_to_joints_calculator.cc"], + hdrs = ["tensor_to_joints_calculator.h"], + deps = [ + ":tensor_to_joints_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +mediapipe_proto_library( + name = "tensor_to_joints_calculator_proto", + srcs = ["tensor_to_joints_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_test( + name = "tensor_to_joints_calculator_test", + srcs = ["tensor_to_joints_calculator_test.cc"], + deps = [ + ":tensor_to_joints_calculator", + ":tensor_to_joints_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:body_rig_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "image_to_tensor_calculator", srcs = ["image_to_tensor_calculator.cc"], diff --git a/mediapipe/calculators/tensor/tensor_to_joints_calculator.cc b/mediapipe/calculators/tensor/tensor_to_joints_calculator.cc new file mode 100644 index 000000000..07b919ab2 --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_to_joints_calculator.cc @@ -0,0 +1,84 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h" + +#include + +#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace api2 { +namespace { + +// Number of values in 6D representation of rotation. +constexpr int kRotation6dSize = 6; + +} // namespace + +class TensorToJointsCalculatorImpl + : public mediapipe::api2::NodeImpl { + public: + absl::Status Open(CalculatorContext* cc) override { + const auto& options = cc->Options(); + + // Get number of joints. + RET_CHECK_GE(options.num_joints(), 0); + num_joints_ = options.num_joints(); + + // Get start index. + start_index_ = options.start_index(); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // Skip if Tensor is empty. + if (kInTensor(cc).IsEmpty()) { + return absl::OkStatus(); + } + + // Get raw floats from the Tensor. + const Tensor& tensor = kInTensor(cc).Get(); + RET_CHECK_EQ(tensor.shape().num_elements(), + num_joints_ * kRotation6dSize + start_index_) + << "Unexpected number of values in Tensor"; + const float* raw_floats = tensor.GetCpuReadView().buffer(); + + // Convert raw floats into Joint rotations. + JointList joints; + for (int joint_idx = 0; joint_idx < num_joints_; ++joint_idx) { + Joint* joint = joints.add_joint(); + for (int idx_6d = 0; idx_6d < kRotation6dSize; ++idx_6d) { + joint->add_rotation_6d( + raw_floats[start_index_ + joint_idx * kRotation6dSize + idx_6d]); + } + } + + kOutJoints(cc).Send(std::move(joints)); + return absl::OkStatus(); + } + + private: + int num_joints_ = 0; + int start_index_ = 0; +}; +MEDIAPIPE_NODE_IMPLEMENTATION(TensorToJointsCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_to_joints_calculator.h b/mediapipe/calculators/tensor/tensor_to_joints_calculator.h new file mode 100644 index 000000000..c2798526e --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_to_joints_calculator.h @@ -0,0 +1,64 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ + +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +// A calculator to convert Tensors to JointList. +// +// Calculator fills in only rotation of the joints leaving visibility undefined. +// +// Input: +// TENSOR - std::vector with kFloat32 values +// Vector of tensors to be converted to joints. Only the first tensor will +// be used. Number of values is expected to be multiple of six. +// +// Output: +// JOINTS - JointList +// List of joints with rotations extracted from given tensor and undefined +// visibility. +// +// Example: +// node { +// calculator: "TensorToJointsCalculator" +// input_stream: "TENSOR:tensor" +// output_stream: "JOINTS:joints" +// options: { +// [mediapipe.TensorToJointsCalculatorOptions.ext] { +// num_joints: 56 +// start_index: 3 +// } +// } +// } +class TensorToJointsCalculator : public NodeIntf { + public: + static constexpr Input kInTensor{"TENSOR"}; + static constexpr Output kOutJoints{"JOINTS"}; + MEDIAPIPE_NODE_INTERFACE(TensorToJointsCalculator, kInTensor, kOutJoints); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/tensor_to_joints_calculator.proto b/mediapipe/calculators/tensor/tensor_to_joints_calculator.proto new file mode 100644 index 000000000..3534e35ba --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_to_joints_calculator.proto @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorToJointsCalculatorOptions { + extend CalculatorOptions { + optional TensorToJointsCalculatorOptions ext = 406440177; + } + + // Number of joints from the output of the model. Calculator will expect the + // tensor to contain `6 * num_joints + start_index` values. + optional int32 num_joints = 1; + + // Index to start reading 6 value blocks from. + optional int32 start_index = 2 [default = 0]; +} diff --git a/mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc b/mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc new file mode 100644 index 000000000..7d29899f4 --- /dev/null +++ b/mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc @@ -0,0 +1,123 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/body_rig.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { +namespace api2 { +namespace { + +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +struct TensorToJointsTestCase { + std::string test_name; + int num_joints; + int start_index; + std::vector raw_values; + std::vector> expected_rotations; +}; + +using TensorToJointsTest = ::testing::TestWithParam; + +TEST_P(TensorToJointsTest, TensorToJointsTest) { + const TensorToJointsTestCase& tc = GetParam(); + + // Prepare graph. + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(absl::Substitute( + R"( + calculator: "TensorToJointsCalculator" + input_stream: "TENSOR:tensor" + output_stream: "JOINTS:joints" + options: { + [mediapipe.TensorToJointsCalculatorOptions.ext] { + num_joints: $0 + start_index: $1 + } + } + )", + tc.num_joints, tc.start_index))); + + // Prepare tensor. + Tensor tensor(Tensor::ElementType::kFloat32, + Tensor::Shape{1, 1, static_cast(tc.raw_values.size()), 1}); + float* tensor_buffer = tensor.GetCpuWriteView().buffer(); + ASSERT_NE(tensor_buffer, nullptr); + for (int i = 0; i < tc.raw_values.size(); ++i) { + tensor_buffer[i] = tc.raw_values[i]; + } + + // Send tensor to the graph. + runner.MutableInputs()->Tag("TENSOR").packets.push_back( + mediapipe::MakePacket(std::move(tensor)).At(Timestamp(0))); + + // Run the graph. + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets = runner.Outputs().Tag("JOINTS").packets; + EXPECT_EQ(1, output_packets.size()); + + const auto& joints = output_packets[0].Get(); + EXPECT_EQ(joints.joint_size(), tc.expected_rotations.size()); + for (int i = 0; i < joints.joint_size(); ++i) { + const Joint& joint = joints.joint(i); + std::vector expected_rotation_6d = tc.expected_rotations[i]; + EXPECT_EQ(joint.rotation_6d_size(), expected_rotation_6d.size()) + << "Unexpected joint #" << i << " rotation"; + for (int j = 0; j < joint.rotation_6d_size(); ++j) { + EXPECT_EQ(joint.rotation_6d(j), expected_rotation_6d[j]) + << "Unexpected joint #" << i << " rotation"; + } + EXPECT_FALSE(joint.has_visibility()); + } +} + +INSTANTIATE_TEST_SUITE_P( + TensorToJointsTests, TensorToJointsTest, + testing::ValuesIn({ + {"Empty", 0, 3, {0, 0, 0}, {}}, + + {"Single", + 1, + 3, + {0, 0, 0, 10, 11, 12, 13, 14, 15}, + {{10, 11, 12, 13, 14, 15}}}, + + {"Double", + 2, + 3, + {0, 0, 0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, + {{10, 11, 12, 13, 14, 15}, {16, 17, 18, 19, 20, 21}}}, + }), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace api2 +} // namespace mediapipe