tensor_to_joints stream utility function.

PiperOrigin-RevId: 569043195
This commit is contained in:
MediaPipe Team 2023-09-27 20:32:14 -07:00 committed by Copybara-Service
parent 0ae9ff6b98
commit 66a279418c
4 changed files with 158 additions and 0 deletions

View File

@ -324,3 +324,34 @@ cc_test(
"//mediapipe/framework/port:status_matchers", "//mediapipe/framework/port:status_matchers",
], ],
) )
cc_library(
name = "tensor_to_joints",
srcs = ["tensor_to_joints.cc"],
hdrs = ["tensor_to_joints.h"],
deps = [
"//mediapipe/calculators/tensor:tensor_to_joints_calculator",
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:tensor",
],
)
cc_test(
name = "tensor_to_joints_test",
srcs = ["tensor_to_joints_test.cc"],
deps = [
":tensor_to_joints",
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:body_rig_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status_matchers",
],
)

View File

@ -0,0 +1,27 @@
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h"
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe::api2::builder {
namespace {} // namespace
Stream<JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
const int num_joints,
const int start_index,
Graph& graph) {
auto& to_joints = graph.AddNode("TensorToJointsCalculator");
auto& to_joints_options =
to_joints.GetOptions<TensorToJointsCalculatorOptions>();
to_joints_options.set_num_joints(num_joints);
to_joints_options.set_start_index(start_index);
tensor.ConnectTo(to_joints[TensorToJointsCalculator::kInTensor]);
return to_joints[TensorToJointsCalculator::kOutJoints];
}
} // namespace mediapipe::api2::builder

View File

@ -0,0 +1,26 @@
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe::api2::builder {
// Updates @graph to convert @tensor to a JointList skipping first @start_index
// values of a @tensor.
Stream<mediapipe::JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
const int num_joints,
const int start_index,
Graph& graph);
// Updates @graph to convert @tensor to a JointList.
inline Stream<::mediapipe::JointList> ConvertTensorToJoints(
Stream<Tensor> tensor, const int num_joints, Graph& graph) {
return ConvertTensorToJointsAtIndex(tensor, num_joints, /*start_index=*/0,
graph);
}
} // namespace mediapipe::api2::builder
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_

View File

@ -0,0 +1,74 @@
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
#include <vector>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe::api2::builder {
namespace {
TEST(ConvertTensorToJoints, ConvertTensorToJoints) {
Graph graph;
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
Stream<JointList> joint_list =
ConvertTensorToJoints(tensor, /*num_joints=*/56, graph);
joint_list.SetName("joints");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "TensorToJointsCalculator"
input_stream: "TENSOR:__stream_0"
output_stream: "JOINTS:joints"
options {
[mediapipe.TensorToJointsCalculatorOptions.ext] {
num_joints: 56
start_index: 0
}
}
}
input_stream: "TENSOR:__stream_0"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(ConvertTensorToJointsAtIndex, ConvertTensorToJointsAtIndex) {
Graph graph;
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
Stream<JointList> joint_list = ConvertTensorToJointsAtIndex(
tensor, /*num_joints=*/56, /*start_index=*/3, graph);
joint_list.SetName("joints");
EXPECT_THAT(graph.GetConfig(),
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "TensorToJointsCalculator"
input_stream: "TENSOR:__stream_0"
output_stream: "JOINTS:joints"
options {
[mediapipe.TensorToJointsCalculatorOptions.ext] {
num_joints: 56
start_index: 3
}
}
}
input_stream: "TENSOR:__stream_0"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace mediapipe::api2::builder