Introduce TensorToJointsCalculator
PiperOrigin-RevId: 569040914
This commit is contained in:
parent
da02052c70
commit
0ae9ff6b98
|
@ -980,6 +980,48 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_to_joints_calculator",
|
||||||
|
srcs = ["tensor_to_joints_calculator.cc"],
|
||||||
|
hdrs = ["tensor_to_joints_calculator.h"],
|
||||||
|
deps = [
|
||||||
|
":tensor_to_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "tensor_to_joints_calculator_proto",
|
||||||
|
srcs = ["tensor_to_joints_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tensor_to_joints_calculator_test",
|
||||||
|
srcs = ["tensor_to_joints_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":tensor_to_joints_calculator",
|
||||||
|
":tensor_to_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "image_to_tensor_calculator",
|
name = "image_to_tensor_calculator",
|
||||||
srcs = ["image_to_tensor_calculator.cc"],
|
srcs = ["image_to_tensor_calculator.cc"],
|
||||||
|
|
84
mediapipe/calculators/tensor/tensor_to_joints_calculator.cc
Normal file
84
mediapipe/calculators/tensor/tensor_to_joints_calculator.cc
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Number of values in 6D representation of rotation.
|
||||||
|
constexpr int kRotation6dSize = 6;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class TensorToJointsCalculatorImpl
|
||||||
|
: public mediapipe::api2::NodeImpl<TensorToJointsCalculator> {
|
||||||
|
public:
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
const auto& options = cc->Options<TensorToJointsCalculatorOptions>();
|
||||||
|
|
||||||
|
// Get number of joints.
|
||||||
|
RET_CHECK_GE(options.num_joints(), 0);
|
||||||
|
num_joints_ = options.num_joints();
|
||||||
|
|
||||||
|
// Get start index.
|
||||||
|
start_index_ = options.start_index();
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
// Skip if Tensor is empty.
|
||||||
|
if (kInTensor(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get raw floats from the Tensor.
|
||||||
|
const Tensor& tensor = kInTensor(cc).Get();
|
||||||
|
RET_CHECK_EQ(tensor.shape().num_elements(),
|
||||||
|
num_joints_ * kRotation6dSize + start_index_)
|
||||||
|
<< "Unexpected number of values in Tensor";
|
||||||
|
const float* raw_floats = tensor.GetCpuReadView().buffer<float>();
|
||||||
|
|
||||||
|
// Convert raw floats into Joint rotations.
|
||||||
|
JointList joints;
|
||||||
|
for (int joint_idx = 0; joint_idx < num_joints_; ++joint_idx) {
|
||||||
|
Joint* joint = joints.add_joint();
|
||||||
|
for (int idx_6d = 0; idx_6d < kRotation6dSize; ++idx_6d) {
|
||||||
|
joint->add_rotation_6d(
|
||||||
|
raw_floats[start_index_ + joint_idx * kRotation6dSize + idx_6d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kOutJoints(cc).Send(std::move(joints));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_joints_ = 0;
|
||||||
|
int start_index_ = 0;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_NODE_IMPLEMENTATION(TensorToJointsCalculatorImpl);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
64
mediapipe/calculators/tensor/tensor_to_joints_calculator.h
Normal file
64
mediapipe/calculators/tensor/tensor_to_joints_calculator.h
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
// A calculator to convert Tensors to JointList.
|
||||||
|
//
|
||||||
|
// Calculator fills in only rotation of the joints leaving visibility undefined.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// TENSOR - std::vector<Tensor> with kFloat32 values
|
||||||
|
// Vector of tensors to be converted to joints. Only the first tensor will
|
||||||
|
// be used. Number of values is expected to be multiple of six.
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// JOINTS - JointList
|
||||||
|
// List of joints with rotations extracted from given tensor and undefined
|
||||||
|
// visibility.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "TensorToJointsCalculator"
|
||||||
|
// input_stream: "TENSOR:tensor"
|
||||||
|
// output_stream: "JOINTS:joints"
|
||||||
|
// options: {
|
||||||
|
// [mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||||
|
// num_joints: 56
|
||||||
|
// start_index: 3
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class TensorToJointsCalculator : public NodeIntf {
|
||||||
|
public:
|
||||||
|
static constexpr Input<mediapipe::Tensor> kInTensor{"TENSOR"};
|
||||||
|
static constexpr Output<mediapipe::JointList> kOutJoints{"JOINTS"};
|
||||||
|
MEDIAPIPE_NODE_INTERFACE(TensorToJointsCalculator, kInTensor, kOutJoints);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSOR_TO_JOINTS_CALCULATOR_H_
|
|
@ -0,0 +1,32 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message TensorToJointsCalculatorOptions {
|
||||||
|
extend CalculatorOptions {
|
||||||
|
optional TensorToJointsCalculatorOptions ext = 406440177;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of joints from the output of the model. Calculator will expect the
|
||||||
|
// tensor to contain `6 * num_joints + start_index` values.
|
||||||
|
optional int32 num_joints = 1;
|
||||||
|
|
||||||
|
// Index to start reading 6 value blocks from.
|
||||||
|
optional int32 start_index = 2 [default = 0];
|
||||||
|
}
|
123
mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc
Normal file
123
mediapipe/calculators/tensor/tensor_to_joints_calculator_test.cc
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
|
|
||||||
|
struct TensorToJointsTestCase {
|
||||||
|
std::string test_name;
|
||||||
|
int num_joints;
|
||||||
|
int start_index;
|
||||||
|
std::vector<float> raw_values;
|
||||||
|
std::vector<std::vector<float>> expected_rotations;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TensorToJointsTest = ::testing::TestWithParam<TensorToJointsTestCase>;
|
||||||
|
|
||||||
|
TEST_P(TensorToJointsTest, TensorToJointsTest) {
|
||||||
|
const TensorToJointsTestCase& tc = GetParam();
|
||||||
|
|
||||||
|
// Prepare graph.
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(absl::Substitute(
|
||||||
|
R"(
|
||||||
|
calculator: "TensorToJointsCalculator"
|
||||||
|
input_stream: "TENSOR:tensor"
|
||||||
|
output_stream: "JOINTS:joints"
|
||||||
|
options: {
|
||||||
|
[mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||||
|
num_joints: $0
|
||||||
|
start_index: $1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)",
|
||||||
|
tc.num_joints, tc.start_index)));
|
||||||
|
|
||||||
|
// Prepare tensor.
|
||||||
|
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{1, 1, static_cast<int>(tc.raw_values.size()), 1});
|
||||||
|
float* tensor_buffer = tensor.GetCpuWriteView().buffer<float>();
|
||||||
|
ASSERT_NE(tensor_buffer, nullptr);
|
||||||
|
for (int i = 0; i < tc.raw_values.size(); ++i) {
|
||||||
|
tensor_buffer[i] = tc.raw_values[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send tensor to the graph.
|
||||||
|
runner.MutableInputs()->Tag("TENSOR").packets.push_back(
|
||||||
|
mediapipe::MakePacket<Tensor>(std::move(tensor)).At(Timestamp(0)));
|
||||||
|
|
||||||
|
// Run the graph.
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& output_packets = runner.Outputs().Tag("JOINTS").packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
|
||||||
|
const auto& joints = output_packets[0].Get<JointList>();
|
||||||
|
EXPECT_EQ(joints.joint_size(), tc.expected_rotations.size());
|
||||||
|
for (int i = 0; i < joints.joint_size(); ++i) {
|
||||||
|
const Joint& joint = joints.joint(i);
|
||||||
|
std::vector<float> expected_rotation_6d = tc.expected_rotations[i];
|
||||||
|
EXPECT_EQ(joint.rotation_6d_size(), expected_rotation_6d.size())
|
||||||
|
<< "Unexpected joint #" << i << " rotation";
|
||||||
|
for (int j = 0; j < joint.rotation_6d_size(); ++j) {
|
||||||
|
EXPECT_EQ(joint.rotation_6d(j), expected_rotation_6d[j])
|
||||||
|
<< "Unexpected joint #" << i << " rotation";
|
||||||
|
}
|
||||||
|
EXPECT_FALSE(joint.has_visibility());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TensorToJointsTests, TensorToJointsTest,
|
||||||
|
testing::ValuesIn<TensorToJointsTestCase>({
|
||||||
|
{"Empty", 0, 3, {0, 0, 0}, {}},
|
||||||
|
|
||||||
|
{"Single",
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
{0, 0, 0, 10, 11, 12, 13, 14, 15},
|
||||||
|
{{10, 11, 12, 13, 14, 15}}},
|
||||||
|
|
||||||
|
{"Double",
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
{0, 0, 0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21},
|
||||||
|
{{10, 11, 12, 13, 14, 15}, {16, 17, 18, 19, 20, 21}}},
|
||||||
|
}),
|
||||||
|
[](const testing::TestParamInfo<TensorToJointsTest::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user