Update Python models to allow extracting any set of graph outputs

This commit is contained in:
Leevar Williams 2020-11-25 09:44:22 -08:00
parent 7d01b7ff32
commit facd7520d7
3 changed files with 40 additions and 17 deletions

View File

@ -15,7 +15,7 @@
# Lint as: python3
"""MediaPipe FaceMesh."""
from typing import NamedTuple
from typing import NamedTuple, Optional, Tuple
import numpy as np
@ -249,7 +249,8 @@ class FaceMesh(SolutionBase):
static_image_mode=False,
max_num_faces=2,
min_detection_confidence=0.5,
min_tracking_confidence=0.5):
min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('multi_face_landmarks',)):
"""Initializes a MediaPipe FaceMesh object.
Args:
@ -274,6 +275,9 @@ class FaceMesh(SolutionBase):
robustness of the solution, at the expense of a higher latency. Ignored
if "static_image_mode" is True, where face detection simply runs on
every image. Default to 0.5.
outputs: A list of the graph output stream names to observe. If the list
is empty, all the output streams listed in the graph config will be
automatically observed by default.
"""
super().__init__(
binary_graph_path=BINARYPB_FILE_PATH,
@ -287,7 +291,7 @@ class FaceMesh(SolutionBase):
'facelandmarkcpu__ThresholdingCalculator.threshold':
min_tracking_confidence,
},
outputs=['multi_face_landmarks'])
outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the face landmarks on each detected face.
@ -300,8 +304,12 @@ class FaceMesh(SolutionBase):
ValueError: If the input image is not three channel RGB.
Returns:
A NamedTuple object with a "multi_face_landmarks" field that contains the
face landmarks on each detected face.
A NamedTuple object with fields corresponding to the set of outputs passed to the
constructor. Fields may include:
"multi_hand_landmarks" The face landmarks on each detected face
"face_detections" The detected faces
"face_rects_from_landmarks" Regions of interest calculated based on landmarks
"face_rects_from_detections" Regions of interest calculated based on face detections
"""
return super().process(input_data={'image': image})

View File

@ -16,7 +16,7 @@
"""MediaPipe Hands."""
import enum
from typing import NamedTuple
from typing import NamedTuple, Optional, Tuple
import numpy as np
@ -168,7 +168,8 @@ class Hands(SolutionBase):
static_image_mode=False,
max_num_hands=2,
min_detection_confidence=0.7,
min_tracking_confidence=0.5):
min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('multi_hand_landmarks', 'multi_handedness')):
"""Initializes a MediaPipe Hand object.
Args:
@ -193,6 +194,9 @@ class Hands(SolutionBase):
robustness of the solution, at the expense of a higher latency. Ignored
if "static_image_mode" is True, where hand detection simply runs on
every image. Default to 0.5.
outputs: A tuple of the graph output stream names to observe. If the tuple
is empty, all the output streams listed in the graph config will be
automatically observed by default.
"""
super().__init__(
binary_graph_path=BINARYPB_FILE_PATH,
@ -206,7 +210,7 @@ class Hands(SolutionBase):
'handlandmarkcpu__ThresholdingCalculator.threshold':
min_tracking_confidence,
},
outputs=['multi_hand_landmarks', 'multi_handedness'])
outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the hand landmarks and handedness of each detected hand.
@ -219,10 +223,13 @@ class Hands(SolutionBase):
ValueError: If the input image is not three channel RGB.
Returns:
A NamedTuple object with two fields: a "multi_hand_landmarks" field that
contains the hand landmarks on each detected hand and a "multi_handedness"
field that contains the handedness (left v.s. right hand) of the detected
hand.
A NamedTuple object with fields corresponding to the set of outputs passed to the
constructor. Fields may include:
"multi_hand_landmarks" The hand landmarks on each detected hand
"multi_handedness" The handedness (left v.s. right hand) of the detected hand
"palm_detections" The detected palms
"hand_rects" Regions of interest calculated based on landmarks
"hand_rects_from_palm_detections" Regions of interest calculated based on palm detections
"""
return super().process(input_data={'image': image})

View File

@ -16,7 +16,7 @@
"""MediaPipe Pose."""
import enum
from typing import NamedTuple
from typing import NamedTuple, Optional, Tuple
import numpy as np
@ -159,7 +159,8 @@ class Pose(SolutionBase):
def __init__(self,
static_image_mode=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5):
min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('pose_landmarks',)):
"""Initializes a MediaPipe Pose object.
Args:
@ -181,6 +182,9 @@ class Pose(SolutionBase):
increase robustness of the solution, at the expense of a higher latency.
Ignored if "static_image_mode" is True, where person detection simply
runs on every image. Default to 0.5.
outputs: A list of the graph output stream names to observe. If the list
is empty, all the output streams listed in the graph config will be
automatically observed by default.
"""
super().__init__(
binary_graph_path=BINARYPB_FILE_PATH,
@ -193,7 +197,7 @@ class Pose(SolutionBase):
'poselandmarkupperbodycpu__poselandmarkupperbodybyroicpu__ThresholdingCalculator.threshold':
min_tracking_confidence,
},
outputs=['pose_landmarks'])
outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the pose landmarks on the most prominent person detected.
@ -206,8 +210,12 @@ class Pose(SolutionBase):
ValueError: If the input image is not three channel RGB.
Returns:
A NamedTuple object with a "pose_landmarks" field that contains the pose
landmarks on the most prominent person detected.
A NamedTuple object with fields corresponding to the set of outputs passed to the
constructor. Fields may include:
"pose_landmarks" The pose landmarks on the most prominent person detected
"pose_detection" The detected pose
"pose_rect_from_landmarks" Region of interest calculated based on landmarks
"pose_rect_from_detection" Region of interest calculated based on pose detection
"""
return super().process(input_data={'image': image})