refactored unsafe calls into mediagraph struct

This commit is contained in:
Jules Youngberg 2022-06-06 22:16:31 -07:00
parent 9f758c730f
commit 20cc045f11

View File

@ -19,7 +19,7 @@ mod bindings;
pub use bindings::*; pub use bindings::*;
type Mediagraph = mediagraph_Mediagraph; type GraphType = mediagraph_GraphType;
type Landmark = mediagraph_Landmark; type Landmark = mediagraph_Landmark;
impl Default for Landmark { impl Default for Landmark {
@ -70,6 +70,35 @@ impl Default for FaceMesh {
} }
} }
struct Mediagraph {
graph: *mut mediagraph_Mediagraph,
}
impl Mediagraph {
fn new(graph_type: GraphType, graph_config: &str, output_node: &str) -> Self {
let graph_config = CString::new(graph_config).expect("CString::new failed");
let output_node = CString::new(output_node).expect("CString::new failed");
let graph: *mut mediagraph_Mediagraph = unsafe {
mediagraph_Mediagraph::Create(graph_type, graph_config.as_ptr(), output_node.as_ptr())
};
Self { graph }
}
fn process(&mut self, input: &Mat) {
let mut data = input.clone();
let landmarks = unsafe {
mediagraph_Mediagraph_Process(
self.graph as *mut std::ffi::c_void,
data.data_mut(),
data.cols(),
data.rows(),
)
};
}
}
pub mod pose { pub mod pose {
use super::*; use super::*;
@ -114,22 +143,16 @@ pub mod pose {
pub smooth: bool, // true, pub smooth: bool, // true,
pub detection_con: f32, // 0.5 pub detection_con: f32, // 0.5
pub track_con: f32, // 0.5 pub track_con: f32, // 0.5
pub graph: *mut Mediagraph, graph: Mediagraph,
} }
impl PoseDetector { impl PoseDetector {
pub fn new(mode: bool, smooth: bool, detection_con: f32, track_con: f32) -> Self { pub fn new(mode: bool, smooth: bool, detection_con: f32, track_con: f32) -> Self {
let graph_config = let graph = Mediagraph::new(
CString::new(include_str!("pose_tracking_cpu.txt")).expect("CString::new failed");
let output_node = CString::new("pose_landmarks").expect("CString::new failed");
let graph: *mut Mediagraph = unsafe {
Mediagraph::Create(
mediagraph_GraphType_POSE, mediagraph_GraphType_POSE,
graph_config.as_ptr(), include_str!("pose_tracking_cpu.txt"),
output_node.as_ptr(), "pose_landmarks",
) );
};
Self { Self {
mode, mode,
@ -141,16 +164,7 @@ pub mod pose {
} }
pub fn process(&mut self, input: &Mat) -> bool { pub fn process(&mut self, input: &Mat) -> bool {
let mut data = input.clone(); self.graph.process(input);
let landmarks = unsafe {
mediagraph_Mediagraph_Process(
self.graph as *mut std::ffi::c_void,
data.data_mut(),
data.cols(),
data.rows(),
)
};
// @todo read each landmark to build a pose struct // @todo read each landmark to build a pose struct
true true
} }
@ -171,7 +185,7 @@ pub mod face_mesh {
pub max_faces: usize, // 2 pub max_faces: usize, // 2
pub min_detection_con: f32, // 0.5 pub min_detection_con: f32, // 0.5
pub min_track_con: f32, // 0.5 pub min_track_con: f32, // 0.5
pub graph: *mut Mediagraph, graph: Mediagraph,
} }
impl FaceMeshDetector { impl FaceMeshDetector {
@ -181,17 +195,11 @@ pub mod face_mesh {
min_detection_con: f32, min_detection_con: f32,
min_track_con: f32, min_track_con: f32,
) -> Self { ) -> Self {
let graph_config = CString::new(include_str!("face_mesh_desktop_live.txt")) let graph = Mediagraph::new(
.expect("CString::new failed");
let output_node = CString::new("multi_face_landmarks").expect("CString::new failed");
let graph: *mut Mediagraph = unsafe {
Mediagraph::Create(
mediagraph_GraphType_FACE, mediagraph_GraphType_FACE,
graph_config.as_ptr(), include_str!("face_mesh_desktop_live.txt"),
output_node.as_ptr(), "multi_face_landmarks",
) );
};
Self { Self {
static_mode, static_mode,
@ -202,16 +210,8 @@ pub mod face_mesh {
} }
} }
pub fn process(&mut self, input: Mat) -> bool { pub fn process(&mut self, input: &Mat) -> bool {
let mut data = input.clone(); self.graph.process(input);
let landmarks = unsafe {
mediagraph_Mediagraph_Process(
self.graph as *mut std::ffi::c_void,
data.data_mut(),
data.cols(),
data.rows(),
)
};
// @todo read each landmark to build a face mesh struct // @todo read each landmark to build a face mesh struct
true true
} }
@ -256,22 +256,16 @@ pub mod hands {
pub max_hands: usize, pub max_hands: usize,
pub detection_con: f32, // 0.5 pub detection_con: f32, // 0.5
pub min_track_con: f32, // 0.5 pub min_track_con: f32, // 0.5
pub graph: *mut Mediagraph, graph: Mediagraph,
} }
impl HandDetector { impl HandDetector {
pub fn new(mode: bool, max_hands: usize, detection_con: f32, min_track_con: f32) -> Self { pub fn new(mode: bool, max_hands: usize, detection_con: f32, min_track_con: f32) -> Self {
let graph_config = CString::new(include_str!("hand_tracking_desktop_live.txt")) let graph = Mediagraph::new(
.expect("CString::new failed"); mediagraph_GraphType_HANDS,
let output_node = CString::new("hand_landmarks").expect("CString::new failed"); include_str!("hand_tracking_desktop_live.txt"),
"hand_landmarks",
let graph: *mut Mediagraph = unsafe { );
Mediagraph::Create(
mediagraph_GraphType_FACE,
graph_config.as_ptr(),
output_node.as_ptr(),
)
};
Self { Self {
mode, mode,
@ -282,16 +276,8 @@ pub mod hands {
} }
} }
pub fn process(&mut self, input: Mat) -> bool { pub fn process(&mut self, input: &Mat) -> bool {
let mut data = input.clone(); self.graph.process(input);
let landmarks = unsafe {
mediagraph_Mediagraph_Process(
self.graph as *mut std::ffi::c_void,
data.data_mut(),
data.cols(),
data.rows(),
)
};
// @todo read each landmark to build a hands struct // @todo read each landmark to build a hands struct
true true
} }