update detector to support arbitrary output features

This commit is contained in:
Jules Youngberg 2022-06-15 00:24:40 -07:00
parent 18223e924c
commit c6377f857c
6 changed files with 238 additions and 62 deletions

View File

@ -13,6 +13,7 @@ name = "mediapipe"
[dependencies] [dependencies]
cgmath = "0.18.0" cgmath = "0.18.0"
cxx = "1.0.68"
libc = "0.2.0" libc = "0.2.0"
opencv = { version = "0.63", features = ["clang-runtime"] } opencv = { version = "0.63", features = ["clang-runtime"] }
protobuf = "2.23.0" protobuf = "2.23.0"

View File

@ -72,23 +72,130 @@ fn bindgen_test_layout_mediagraph_Landmark() {
) )
); );
} }
pub const mediagraph_DetectorType_POSE: mediagraph_DetectorType = 0; pub const mediagraph_FeatureType_POSE: mediagraph_FeatureType = 0;
pub const mediagraph_DetectorType_HANDS: mediagraph_DetectorType = 1; pub const mediagraph_FeatureType_HANDS: mediagraph_FeatureType = 1;
pub const mediagraph_DetectorType_FACE: mediagraph_DetectorType = 2; pub const mediagraph_FeatureType_FACE: mediagraph_FeatureType = 2;
pub type mediagraph_DetectorType = ::std::os::raw::c_uint; pub type mediagraph_FeatureType = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct mediagraph_Output {
pub type_: mediagraph_FeatureType,
pub name: *mut ::std::os::raw::c_char,
}
#[test]
fn bindgen_test_layout_mediagraph_Output() {
assert_eq!(
::std::mem::size_of::<mediagraph_Output>(),
16usize,
concat!("Size of: ", stringify!(mediagraph_Output))
);
assert_eq!(
::std::mem::align_of::<mediagraph_Output>(),
8usize,
concat!("Alignment of ", stringify!(mediagraph_Output))
);
assert_eq!(
unsafe { &(*(::std::ptr::null::<mediagraph_Output>())).type_ as *const _ as usize },
0usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_Output),
"::",
stringify!(type_)
)
);
assert_eq!(
unsafe { &(*(::std::ptr::null::<mediagraph_Output>())).name as *const _ as usize },
8usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_Output),
"::",
stringify!(name)
)
);
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct mediagraph_Feature {
pub data: *mut mediagraph_Landmark,
}
#[test]
fn bindgen_test_layout_mediagraph_Feature() {
assert_eq!(
::std::mem::size_of::<mediagraph_Feature>(),
8usize,
concat!("Size of: ", stringify!(mediagraph_Feature))
);
assert_eq!(
::std::mem::align_of::<mediagraph_Feature>(),
8usize,
concat!("Alignment of ", stringify!(mediagraph_Feature))
);
assert_eq!(
unsafe { &(*(::std::ptr::null::<mediagraph_Feature>())).data as *const _ as usize },
0usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_Feature),
"::",
stringify!(data)
)
);
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct mediagraph_FeatureList {
pub num_features: u8,
pub features: *mut mediagraph_Feature,
}
#[test]
fn bindgen_test_layout_mediagraph_FeatureList() {
assert_eq!(
::std::mem::size_of::<mediagraph_FeatureList>(),
16usize,
concat!("Size of: ", stringify!(mediagraph_FeatureList))
);
assert_eq!(
::std::mem::align_of::<mediagraph_FeatureList>(),
8usize,
concat!("Alignment of ", stringify!(mediagraph_FeatureList))
);
assert_eq!(
unsafe {
&(*(::std::ptr::null::<mediagraph_FeatureList>())).num_features as *const _ as usize
},
0usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_FeatureList),
"::",
stringify!(num_features)
)
);
assert_eq!(
unsafe { &(*(::std::ptr::null::<mediagraph_FeatureList>())).features as *const _ as usize },
8usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_FeatureList),
"::",
stringify!(features)
)
);
}
#[repr(C)] #[repr(C)]
pub struct mediagraph_Detector__bindgen_vtable(::std::os::raw::c_void); pub struct mediagraph_Detector__bindgen_vtable(::std::os::raw::c_void);
#[repr(C)] #[repr(C)]
#[derive(Debug)] #[derive(Debug)]
pub struct mediagraph_Detector { pub struct mediagraph_Detector {
pub vtable_: *const mediagraph_Detector__bindgen_vtable, pub vtable_: *const mediagraph_Detector__bindgen_vtable,
pub m_graph_type: mediagraph_DetectorType,
} }
#[test] #[test]
fn bindgen_test_layout_mediagraph_Detector() { fn bindgen_test_layout_mediagraph_Detector() {
assert_eq!( assert_eq!(
::std::mem::size_of::<mediagraph_Detector>(), ::std::mem::size_of::<mediagraph_Detector>(),
16usize, 8usize,
concat!("Size of: ", stringify!(mediagraph_Detector)) concat!("Size of: ", stringify!(mediagraph_Detector))
); );
assert_eq!( assert_eq!(
@ -96,35 +203,23 @@ fn bindgen_test_layout_mediagraph_Detector() {
8usize, 8usize,
concat!("Alignment of ", stringify!(mediagraph_Detector)) concat!("Alignment of ", stringify!(mediagraph_Detector))
); );
assert_eq!(
unsafe {
&(*(::std::ptr::null::<mediagraph_Detector>())).m_graph_type as *const _ as usize
},
8usize,
concat!(
"Offset of field: ",
stringify!(mediagraph_Detector),
"::",
stringify!(m_graph_type)
)
);
} }
extern "C" { extern "C" {
#[link_name = "\u{1}__ZN10mediagraph8Detector6CreateENS_12DetectorTypeEPKcS3_"] #[link_name = "\u{1}__ZN10mediagraph8Detector6CreateEPKcPKNS_6OutputEh"]
pub fn mediagraph_Detector_Create( pub fn mediagraph_Detector_Create(
t: mediagraph_DetectorType,
graph_config: *const ::std::os::raw::c_char, graph_config: *const ::std::os::raw::c_char,
output_node: *const ::std::os::raw::c_char, outputs: *const mediagraph_Output,
num_outputs: u8,
) -> *mut mediagraph_Detector; ) -> *mut mediagraph_Detector;
} }
impl mediagraph_Detector { impl mediagraph_Detector {
#[inline] #[inline]
pub unsafe fn Create( pub unsafe fn Create(
t: mediagraph_DetectorType,
graph_config: *const ::std::os::raw::c_char, graph_config: *const ::std::os::raw::c_char,
output_node: *const ::std::os::raw::c_char, outputs: *const mediagraph_Output,
num_outputs: u8,
) -> *mut mediagraph_Detector { ) -> *mut mediagraph_Detector {
mediagraph_Detector_Create(t, graph_config, output_node) mediagraph_Detector_Create(graph_config, outputs, num_outputs)
} }
} }
extern "C" { extern "C" {
@ -138,7 +233,7 @@ extern "C" {
data: *mut u8, data: *mut u8,
width: ::std::os::raw::c_int, width: ::std::os::raw::c_int,
height: ::std::os::raw::c_int, height: ::std::os::raw::c_int,
) -> *mut mediagraph_Landmark; ) -> *mut mediagraph_FeatureList;
} }
#[repr(C)] #[repr(C)]
pub struct mediagraph_Effect__bindgen_vtable(::std::os::raw::c_void); pub struct mediagraph_Effect__bindgen_vtable(::std::os::raw::c_void);

View File

@ -8,9 +8,11 @@ pub struct FaceMeshDetector {
impl FaceMeshDetector { impl FaceMeshDetector {
pub fn new() -> Self { pub fn new() -> Self {
let graph = Detector::new( let graph = Detector::new(
FACE_GRAPH_TYPE,
include_str!("graphs/face_mesh_desktop_live.pbtxt"), include_str!("graphs/face_mesh_desktop_live.pbtxt"),
"multi_face_landmarks", vec![Output {
type_: FeatureType::Face,
name: "multi_face_landmarks".into(),
}],
); );
Self { graph } Self { graph }
@ -20,12 +22,12 @@ impl FaceMeshDetector {
pub fn process(&mut self, input: &Mat) -> Option<FaceMesh> { pub fn process(&mut self, input: &Mat) -> Option<FaceMesh> {
let landmarks = self.graph.process(input); let landmarks = self.graph.process(input);
if landmarks.is_empty() { if landmarks[0].is_empty() {
return None; return None;
} }
let mut face_mesh = FaceMesh::default(); let mut face_mesh = FaceMesh::default();
face_mesh.data.copy_from_slice(landmarks); face_mesh.data.copy_from_slice(landmarks[0][0].as_slice());
Some(face_mesh) Some(face_mesh)
} }
} }

View File

@ -33,9 +33,11 @@ pub struct HandDetector {
impl HandDetector { impl HandDetector {
pub fn new() -> Self { pub fn new() -> Self {
let graph = Detector::new( let graph = Detector::new(
HANDS_GRAPH_TYPE,
include_str!("graphs/hand_tracking_desktop_live.pbtxt"), include_str!("graphs/hand_tracking_desktop_live.pbtxt"),
"hand_landmarks", vec![Output {
type_: FeatureType::Hands,
name: "hand_landmarks".into(),
}],
); );
Self { graph } Self { graph }
@ -43,12 +45,14 @@ impl HandDetector {
/// Processes the input frame, returns a tuple of hands if detected. /// Processes the input frame, returns a tuple of hands if detected.
pub fn process(&mut self, input: &Mat) -> Option<[Hand; 2]> { pub fn process(&mut self, input: &Mat) -> Option<[Hand; 2]> {
let landmarks = self.graph.process(input); let result = self.graph.process(input);
if landmarks.is_empty() { if result[0].is_empty() {
return None; return None;
} }
let landmarks = &result[0][0];
let mut lh = Hand::default(); let mut lh = Hand::default();
let mut rh = Hand::default(); let mut rh = Hand::default();
lh.data.copy_from_slice(&landmarks[0..21]); lh.data.copy_from_slice(&landmarks[0..21]);

View File

@ -22,16 +22,62 @@ pub mod segmentation;
use bindings::*; use bindings::*;
/// The C++ mediagraph graph type. type mFeatureType = mediagraph_FeatureType;
pub type DetectorType = mediagraph_DetectorType; type mOutput = mediagraph_Output;
type mFeature = mediagraph_Feature;
type mFeatureList = mediagraph_FeatureList;
/// The type of visual feature made up of landmarks.
#[derive(Debug, Clone, Copy)]
pub enum FeatureType {
Face,
Hands,
Pose,
}
impl FeatureType {
fn num_landmarks(&self) -> usize {
match self {
FeatureType::Face => 478,
FeatureType::Hands => 42,
FeatureType::Pose => 33,
}
}
}
impl Into<mFeatureType> for FeatureType {
fn into(self) -> mFeatureType {
match self {
FeatureType::Face => mediagraph_FeatureType_FACE,
FeatureType::Hands => mediagraph_FeatureType_HANDS,
FeatureType::Pose => mediagraph_FeatureType_POSE,
}
}
}
/// The definition of a graph output.
#[derive(Debug, Clone)]
pub struct Output {
type_: FeatureType,
name: String,
}
impl Into<mOutput> for Output {
fn into(self) -> mOutput {
let name = CString::new(self.name)
.expect("CString::new failed")
.into_raw();
mOutput {
type_: self.type_.into(),
name,
}
}
}
/// The C++ mediagraph landmark type. /// The C++ mediagraph landmark type.
pub type Landmark = mediagraph_Landmark; pub type Landmark = mediagraph_Landmark;
pub const FACE_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_FACE;
pub const HANDS_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_HANDS;
pub const POSE_GRAPH_TYPE: DetectorType = mediagraph_DetectorType_POSE;
impl Default for Landmark { impl Default for Landmark {
fn default() -> Self { fn default() -> Self {
Self { Self {
@ -81,36 +127,41 @@ impl Default for FaceMesh {
/// Detector calculator which interacts with the C++ library. /// Detector calculator which interacts with the C++ library.
pub struct Detector { pub struct Detector {
graph: *mut mediagraph_Detector, graph: *mut mediagraph_Detector,
num_landmarks: u32, outputs: Vec<Output>,
} }
impl Detector { impl Detector {
/// Creates a new Mediagraph with the given config. /// Creates a new Mediagraph with the given config.
pub fn new(graph_type: DetectorType, graph_config: &str, output_node: &str) -> Self { pub fn new(graph_config: &str, output_config: Vec<Output>) -> Self {
assert!(
output_config.len() > 0,
"must specify at least one output feature"
);
let graph_config = CString::new(graph_config).expect("CString::new failed"); let graph_config = CString::new(graph_config).expect("CString::new failed");
let output_node = CString::new(output_node).expect("CString::new failed");
let outputs = output_config
.iter()
.map(|f| f.clone().into())
.collect::<Vec<mOutput>>();
let graph: *mut mediagraph_Detector = unsafe { let graph: *mut mediagraph_Detector = unsafe {
mediagraph_Detector::Create(graph_type, graph_config.as_ptr(), output_node.as_ptr()) mediagraph_Detector::Create(
}; graph_config.as_ptr(),
outputs.as_ptr(),
let num_landmarks = match graph_type { outputs.len() as u8,
FACE_GRAPH_TYPE => 478, )
HANDS_GRAPH_TYPE => 42,
POSE_GRAPH_TYPE => 33,
_ => 0,
}; };
Self { Self {
graph, graph,
num_landmarks, outputs: output_config,
} }
} }
/// Processes the input frame, returns a slice of landmarks if any are detected. /// Processes the input frame, returns a slice of landmarks if any are detected.
pub fn process(&mut self, input: &Mat) -> &[Landmark] { pub fn process(&mut self, input: &Mat) -> Vec<Vec<Vec<Landmark>>> {
let mut data = input.clone(); let mut data = input.clone();
let raw_landmarks = unsafe { let results = unsafe {
mediagraph_Detector_Process( mediagraph_Detector_Process(
self.graph as *mut std::ffi::c_void, self.graph as *mut std::ffi::c_void,
data.data_mut(), data.data_mut(),
@ -118,11 +169,30 @@ impl Detector {
data.rows(), data.rows(),
) )
}; };
if raw_landmarks.is_null() {
return &[]; let mut landmarks = vec![];
let feature_lists =
unsafe { std::slice::from_raw_parts(results, self.outputs.len() as usize) };
for (i, feature_list) in feature_lists.iter().enumerate() {
let num_landmarks = self.outputs[i].type_.num_landmarks();
let mut fl = vec![];
let features = unsafe {
std::slice::from_raw_parts(
feature_list.features,
feature_list.num_features as usize,
)
};
for feature in features.iter() {
let landmarks = unsafe { std::slice::from_raw_parts(feature.data, num_landmarks) };
fl.push(landmarks.to_vec());
}
landmarks.push(fl);
} }
let landmarks =
unsafe { std::slice::from_raw_parts(raw_landmarks, self.num_landmarks as usize) };
landmarks landmarks
} }
} }

View File

@ -45,9 +45,11 @@ pub struct PoseDetector {
impl PoseDetector { impl PoseDetector {
pub fn new() -> Self { pub fn new() -> Self {
let graph = Detector::new( let graph = Detector::new(
POSE_GRAPH_TYPE,
include_str!("graphs/pose_tracking_cpu.pbtxt"), include_str!("graphs/pose_tracking_cpu.pbtxt"),
"pose_landmarks", vec![Output {
type_: FeatureType::Pose,
name: "pose_landmarks".into(),
}],
); );
Self { graph } Self { graph }
@ -55,14 +57,16 @@ impl PoseDetector {
/// Processes the input frame, returns a pose if detected. /// Processes the input frame, returns a pose if detected.
pub fn process(&mut self, input: &Mat) -> Option<Pose> { pub fn process(&mut self, input: &Mat) -> Option<Pose> {
let landmarks = self.graph.process(input); let result = self.graph.process(input);
if landmarks.is_empty() { if result[0].is_empty() {
return None; return None;
} }
let landmarks = &result[0][0];
let mut pose = Pose::default(); let mut pose = Pose::default();
pose.data.copy_from_slice(landmarks); pose.data.copy_from_slice(landmarks.as_slice());
Some(pose) Some(pose)
} }
} }