diff --git a/src/lib.rs b/src/lib.rs index 54d2304..e91aded 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ mod bindings; pub use bindings::*; -type Mediagraph = mediagraph_Mediagraph; +type GraphType = mediagraph_GraphType; type Landmark = mediagraph_Landmark; impl Default for Landmark { @@ -70,6 +70,35 @@ impl Default for FaceMesh { } } +struct Mediagraph { + graph: *mut mediagraph_Mediagraph, +} + +impl Mediagraph { + fn new(graph_type: GraphType, graph_config: &str, output_node: &str) -> Self { + let graph_config = CString::new(graph_config).expect("CString::new failed"); + let output_node = CString::new(output_node).expect("CString::new failed"); + + let graph: *mut mediagraph_Mediagraph = unsafe { + mediagraph_Mediagraph::Create(graph_type, graph_config.as_ptr(), output_node.as_ptr()) + }; + + Self { graph } + } + + fn process(&mut self, input: &Mat) { + let mut data = input.clone(); + let landmarks = unsafe { + mediagraph_Mediagraph_Process( + self.graph as *mut std::ffi::c_void, + data.data_mut(), + data.cols(), + data.rows(), + ) + }; + } +} + pub mod pose { use super::*; @@ -114,22 +143,16 @@ pub mod pose { pub smooth: bool, // true, pub detection_con: f32, // 0.5 pub track_con: f32, // 0.5 - pub graph: *mut Mediagraph, + graph: Mediagraph, } impl PoseDetector { pub fn new(mode: bool, smooth: bool, detection_con: f32, track_con: f32) -> Self { - let graph_config = - CString::new(include_str!("pose_tracking_cpu.txt")).expect("CString::new failed"); - let output_node = CString::new("pose_landmarks").expect("CString::new failed"); - - let graph: *mut Mediagraph = unsafe { - Mediagraph::Create( - mediagraph_GraphType_POSE, - graph_config.as_ptr(), - output_node.as_ptr(), - ) - }; + let graph = Mediagraph::new( + mediagraph_GraphType_POSE, + include_str!("pose_tracking_cpu.txt"), + "pose_landmarks", + ); Self { mode, @@ -141,16 +164,7 @@ pub mod pose { } pub fn process(&mut self, input: &Mat) -> bool { - let mut data = input.clone(); - let landmarks = unsafe { - mediagraph_Mediagraph_Process( - self.graph as *mut std::ffi::c_void, - data.data_mut(), - data.cols(), - data.rows(), - ) - }; - + self.graph.process(input); // @todo read each landmark to build a pose struct true } @@ -171,7 +185,7 @@ pub mod face_mesh { pub max_faces: usize, // 2 pub min_detection_con: f32, // 0.5 pub min_track_con: f32, // 0.5 - pub graph: *mut Mediagraph, + graph: Mediagraph, } impl FaceMeshDetector { @@ -181,17 +195,11 @@ pub mod face_mesh { min_detection_con: f32, min_track_con: f32, ) -> Self { - let graph_config = CString::new(include_str!("face_mesh_desktop_live.txt")) - .expect("CString::new failed"); - let output_node = CString::new("multi_face_landmarks").expect("CString::new failed"); - - let graph: *mut Mediagraph = unsafe { - Mediagraph::Create( - mediagraph_GraphType_FACE, - graph_config.as_ptr(), - output_node.as_ptr(), - ) - }; + let graph = Mediagraph::new( + mediagraph_GraphType_FACE, + include_str!("face_mesh_desktop_live.txt"), + "multi_face_landmarks", + ); Self { static_mode, @@ -202,16 +210,8 @@ pub mod face_mesh { } } - pub fn process(&mut self, input: Mat) -> bool { - let mut data = input.clone(); - let landmarks = unsafe { - mediagraph_Mediagraph_Process( - self.graph as *mut std::ffi::c_void, - data.data_mut(), - data.cols(), - data.rows(), - ) - }; + pub fn process(&mut self, input: &Mat) -> bool { + self.graph.process(input); // @todo read each landmark to build a face mesh struct true } @@ -256,22 +256,16 @@ pub mod hands { pub max_hands: usize, pub detection_con: f32, // 0.5 pub min_track_con: f32, // 0.5 - pub graph: *mut Mediagraph, + graph: Mediagraph, } impl HandDetector { pub fn new(mode: bool, max_hands: usize, detection_con: f32, min_track_con: f32) -> Self { - let graph_config = CString::new(include_str!("hand_tracking_desktop_live.txt")) - .expect("CString::new failed"); - let output_node = CString::new("hand_landmarks").expect("CString::new failed"); - - let graph: *mut Mediagraph = unsafe { - Mediagraph::Create( - mediagraph_GraphType_FACE, - graph_config.as_ptr(), - output_node.as_ptr(), - ) - }; + let graph = Mediagraph::new( + mediagraph_GraphType_HANDS, + include_str!("hand_tracking_desktop_live.txt"), + "hand_landmarks", + ); Self { mode, @@ -282,16 +276,8 @@ pub mod hands { } } - pub fn process(&mut self, input: Mat) -> bool { - let mut data = input.clone(); - let landmarks = unsafe { - mediagraph_Mediagraph_Process( - self.graph as *mut std::ffi::c_void, - data.data_mut(), - data.cols(), - data.rows(), - ) - }; + pub fn process(&mut self, input: &Mat) -> bool { + self.graph.process(input); // @todo read each landmark to build a hands struct true }