Update Python models to allow extracting any set of graph outputs

This commit is contained in:
Leevar Williams 2020-11-25 09:44:22 -08:00
parent 7d01b7ff32
commit facd7520d7
3 changed files with 40 additions and 17 deletions

View File

@ -15,7 +15,7 @@
# Lint as: python3 # Lint as: python3
"""MediaPipe FaceMesh.""" """MediaPipe FaceMesh."""
from typing import NamedTuple from typing import NamedTuple, Optional, Tuple
import numpy as np import numpy as np
@ -249,7 +249,8 @@ class FaceMesh(SolutionBase):
static_image_mode=False, static_image_mode=False,
max_num_faces=2, max_num_faces=2,
min_detection_confidence=0.5, min_detection_confidence=0.5,
min_tracking_confidence=0.5): min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('multi_face_landmarks',)):
"""Initializes a MediaPipe FaceMesh object. """Initializes a MediaPipe FaceMesh object.
Args: Args:
@ -274,6 +275,9 @@ class FaceMesh(SolutionBase):
robustness of the solution, at the expense of a higher latency. Ignored robustness of the solution, at the expense of a higher latency. Ignored
if "static_image_mode" is True, where face detection simply runs on if "static_image_mode" is True, where face detection simply runs on
every image. Default to 0.5. every image. Default to 0.5.
outputs: A list of the graph output stream names to observe. If the list
is empty, all the output streams listed in the graph config will be
automatically observed by default.
""" """
super().__init__( super().__init__(
binary_graph_path=BINARYPB_FILE_PATH, binary_graph_path=BINARYPB_FILE_PATH,
@ -287,7 +291,7 @@ class FaceMesh(SolutionBase):
'facelandmarkcpu__ThresholdingCalculator.threshold': 'facelandmarkcpu__ThresholdingCalculator.threshold':
min_tracking_confidence, min_tracking_confidence,
}, },
outputs=['multi_face_landmarks']) outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple: def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the face landmarks on each detected face. """Processes an RGB image and returns the face landmarks on each detected face.
@ -300,8 +304,12 @@ class FaceMesh(SolutionBase):
ValueError: If the input image is not three channel RGB. ValueError: If the input image is not three channel RGB.
Returns: Returns:
A NamedTuple object with a "multi_face_landmarks" field that contains the A NamedTuple object with fields corresponding to the set of outputs passed to the
face landmarks on each detected face. constructor. Fields may include:
"multi_hand_landmarks" The face landmarks on each detected face
"face_detections" The detected faces
"face_rects_from_landmarks" Regions of interest calculated based on landmarks
"face_rects_from_detections" Regions of interest calculated based on face detections
""" """
return super().process(input_data={'image': image}) return super().process(input_data={'image': image})

View File

@ -16,7 +16,7 @@
"""MediaPipe Hands.""" """MediaPipe Hands."""
import enum import enum
from typing import NamedTuple from typing import NamedTuple, Optional, Tuple
import numpy as np import numpy as np
@ -168,7 +168,8 @@ class Hands(SolutionBase):
static_image_mode=False, static_image_mode=False,
max_num_hands=2, max_num_hands=2,
min_detection_confidence=0.7, min_detection_confidence=0.7,
min_tracking_confidence=0.5): min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('multi_hand_landmarks', 'multi_handedness')):
"""Initializes a MediaPipe Hand object. """Initializes a MediaPipe Hand object.
Args: Args:
@ -193,6 +194,9 @@ class Hands(SolutionBase):
robustness of the solution, at the expense of a higher latency. Ignored robustness of the solution, at the expense of a higher latency. Ignored
if "static_image_mode" is True, where hand detection simply runs on if "static_image_mode" is True, where hand detection simply runs on
every image. Default to 0.5. every image. Default to 0.5.
outputs: A tuple of the graph output stream names to observe. If the tuple
is empty, all the output streams listed in the graph config will be
automatically observed by default.
""" """
super().__init__( super().__init__(
binary_graph_path=BINARYPB_FILE_PATH, binary_graph_path=BINARYPB_FILE_PATH,
@ -206,7 +210,7 @@ class Hands(SolutionBase):
'handlandmarkcpu__ThresholdingCalculator.threshold': 'handlandmarkcpu__ThresholdingCalculator.threshold':
min_tracking_confidence, min_tracking_confidence,
}, },
outputs=['multi_hand_landmarks', 'multi_handedness']) outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple: def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the hand landmarks and handedness of each detected hand. """Processes an RGB image and returns the hand landmarks and handedness of each detected hand.
@ -219,10 +223,13 @@ class Hands(SolutionBase):
ValueError: If the input image is not three channel RGB. ValueError: If the input image is not three channel RGB.
Returns: Returns:
A NamedTuple object with two fields: a "multi_hand_landmarks" field that A NamedTuple object with fields corresponding to the set of outputs passed to the
contains the hand landmarks on each detected hand and a "multi_handedness" constructor. Fields may include:
field that contains the handedness (left v.s. right hand) of the detected "multi_hand_landmarks" The hand landmarks on each detected hand
hand. "multi_handedness" The handedness (left v.s. right hand) of the detected hand
"palm_detections" The detected palms
"hand_rects" Regions of interest calculated based on landmarks
"hand_rects_from_palm_detections" Regions of interest calculated based on palm detections
""" """
return super().process(input_data={'image': image}) return super().process(input_data={'image': image})

View File

@ -16,7 +16,7 @@
"""MediaPipe Pose.""" """MediaPipe Pose."""
import enum import enum
from typing import NamedTuple from typing import NamedTuple, Optional, Tuple
import numpy as np import numpy as np
@ -159,7 +159,8 @@ class Pose(SolutionBase):
def __init__(self, def __init__(self,
static_image_mode=False, static_image_mode=False,
min_detection_confidence=0.5, min_detection_confidence=0.5,
min_tracking_confidence=0.5): min_tracking_confidence=0.5,
outputs: Optional[Tuple[str]] = ('pose_landmarks',)):
"""Initializes a MediaPipe Pose object. """Initializes a MediaPipe Pose object.
Args: Args:
@ -181,6 +182,9 @@ class Pose(SolutionBase):
increase robustness of the solution, at the expense of a higher latency. increase robustness of the solution, at the expense of a higher latency.
Ignored if "static_image_mode" is True, where person detection simply Ignored if "static_image_mode" is True, where person detection simply
runs on every image. Default to 0.5. runs on every image. Default to 0.5.
outputs: A list of the graph output stream names to observe. If the list
is empty, all the output streams listed in the graph config will be
automatically observed by default.
""" """
super().__init__( super().__init__(
binary_graph_path=BINARYPB_FILE_PATH, binary_graph_path=BINARYPB_FILE_PATH,
@ -193,7 +197,7 @@ class Pose(SolutionBase):
'poselandmarkupperbodycpu__poselandmarkupperbodybyroicpu__ThresholdingCalculator.threshold': 'poselandmarkupperbodycpu__poselandmarkupperbodybyroicpu__ThresholdingCalculator.threshold':
min_tracking_confidence, min_tracking_confidence,
}, },
outputs=['pose_landmarks']) outputs=list(outputs) if outputs else [])
def process(self, image: np.ndarray) -> NamedTuple: def process(self, image: np.ndarray) -> NamedTuple:
"""Processes an RGB image and returns the pose landmarks on the most prominent person detected. """Processes an RGB image and returns the pose landmarks on the most prominent person detected.
@ -206,8 +210,12 @@ class Pose(SolutionBase):
ValueError: If the input image is not three channel RGB. ValueError: If the input image is not three channel RGB.
Returns: Returns:
A NamedTuple object with a "pose_landmarks" field that contains the pose A NamedTuple object with fields corresponding to the set of outputs passed to the
landmarks on the most prominent person detected. constructor. Fields may include:
"pose_landmarks" The pose landmarks on the most prominent person detected
"pose_detection" The detected pose
"pose_rect_from_landmarks" Region of interest calculated based on landmarks
"pose_rect_from_detection" Region of interest calculated based on pose detection
""" """
return super().process(input_data={'image': image}) return super().process(input_data={'image': image})