From c6377f857cdb60de16c613e3e213e438e305f4a8 Mon Sep 17 00:00:00 2001 From: Jules Youngberg Date: Wed, 15 Jun 2022 00:24:40 -0700 Subject: [PATCH] update detector to support arbitrary output features --- Cargo.toml | 1 + src/bindings.rs | 145 +++++++++++++++++++++++++++++++++++++++-------- src/face_mesh.rs | 10 ++-- src/hands.rs | 12 ++-- src/lib.rs | 118 ++++++++++++++++++++++++++++++-------- src/pose.rs | 14 +++-- 6 files changed, 238 insertions(+), 62 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cbe10d4..d2eb753 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ name = "mediapipe" [dependencies] cgmath = "0.18.0" +cxx = "1.0.68" libc = "0.2.0" opencv = { version = "0.63", features = ["clang-runtime"] } protobuf = "2.23.0" diff --git a/src/bindings.rs b/src/bindings.rs index a2f36f8..28a0b60 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -72,23 +72,130 @@ fn bindgen_test_layout_mediagraph_Landmark() { ) ); } -pub const mediagraph_DetectorType_POSE: mediagraph_DetectorType = 0; -pub const mediagraph_DetectorType_HANDS: mediagraph_DetectorType = 1; -pub const mediagraph_DetectorType_FACE: mediagraph_DetectorType = 2; -pub type mediagraph_DetectorType = ::std::os::raw::c_uint; +pub const mediagraph_FeatureType_POSE: mediagraph_FeatureType = 0; +pub const mediagraph_FeatureType_HANDS: mediagraph_FeatureType = 1; +pub const mediagraph_FeatureType_FACE: mediagraph_FeatureType = 2; +pub type mediagraph_FeatureType = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct mediagraph_Output { + pub type_: mediagraph_FeatureType, + pub name: *mut ::std::os::raw::c_char, +} +#[test] +fn bindgen_test_layout_mediagraph_Output() { + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(mediagraph_Output)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(mediagraph_Output)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).type_ as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(mediagraph_Output), + "::", + stringify!(type_) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).name as *const _ as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(mediagraph_Output), + "::", + stringify!(name) + ) + ); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct mediagraph_Feature { + pub data: *mut mediagraph_Landmark, +} +#[test] +fn bindgen_test_layout_mediagraph_Feature() { + assert_eq!( + ::std::mem::size_of::(), + 8usize, + concat!("Size of: ", stringify!(mediagraph_Feature)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(mediagraph_Feature)) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).data as *const _ as usize }, + 0usize, + concat!( + "Offset of field: ", + stringify!(mediagraph_Feature), + "::", + stringify!(data) + ) + ); +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct mediagraph_FeatureList { + pub num_features: u8, + pub features: *mut mediagraph_Feature, +} +#[test] +fn bindgen_test_layout_mediagraph_FeatureList() { + assert_eq!( + ::std::mem::size_of::(), + 16usize, + concat!("Size of: ", stringify!(mediagraph_FeatureList)) + ); + assert_eq!( + ::std::mem::align_of::(), + 8usize, + concat!("Alignment of ", stringify!(mediagraph_FeatureList)) + ); + assert_eq!( + unsafe { + &(*(::std::ptr::null::())).num_features as *const _ as usize + }, + 0usize, + concat!( + "Offset of field: ", + stringify!(mediagraph_FeatureList), + "::", + stringify!(num_features) + ) + ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).features as *const _ as usize }, + 8usize, + concat!( + "Offset of field: ", + stringify!(mediagraph_FeatureList), + "::", + stringify!(features) + ) + ); +} #[repr(C)] pub struct mediagraph_Detector__bindgen_vtable(::std::os::raw::c_void); #[repr(C)] #[derive(Debug)] pub struct mediagraph_Detector { pub vtable_: *const mediagraph_Detector__bindgen_vtable, - pub m_graph_type: mediagraph_DetectorType, } #[test] fn bindgen_test_layout_mediagraph_Detector() { assert_eq!( ::std::mem::size_of::(), - 16usize, + 8usize, concat!("Size of: ", stringify!(mediagraph_Detector)) ); assert_eq!( @@ -96,35 +203,23 @@ fn bindgen_test_layout_mediagraph_Detector() { 8usize, concat!("Alignment of ", stringify!(mediagraph_Detector)) ); - assert_eq!( - unsafe { - &(*(::std::ptr::null::())).m_graph_type as *const _ as usize - }, - 8usize, - concat!( - "Offset of field: ", - stringify!(mediagraph_Detector), - "::", - stringify!(m_graph_type) - ) - ); } extern "C" { - #[link_name = "\u{1}__ZN10mediagraph8Detector6CreateENS_12DetectorTypeEPKcS3_"] + #[link_name = "\u{1}__ZN10mediagraph8Detector6CreateEPKcPKNS_6OutputEh"] pub fn mediagraph_Detector_Create( - t: mediagraph_DetectorType, graph_config: *const ::std::os::raw::c_char, - output_node: *const ::std::os::raw::c_char, + outputs: *const mediagraph_Output, + num_outputs: u8, ) -> *mut mediagraph_Detector; } impl mediagraph_Detector { #[inline] pub unsafe fn Create( - t: mediagraph_DetectorType, graph_config: *const ::std::os::raw::c_char, - output_node: *const ::std::os::raw::c_char, + outputs: *const mediagraph_Output, + num_outputs: u8, ) -> *mut mediagraph_Detector { - mediagraph_Detector_Create(t, graph_config, output_node) + mediagraph_Detector_Create(graph_config, outputs, num_outputs) } } extern "C" { @@ -138,7 +233,7 @@ extern "C" { data: *mut u8, width: ::std::os::raw::c_int, height: ::std::os::raw::c_int, - ) -> *mut mediagraph_Landmark; + ) -> *mut mediagraph_FeatureList; } #[repr(C)] pub struct mediagraph_Effect__bindgen_vtable(::std::os::raw::c_void); diff --git a/src/face_mesh.rs b/src/face_mesh.rs index f6046df..0e85208 100644 --- a/src/face_mesh.rs +++ b/src/face_mesh.rs @@ -8,9 +8,11 @@ pub struct FaceMeshDetector { impl FaceMeshDetector { pub fn new() -> Self { let graph = Detector::new( - FACE_GRAPH_TYPE, include_str!("graphs/face_mesh_desktop_live.pbtxt"), - "multi_face_landmarks", + vec![Output { + type_: FeatureType::Face, + name: "multi_face_landmarks".into(), + }], ); Self { graph } @@ -20,12 +22,12 @@ impl FaceMeshDetector { pub fn process(&mut self, input: &Mat) -> Option { let landmarks = self.graph.process(input); - if landmarks.is_empty() { + if landmarks[0].is_empty() { return None; } let mut face_mesh = FaceMesh::default(); - face_mesh.data.copy_from_slice(landmarks); + face_mesh.data.copy_from_slice(landmarks[0][0].as_slice()); Some(face_mesh) } } diff --git a/src/hands.rs b/src/hands.rs index 7906e43..335aa86 100644 --- a/src/hands.rs +++ b/src/hands.rs @@ -33,9 +33,11 @@ pub struct HandDetector { impl HandDetector { pub fn new() -> Self { let graph = Detector::new( - HANDS_GRAPH_TYPE, include_str!("graphs/hand_tracking_desktop_live.pbtxt"), - "hand_landmarks", + vec![Output { + type_: FeatureType::Hands, + name: "hand_landmarks".into(), + }], ); Self { graph } @@ -43,12 +45,14 @@ impl HandDetector { /// Processes the input frame, returns a tuple of hands if detected. pub fn process(&mut self, input: &Mat) -> Option<[Hand; 2]> { - let landmarks = self.graph.process(input); + let result = self.graph.process(input); - if landmarks.is_empty() { + if result[0].is_empty() { return None; } + let landmarks = &result[0][0]; + let mut lh = Hand::default(); let mut rh = Hand::default(); lh.data.copy_from_slice(&landmarks[0..21]); diff --git a/src/lib.rs b/src/lib.rs index e256ff9..a5bd362 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,16 +22,62 @@ pub mod segmentation; use bindings::*; -/// The C++ mediagraph graph type. -pub type DetectorType = mediagraph_DetectorType; +type mFeatureType = mediagraph_FeatureType; +type mOutput = mediagraph_Output; +type mFeature = mediagraph_Feature; +type mFeatureList = mediagraph_FeatureList; + +/// The type of visual feature made up of landmarks. +#[derive(Debug, Clone, Copy)] +pub enum FeatureType { + Face, + Hands, + Pose, +} + +impl FeatureType { + fn num_landmarks(&self) -> usize { + match self { + FeatureType::Face => 478, + FeatureType::Hands => 42, + FeatureType::Pose => 33, + } + } +} + +impl Into for FeatureType { + fn into(self) -> mFeatureType { + match self { + FeatureType::Face => mediagraph_FeatureType_FACE, + FeatureType::Hands => mediagraph_FeatureType_HANDS, + FeatureType::Pose => mediagraph_FeatureType_POSE, + } + } +} + +/// The definition of a graph output. +#[derive(Debug, Clone)] +pub struct Output { + type_: FeatureType, + name: String, +} + +impl Into for Output { + fn into(self) -> mOutput { + let name = CString::new(self.name) + .expect("CString::new failed") + .into_raw(); + + mOutput { + type_: self.type_.into(), + name, + } + } +} /// The C++ mediagraph landmark type. pub type Landmark = mediagraph_Landmark; -pub const FACE_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_FACE; -pub const HANDS_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_HANDS; -pub const POSE_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_POSE; - impl Default for Landmark { fn default() -> Self { Self { @@ -81,36 +127,41 @@ impl Default for FaceMesh { /// Detector calculator which interacts with the C++ library. pub struct Detector { graph: *mut mediagraph_Detector, - num_landmarks: u32, + outputs: Vec, } impl Detector { /// Creates a new Mediagraph with the given config. - pub fn new(graph_type: DetectorType, graph_config: &str, output_node: &str) -> Self { + pub fn new(graph_config: &str, output_config: Vec) -> Self { + assert!( + output_config.len() > 0, + "must specify at least one output feature" + ); let graph_config = CString::new(graph_config).expect("CString::new failed"); - let output_node = CString::new(output_node).expect("CString::new failed"); + + let outputs = output_config + .iter() + .map(|f| f.clone().into()) + .collect::>(); let graph: *mut mediagraph_Detector = unsafe { - mediagraph_Detector::Create(graph_type, graph_config.as_ptr(), output_node.as_ptr()) - }; - - let num_landmarks = match graph_type { - FACE_GRAPH_TYPE => 478, - HANDS_GRAPH_TYPE => 42, - POSE_GRAPH_TYPE => 33, - _ => 0, + mediagraph_Detector::Create( + graph_config.as_ptr(), + outputs.as_ptr(), + outputs.len() as u8, + ) }; Self { graph, - num_landmarks, + outputs: output_config, } } /// Processes the input frame, returns a slice of landmarks if any are detected. - pub fn process(&mut self, input: &Mat) -> &[Landmark] { + pub fn process(&mut self, input: &Mat) -> Vec>> { let mut data = input.clone(); - let raw_landmarks = unsafe { + let results = unsafe { mediagraph_Detector_Process( self.graph as *mut std::ffi::c_void, data.data_mut(), @@ -118,11 +169,30 @@ impl Detector { data.rows(), ) }; - if raw_landmarks.is_null() { - return &[]; + + let mut landmarks = vec![]; + + let feature_lists = + unsafe { std::slice::from_raw_parts(results, self.outputs.len() as usize) }; + + for (i, feature_list) in feature_lists.iter().enumerate() { + let num_landmarks = self.outputs[i].type_.num_landmarks(); + let mut fl = vec![]; + let features = unsafe { + std::slice::from_raw_parts( + feature_list.features, + feature_list.num_features as usize, + ) + }; + + for feature in features.iter() { + let landmarks = unsafe { std::slice::from_raw_parts(feature.data, num_landmarks) }; + fl.push(landmarks.to_vec()); + } + + landmarks.push(fl); } - let landmarks = - unsafe { std::slice::from_raw_parts(raw_landmarks, self.num_landmarks as usize) }; + landmarks } } diff --git a/src/pose.rs b/src/pose.rs index ce48342..9d256de 100644 --- a/src/pose.rs +++ b/src/pose.rs @@ -45,9 +45,11 @@ pub struct PoseDetector { impl PoseDetector { pub fn new() -> Self { let graph = Detector::new( - POSE_GRAPH_TYPE, include_str!("graphs/pose_tracking_cpu.pbtxt"), - "pose_landmarks", + vec![Output { + type_: FeatureType::Pose, + name: "pose_landmarks".into(), + }], ); Self { graph } @@ -55,14 +57,16 @@ impl PoseDetector { /// Processes the input frame, returns a pose if detected. pub fn process(&mut self, input: &Mat) -> Option { - let landmarks = self.graph.process(input); + let result = self.graph.process(input); - if landmarks.is_empty() { + if result[0].is_empty() { return None; } + let landmarks = &result[0][0]; + let mut pose = Pose::default(); - pose.data.copy_from_slice(landmarks); + pose.data.copy_from_slice(landmarks.as_slice()); Some(pose) } }