tensor_to_joints stream utility function.
PiperOrigin-RevId: 569043195
This commit is contained in:
parent
0ae9ff6b98
commit
66a279418c
|
@ -324,3 +324,34 @@ cc_test(
|
||||||
"//mediapipe/framework/port:status_matchers",
|
"//mediapipe/framework/port:status_matchers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_to_joints",
|
||||||
|
srcs = ["tensor_to_joints.cc"],
|
||||||
|
hdrs = ["tensor_to_joints.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:tensor_to_joints_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tensor_to_joints_test",
|
||||||
|
srcs = ["tensor_to_joints_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":tensor_to_joints",
|
||||||
|
"//mediapipe/calculators/tensor:tensor_to_joints_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/formats:body_rig_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status_matchers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
27
mediapipe/framework/api2/stream/tensor_to_joints.cc
Normal file
27
mediapipe/framework/api2/stream/tensor_to_joints.cc
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.h"
|
||||||
|
#include "mediapipe/calculators/tensor/tensor_to_joints_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
namespace {} // namespace
|
||||||
|
|
||||||
|
Stream<JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
|
||||||
|
const int num_joints,
|
||||||
|
const int start_index,
|
||||||
|
Graph& graph) {
|
||||||
|
auto& to_joints = graph.AddNode("TensorToJointsCalculator");
|
||||||
|
auto& to_joints_options =
|
||||||
|
to_joints.GetOptions<TensorToJointsCalculatorOptions>();
|
||||||
|
to_joints_options.set_num_joints(num_joints);
|
||||||
|
to_joints_options.set_start_index(start_index);
|
||||||
|
tensor.ConnectTo(to_joints[TensorToJointsCalculator::kInTensor]);
|
||||||
|
return to_joints[TensorToJointsCalculator::kOutJoints];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
26
mediapipe/framework/api2/stream/tensor_to_joints.h
Normal file
26
mediapipe/framework/api2/stream/tensor_to_joints.h
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
|
||||||
|
// Updates @graph to convert @tensor to a JointList skipping first @start_index
|
||||||
|
// values of a @tensor.
|
||||||
|
Stream<mediapipe::JointList> ConvertTensorToJointsAtIndex(Stream<Tensor> tensor,
|
||||||
|
const int num_joints,
|
||||||
|
const int start_index,
|
||||||
|
Graph& graph);
|
||||||
|
|
||||||
|
// Updates @graph to convert @tensor to a JointList.
|
||||||
|
inline Stream<::mediapipe::JointList> ConvertTensorToJoints(
|
||||||
|
Stream<Tensor> tensor, const int num_joints, Graph& graph) {
|
||||||
|
return ConvertTensorToJointsAtIndex(tensor, num_joints, /*start_index=*/0,
|
||||||
|
graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::api2::builder
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_API2_STREAM_TENSOR_TO_JOINTS_H_
|
74
mediapipe/framework/api2/stream/tensor_to_joints_test.cc
Normal file
74
mediapipe/framework/api2/stream/tensor_to_joints_test.cc
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
#include "mediapipe/framework/api2/stream/tensor_to_joints.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/body_rig.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe::api2::builder {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(ConvertTensorToJoints, ConvertTensorToJoints) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
|
||||||
|
Stream<JointList> joint_list =
|
||||||
|
ConvertTensorToJoints(tensor, /*num_joints=*/56, graph);
|
||||||
|
joint_list.SetName("joints");
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "TensorToJointsCalculator"
|
||||||
|
input_stream: "TENSOR:__stream_0"
|
||||||
|
output_stream: "JOINTS:joints"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||||
|
num_joints: 56
|
||||||
|
start_index: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "TENSOR:__stream_0"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConvertTensorToJointsAtIndex, ConvertTensorToJointsAtIndex) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
Stream<Tensor> tensor = graph.In("TENSOR").Cast<Tensor>();
|
||||||
|
Stream<JointList> joint_list = ConvertTensorToJointsAtIndex(
|
||||||
|
tensor, /*num_joints=*/56, /*start_index=*/3, graph);
|
||||||
|
joint_list.SetName("joints");
|
||||||
|
|
||||||
|
EXPECT_THAT(graph.GetConfig(),
|
||||||
|
EqualsProto(ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "TensorToJointsCalculator"
|
||||||
|
input_stream: "TENSOR:__stream_0"
|
||||||
|
output_stream: "JOINTS:joints"
|
||||||
|
options {
|
||||||
|
[mediapipe.TensorToJointsCalculatorOptions.ext] {
|
||||||
|
num_joints: 56
|
||||||
|
start_index: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "TENSOR:__stream_0"
|
||||||
|
)pb")));
|
||||||
|
|
||||||
|
CalculatorGraph calcualtor_graph;
|
||||||
|
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::api2::builder
|
Loading…
Reference in New Issue
Block a user