diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 8f7c66519..916bd3e0c 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -47,22 +47,26 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' +_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = 'gesture_recognizer.task' +_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = 'gesture_recognizer_with_custom_classifier.task' _NO_HANDS_IMAGE = 'cats_and_dogs.jpg' _TWO_HANDS_IMAGE = 'right_hands.jpg' +_FIST_IMAGE = 'fist.jpg' +_FIST_LANDMARKS = 'fist_landmarks.pbtxt' +_FIST_LABEL = 'Closed_Fist' _THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LABEL = 'Thumb_Up' -_THUMB_UP_INDEX = 5 _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' _POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_LABEL = 'Pointing_Up' -_POINTING_UP_INDEX = 3 +_ROCK_LABEL = "Rock" _LANDMARKS_ERROR_TOLERANCE = 0.03 +_GESTURE_EXPECTED_INDEX = -1 def _get_expected_gesture_recognition_result( - file_path: str, gesture_label: str, gesture_index: int + file_path: str, gesture_label: str ) -> _GestureRecognitionResult: landmarks_detection_result_file_path = test_utils.get_test_data_path( file_path) @@ -73,7 +77,8 @@ def _get_expected_gesture_recognition_result( text_format.Parse(f.read(), landmarks_detection_result_proto) landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( landmarks_detection_result_proto) - gesture = _Category(category_name=gesture_label, index=gesture_index, + gesture = _Category(category_name=gesture_label, + index=_GESTURE_EXPECTED_INDEX, display_name='') return _GestureRecognitionResult( gestures=[[gesture]], @@ -94,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase): self.test_image = _Image.create_from_file( test_utils.get_test_data_path(_THUMB_UP_IMAGE)) self.model_path = test_utils.get_test_data_path( - _GESTURE_RECOGNIZER_MODEL_FILE) + _GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) def _assert_actual_result_approximately_matches_expected_result( self, @@ -127,7 +132,7 @@ class GestureRecognizerTest(parameterized.TestCase): # Actual gesture with top score matches expected gesture. actual_top_gesture = actual_result.gestures[0][0] expected_top_gesture = expected_result.gestures[0][0] - self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) + self.assertEqual(actual_top_gesture.index, _GESTURE_EXPECTED_INDEX) self.assertEqual(actual_top_gesture.category_name, expected_top_gesture.category_name) @@ -163,10 +168,10 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL )), (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL ))) def test_recognize(self, model_file_type, expected_recognition_result): # Creates gesture recognizer. @@ -194,10 +199,10 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL )), (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL ))) def test_recognize_in_context(self, model_file_type, expected_recognition_result): @@ -224,12 +229,12 @@ class GestureRecognizerTest(parameterized.TestCase): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) options = _GestureRecognizerOptions(base_options=base_options, - min_gesture_confidence=2) + min_gesture_confidence=0.5) with _GestureRecognizer.create_from_options(options) as recognizer: # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) expected_result = _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL) # Only contains one top scoring gesture. self.assertLen(recognition_result.gestures[0], 1) # Actual gesture with top score matches expected gesture. @@ -266,11 +271,29 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize(test_image, image_processing_options) expected_recognition_result = _get_expected_gesture_recognition_result( - _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX) + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) # Comparing results. self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_custom_gesture_fist(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the fist image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_FIST_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _FIST_LANDMARKS, _ROCK_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + def test_recognize_fails_with_region_of_interest(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) @@ -373,7 +396,7 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize_for_video(self.test_image, timestamp) expected_recognition_result = _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX) + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL) self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) @@ -410,7 +433,7 @@ class GestureRecognizerTest(parameterized.TestCase): @parameterized.parameters( (_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result( - _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)), + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)), (_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], []))) def test_recognize_async_calls(self, image_path, expected_result): test_image = _Image.create_from_file( diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 87de5b987..66c9ece65 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -87,12 +87,9 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/processors:classifier_options", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index c6d30dc4e..9a2e3ba29 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -20,13 +20,9 @@ from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module -from mediapipe.python._framework_bindings import task_runner as task_runner_module -from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.processors import classifier_options @@ -38,12 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module _BaseOptions = base_options_module.BaseOptions -_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions -_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions -_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -64,6 +57,7 @@ _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks' _HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +_GESTURE_DEFAULT_INDEX = -1 @dataclasses.dataclass @@ -72,8 +66,9 @@ class GestureRecognitionResult: element represents a single hand detected in the image. Attributes: - gestures: Recognized hand gestures with sorted order such that the - winning label is the first item in the list. + gestures: Recognized hand gestures of detected hands. Note that the index + of the gesture is always 0, because the raw indices from multiple gesture + classifiers cannot consolidate to a meaningful index. handedness: Classification of handedness. hand_landmarks: Detected hand landmarks in normalized image coordinates. hand_world_landmarks: Detected hand landmarks in world coordinates. @@ -101,16 +96,16 @@ def _build_recognition_result( [ [ category_module.Category( - index=gesture.index, score=gesture.score, + index=_GESTURE_DEFAULT_INDEX, score=gesture.score, display_name=gesture.display_name, category_name=gesture.label) for gesture in gesture_classifications.classification] for gesture_classifications in gestures_proto_list ], [ [ category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] + index=handedness.index, score=handedness.score, + display_name=handedness.display_name, category_name=handedness.label) + for handedness in handedness_classifications.classification] for handedness_classifications in handedness_proto_list ], [ [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) @@ -170,26 +165,17 @@ class GestureRecognizerOptions: base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - # Configure hand detector options. - hand_detector_options_proto = _HandDetectorGraphOptionsProto( - num_hands=self.num_hands, - min_detection_confidence=self.min_hand_detection_confidence) - - # Configure hand landmarker options. - hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto( - min_detection_confidence=self.min_hand_presence_confidence) - hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto( - hand_detector_graph_options=hand_detector_options_proto, - hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto, - min_tracking_confidence=self.min_tracking_confidence) + # Configure hand detector and hand landmarker options. + hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto() + hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence + hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands + hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence + hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence # Configure hand gesture recognizer options. - classifier_options = _ClassifierOptions( - score_threshold=self.min_gesture_confidence) - gesture_classifier_options = _GestureClassifierGraphOptionsProto( - classifier_options=classifier_options.to_pb2()) - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto( - canned_gesture_classifier_graph_options=gesture_classifier_options) + hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() + hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence + hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence return _GestureRecognizerGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 0545c5cca..c7265f5c9 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -130,6 +130,7 @@ filegroup( "hand_landmark_lite.tflite", "hand_landmarker.task", "gesture_recognizer.task", + "gesture_recognizer_with_custom_classifier.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",