Revised gesture recognizer implementation

This commit is contained in:
kinaryml 2022-10-31 05:34:31 -07:00
parent 5ec87c8bd2
commit 19be9e9012
4 changed files with 57 additions and 50 deletions

View File

@ -47,22 +47,26 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' _GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = 'gesture_recognizer.task'
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = 'gesture_recognizer_with_custom_classifier.task'
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg' _NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
_TWO_HANDS_IMAGE = 'right_hands.jpg' _TWO_HANDS_IMAGE = 'right_hands.jpg'
_FIST_IMAGE = 'fist.jpg'
_FIST_LANDMARKS = 'fist_landmarks.pbtxt'
_FIST_LABEL = 'Closed_Fist'
_THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = 'Thumb_Up' _THUMB_UP_LABEL = 'Thumb_Up'
_THUMB_UP_INDEX = 5
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_LABEL = 'Pointing_Up' _POINTING_UP_LABEL = 'Pointing_Up'
_POINTING_UP_INDEX = 3 _ROCK_LABEL = "Rock"
_LANDMARKS_ERROR_TOLERANCE = 0.03 _LANDMARKS_ERROR_TOLERANCE = 0.03
_GESTURE_EXPECTED_INDEX = -1
def _get_expected_gesture_recognition_result( def _get_expected_gesture_recognition_result(
file_path: str, gesture_label: str, gesture_index: int file_path: str, gesture_label: str
) -> _GestureRecognitionResult: ) -> _GestureRecognitionResult:
landmarks_detection_result_file_path = test_utils.get_test_data_path( landmarks_detection_result_file_path = test_utils.get_test_data_path(
file_path) file_path)
@ -73,7 +77,8 @@ def _get_expected_gesture_recognition_result(
text_format.Parse(f.read(), landmarks_detection_result_proto) text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto) landmarks_detection_result_proto)
gesture = _Category(category_name=gesture_label, index=gesture_index, gesture = _Category(category_name=gesture_label,
index=_GESTURE_EXPECTED_INDEX,
display_name='') display_name='')
return _GestureRecognitionResult( return _GestureRecognitionResult(
gestures=[[gesture]], gestures=[[gesture]],
@ -94,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase):
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_THUMB_UP_IMAGE)) test_utils.get_test_data_path(_THUMB_UP_IMAGE))
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_MODEL_FILE) _GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
def _assert_actual_result_approximately_matches_expected_result( def _assert_actual_result_approximately_matches_expected_result(
self, self,
@ -127,7 +132,7 @@ class GestureRecognizerTest(parameterized.TestCase):
# Actual gesture with top score matches expected gesture. # Actual gesture with top score matches expected gesture.
actual_top_gesture = actual_result.gestures[0][0] actual_top_gesture = actual_result.gestures[0][0]
expected_top_gesture = expected_result.gestures[0][0] expected_top_gesture = expected_result.gestures[0][0]
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) self.assertEqual(actual_top_gesture.index, _GESTURE_EXPECTED_INDEX)
self.assertEqual(actual_top_gesture.category_name, self.assertEqual(actual_top_gesture.category_name,
expected_top_gesture.category_name) expected_top_gesture.category_name)
@ -163,10 +168,10 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)), )),
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
))) )))
def test_recognize(self, model_file_type, expected_recognition_result): def test_recognize(self, model_file_type, expected_recognition_result):
# Creates gesture recognizer. # Creates gesture recognizer.
@ -194,10 +199,10 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
)), )),
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
))) )))
def test_recognize_in_context(self, model_file_type, def test_recognize_in_context(self, model_file_type,
expected_recognition_result): expected_recognition_result):
@ -224,12 +229,12 @@ class GestureRecognizerTest(parameterized.TestCase):
# Creates gesture recognizer. # Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, options = _GestureRecognizerOptions(base_options=base_options,
min_gesture_confidence=2) min_gesture_confidence=0.5)
with _GestureRecognizer.create_from_options(options) as recognizer: with _GestureRecognizer.create_from_options(options) as recognizer:
# Performs hand gesture recognition on the input. # Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image) recognition_result = recognizer.recognize(self.test_image)
expected_result = _get_expected_gesture_recognition_result( expected_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
# Only contains one top scoring gesture. # Only contains one top scoring gesture.
self.assertLen(recognition_result.gestures[0], 1) self.assertLen(recognition_result.gestures[0], 1)
# Actual gesture with top score matches expected gesture. # Actual gesture with top score matches expected gesture.
@ -266,7 +271,25 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize(test_image, recognition_result = recognizer.recognize(test_image,
image_processing_options) image_processing_options)
expected_recognition_result = _get_expected_gesture_recognition_result( expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX) _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_custom_gesture_fist(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the fist image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_FIST_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_FIST_LANDMARKS, _ROCK_LABEL)
# Comparing results. # Comparing results.
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
@ -373,7 +396,7 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize_for_video(self.test_image, recognition_result = recognizer.recognize_for_video(self.test_image,
timestamp) timestamp)
expected_recognition_result = _get_expected_gesture_recognition_result( expected_recognition_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
@ -410,7 +433,7 @@ class GestureRecognizerTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result( (_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)), _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)),
(_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], []))) (_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], [])))
def test_recognize_async_calls(self, image_path, expected_result): def test_recognize_async_calls(self, image_path, expected_result):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(

View File

@ -87,12 +87,9 @@ py_library(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",

View File

@ -20,13 +20,9 @@ from mediapipe.python import packet_creator
from mediapipe.python import packet_getter from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
@ -38,12 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions _HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions
_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -64,6 +57,7 @@ _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' _HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
_GESTURE_DEFAULT_INDEX = -1
@dataclasses.dataclass @dataclasses.dataclass
@ -72,8 +66,9 @@ class GestureRecognitionResult:
element represents a single hand detected in the image. element represents a single hand detected in the image.
Attributes: Attributes:
gestures: Recognized hand gestures with sorted order such that the gestures: Recognized hand gestures of detected hands. Note that the index
winning label is the first item in the list. of the gesture is always 0, because the raw indices from multiple gesture
classifiers cannot consolidate to a meaningful index.
handedness: Classification of handedness. handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates. hand_landmarks: Detected hand landmarks in normalized image coordinates.
hand_world_landmarks: Detected hand landmarks in world coordinates. hand_world_landmarks: Detected hand landmarks in world coordinates.
@ -101,16 +96,16 @@ def _build_recognition_result(
[ [
[ [
category_module.Category( category_module.Category(
index=gesture.index, score=gesture.score, index=_GESTURE_DEFAULT_INDEX, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label) display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification] for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list for gesture_classifications in gestures_proto_list
], [ ], [
[ [
category_module.Category( category_module.Category(
index=gesture.index, score=gesture.score, index=handedness.index, score=handedness.score,
display_name=gesture.display_name, category_name=gesture.label) display_name=handedness.display_name, category_name=handedness.label)
for gesture in handedness_classifications.classification] for handedness in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list for handedness_classifications in handedness_proto_list
], [ ], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
@ -170,26 +165,17 @@ class GestureRecognizerOptions:
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Configure hand detector options. # Configure hand detector and hand landmarker options.
hand_detector_options_proto = _HandDetectorGraphOptionsProto( hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto()
num_hands=self.num_hands, hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
min_detection_confidence=self.min_hand_detection_confidence) hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
# Configure hand landmarker options. hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto(
min_detection_confidence=self.min_hand_presence_confidence)
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
hand_detector_graph_options=hand_detector_options_proto,
hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto,
min_tracking_confidence=self.min_tracking_confidence)
# Configure hand gesture recognizer options. # Configure hand gesture recognizer options.
classifier_options = _ClassifierOptions( hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto()
score_threshold=self.min_gesture_confidence) hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
gesture_classifier_options = _GestureClassifierGraphOptionsProto( hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
classifier_options=classifier_options.to_pb2())
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto(
canned_gesture_classifier_graph_options=gesture_classifier_options)
return _GestureRecognizerGraphOptionsProto( return _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,

View File

@ -130,6 +130,7 @@ filegroup(
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task", "hand_landmarker.task",
"gesture_recognizer.task", "gesture_recognizer.task",
"gesture_recognizer_with_custom_classifier.task",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",