Revised gesture recognizer implementation

This commit is contained in:
kinaryml 2022-10-31 05:34:31 -07:00
parent 5ec87c8bd2
commit 19be9e9012
4 changed files with 57 additions and 50 deletions

View File

@ -47,22 +47,26 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task'
_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = 'gesture_recognizer.task'
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = 'gesture_recognizer_with_custom_classifier.task'
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
_TWO_HANDS_IMAGE = 'right_hands.jpg'
_FIST_IMAGE = 'fist.jpg'
_FIST_LANDMARKS = 'fist_landmarks.pbtxt'
_FIST_LABEL = 'Closed_Fist'
_THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = 'Thumb_Up'
_THUMB_UP_INDEX = 5
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_LABEL = 'Pointing_Up'
_POINTING_UP_INDEX = 3
_ROCK_LABEL = "Rock"
_LANDMARKS_ERROR_TOLERANCE = 0.03
_GESTURE_EXPECTED_INDEX = -1
def _get_expected_gesture_recognition_result(
file_path: str, gesture_label: str, gesture_index: int
file_path: str, gesture_label: str
) -> _GestureRecognitionResult:
landmarks_detection_result_file_path = test_utils.get_test_data_path(
file_path)
@ -73,7 +77,8 @@ def _get_expected_gesture_recognition_result(
text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto)
gesture = _Category(category_name=gesture_label, index=gesture_index,
gesture = _Category(category_name=gesture_label,
index=_GESTURE_EXPECTED_INDEX,
display_name='')
return _GestureRecognitionResult(
gestures=[[gesture]],
@ -94,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase):
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
self.model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_MODEL_FILE)
_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
def _assert_actual_result_approximately_matches_expected_result(
self,
@ -127,7 +132,7 @@ class GestureRecognizerTest(parameterized.TestCase):
# Actual gesture with top score matches expected gesture.
actual_top_gesture = actual_result.gestures[0][0]
expected_top_gesture = expected_result.gestures[0][0]
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
self.assertEqual(actual_top_gesture.index, _GESTURE_EXPECTED_INDEX)
self.assertEqual(actual_top_gesture.category_name,
expected_top_gesture.category_name)
@ -163,10 +168,10 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)),
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)))
def test_recognize(self, model_file_type, expected_recognition_result):
# Creates gesture recognizer.
@ -194,10 +199,10 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)),
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)))
def test_recognize_in_context(self, model_file_type,
expected_recognition_result):
@ -224,12 +229,12 @@ class GestureRecognizerTest(parameterized.TestCase):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options,
min_gesture_confidence=2)
min_gesture_confidence=0.5)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image)
expected_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
# Only contains one top scoring gesture.
self.assertLen(recognition_result.gestures[0], 1)
# Actual gesture with top score matches expected gesture.
@ -266,7 +271,25 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize(test_image,
image_processing_options)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX)
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_custom_gesture_fist(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the fist image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_FIST_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_FIST_LANDMARKS, _ROCK_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
@ -373,7 +396,7 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize_for_video(self.test_image,
timestamp)
expected_recognition_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
@ -410,7 +433,7 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters(
(_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)),
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)),
(_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], [])))
def test_recognize_async_calls(self, image_path, expected_result):
test_image = _Image.create_from_file(

View File

@ -87,12 +87,9 @@ py_library(
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",

View File

@ -20,13 +20,9 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options
@ -38,12 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_BaseOptions = base_options_module.BaseOptions
_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions
_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -64,6 +57,7 @@ _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
_GESTURE_DEFAULT_INDEX = -1
@dataclasses.dataclass
@ -72,8 +66,9 @@ class GestureRecognitionResult:
element represents a single hand detected in the image.
Attributes:
gestures: Recognized hand gestures with sorted order such that the
winning label is the first item in the list.
gestures: Recognized hand gestures of detected hands. Note that the index
of the gesture is always 0, because the raw indices from multiple gesture
classifiers cannot consolidate to a meaningful index.
handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates.
hand_world_landmarks: Detected hand landmarks in world coordinates.
@ -101,16 +96,16 @@ def _build_recognition_result(
[
[
category_module.Category(
index=gesture.index, score=gesture.score,
index=_GESTURE_DEFAULT_INDEX, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
index=handedness.index, score=handedness.score,
display_name=handedness.display_name, category_name=handedness.label)
for handedness in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
@ -170,26 +165,17 @@ class GestureRecognizerOptions:
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Configure hand detector options.
hand_detector_options_proto = _HandDetectorGraphOptionsProto(
num_hands=self.num_hands,
min_detection_confidence=self.min_hand_detection_confidence)
# Configure hand landmarker options.
hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto(
min_detection_confidence=self.min_hand_presence_confidence)
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
hand_detector_graph_options=hand_detector_options_proto,
hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto,
min_tracking_confidence=self.min_tracking_confidence)
# Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto()
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
# Configure hand gesture recognizer options.
classifier_options = _ClassifierOptions(
score_threshold=self.min_gesture_confidence)
gesture_classifier_options = _GestureClassifierGraphOptionsProto(
classifier_options=classifier_options.to_pb2())
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto(
canned_gesture_classifier_graph_options=gesture_classifier_options)
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto()
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
return _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto,

View File

@ -130,6 +130,7 @@ filegroup(
"hand_landmark_lite.tflite",
"hand_landmarker.task",
"gesture_recognizer.task",
"gesture_recognizer_with_custom_classifier.task",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite",