tensor_to_joints stream utility function.
PiperOrigin-RevId: 569043195
This commit is contained in:
parent
0ae9ff6b98
commit
66a279418c
|
@ -324,3 +324,34 @@ cc_test(
|
|||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_to_joints",
|
||||
srcs = ["tensor_to_joints.cc"],
|
||||
hdrs = ["tensor_to_joints.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensor_to_joints_calculator",
|
||||
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_to_joints_test",
|
||||
srcs = ["tensor_to_joints_test.cc"],
|
||||
deps = [
|
||||
":tensor_to_joints",
|
||||
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status_matchers",
|
||||
],
|
||||
)
|
||||
|
|
27
mediapipe/framework/api2/stream/tensor_to_joints.cc
Normal file
27
mediapipe/framework/api2/stream/tensor_to_joints.cc
Normal file
|
@ -0,0 +1,27 @@
|
|||
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
|
||||
|
||||
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h"
|
||||
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
Stream<JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
|
||||
const int num_joints,
|
||||
const int start_index,
|
||||
Graph& graph) {
|
||||
auto& to_joints = graph.AddNode("TensorToJointsCalculator");
|
||||
auto& to_joints_options =
|
||||
to_joints.GetOptions<TensorToJointsCalculatorOptions>();
|
||||
to_joints_options.set_num_joints(num_joints);
|
||||
to_joints_options.set_start_index(start_index);
|
||||
tensor.ConnectTo(to_joints[TensorToJointsCalculator::kInTensor]);
|
||||
return to_joints[TensorToJointsCalculator::kOutJoints];
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
26
mediapipe/framework/api2/stream/tensor_to_joints.h
Normal file
26
mediapipe/framework/api2/stream/tensor_to_joints.h
Normal file
|
@ -0,0 +1,26 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
|
||||
// Updates @graph to convert @tensor to a JointList skipping first @start_index
|
||||
// values of a @tensor.
|
||||
Stream<mediapipe::JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
|
||||
const int num_joints,
|
||||
const int start_index,
|
||||
Graph& graph);
|
||||
|
||||
// Updates @graph to convert @tensor to a JointList.
|
||||
inline Stream<::mediapipe::JointList> ConvertTensorToJoints(
|
||||
Stream<Tensor> tensor, const int num_joints, Graph& graph) {
|
||||
return ConvertTensorToJointsAtIndex(tensor, num_joints, /*start_index=*/0,
|
||||
graph);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
74
mediapipe/framework/api2/stream/tensor_to_joints_test.cc
Normal file
74
mediapipe/framework/api2/stream/tensor_to_joints_test.cc
Normal file
|
@ -0,0 +1,74 @@
|
|||
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
TEST(ConvertTensorToJoints, ConvertTensorToJoints) {
|
||||
Graph graph;
|
||||
|
||||
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
|
||||
Stream<JointList> joint_list =
|
||||
ConvertTensorToJoints(tensor, /*num_joints=*/56, graph);
|
||||
joint_list.SetName("joints");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "TensorToJointsCalculator"
|
||||
input_stream: "TENSOR:__stream_0"
|
||||
output_stream: "JOINTS:joints"
|
||||
options {
|
||||
[mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||
num_joints: 56
|
||||
start_index: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSOR:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
TEST(ConvertTensorToJointsAtIndex, ConvertTensorToJointsAtIndex) {
|
||||
Graph graph;
|
||||
|
||||
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
|
||||
Stream<JointList> joint_list = ConvertTensorToJointsAtIndex(
|
||||
tensor, /*num_joints=*/56, /*start_index=*/3, graph);
|
||||
joint_list.SetName("joints");
|
||||
|
||||
EXPECT_THAT(graph.GetConfig(),
|
||||
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "TensorToJointsCalculator"
|
||||
input_stream: "TENSOR:__stream_0"
|
||||
output_stream: "JOINTS:joints"
|
||||
options {
|
||||
[mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||
num_joints: 56
|
||||
start_index: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "TENSOR:__stream_0"
|
||||
)pb")));
|
||||
|
||||
CalculatorGraph calcualtor_graph;
|
||||
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user