From d3b472e888ae7b62b7dd921949b3e9db71c37303 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Mon, 31 Oct 2022 22:16:37 -0700 Subject: [PATCH] Add allow_list/deny_list support --- mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/gesture_recognizer_test.py | 111 ++++++++++++++++-- mediapipe/tasks/python/vision/BUILD | 3 +- .../tasks/python/vision/gesture_recognizer.py | 41 ++++--- 4 files changed, 127 insertions(+), 29 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 40afe22b8..da8ad3f83 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -88,6 +88,7 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark_detection_result", + "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:gesture_recognizer", diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index e8aa61883..d5cd72cd7 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -27,6 +27,7 @@ from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module +from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import gesture_recognizer @@ -40,6 +41,7 @@ _Category = category_module.Category _Landmark = landmark_module.Landmark _NormalizedLandmark = landmark_module.NormalizedLandmark _LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult +_ClassifierOptions = classifier_options.ClassifierOptions _Image = image_module.Image _GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions @@ -59,10 +61,12 @@ _VICTORY_LABEL = 'Victory' _THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LABEL = 'Thumb_Up' +_POINTING_UP_IMAGE = 'pointing_up.jpg' +_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt' _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' -_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' +_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_LABEL = 'Pointing_Up' -_ROCK_LABEL = "Rock" +_ROCK_LABEL = 'Rock' _LANDMARKS_ERROR_TOLERANCE = 0.03 _GESTURE_EXPECTED_INDEX = -1 @@ -227,11 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) - def test_recognize_succeeds_with_min_gesture_confidence(self): + def test_recognize_succeeds_with_score_threshold(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) - options = _GestureRecognizerOptions(base_options=base_options, - min_gesture_confidence=0.5) + canned_gesture_classifier_options = _ClassifierOptions(score_threshold=.5) + options = _GestureRecognizerOptions( + base_options=base_options, + canned_gesture_classifier_options=canned_gesture_classifier_options) with _GestureRecognizer.create_from_options(options) as recognizer: # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) @@ -273,7 +279,7 @@ class GestureRecognizerTest(parameterized.TestCase): recognition_result = recognizer.recognize(test_image, image_processing_options) expected_recognition_result = _get_expected_gesture_recognition_result( - _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + _POINTING_UP_ROTATED_LANDMARKS, _POINTING_UP_LABEL) # Comparing results. self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) @@ -294,14 +300,14 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) - def test_recognize_succeeds_with_custom_gesture_fist(self): + def test_recognize_succeeds_with_custom_gesture_rock(self): # Creates gesture recognizer. model_path = test_utils.get_test_data_path( _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) base_options = _BaseOptions(model_asset_path=model_path) options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) with _GestureRecognizer.create_from_options(options) as recognizer: - # Load the fist image. + # Load the rock image. test_image = _Image.create_from_file( test_utils.get_test_data_path(_FIST_IMAGE)) # Performs hand gesture recognition on the input. @@ -312,6 +318,95 @@ class GestureRecognizerTest(parameterized.TestCase): self._assert_actual_result_approximately_matches_expected_result( recognition_result, expected_recognition_result) + def test_recognize_succeeds_with_allow_gesture_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + category_allowlist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + + def test_recognize_succeeds_with_deny_gesture_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + category_denylist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + actual_top_gesture = recognition_result.gestures[0][0] + self.assertEqual(actual_top_gesture.category_name, 'None') + + def test_recognize_succeeds_with_allow_all_gestures_except_pointing_up(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + score_threshold=.5, category_allowlist=[ + 'None', 'Open_Palm', 'Victory', 'Thumb_Down', 'Thumb_Up', + 'ILoveYou', 'Closed_Fist']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + actual_top_gesture = recognition_result.gestures[0][0] + self.assertEqual(actual_top_gesture.category_name, 'None') + + def test_recognize_succeeds_with_prefer_allow_list_than_deny_list(self): + # Creates gesture recognizer. + model_path = test_utils.get_test_data_path( + _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) + base_options = _BaseOptions(model_asset_path=model_path) + canned_gesture_classifier_options = _ClassifierOptions( + score_threshold=.5, category_allowlist=['Pointing_Up'], + category_denylist=['Pointing_Up']) + options = _GestureRecognizerOptions( + base_options=base_options, + num_hands=1, + canned_gesture_classifier_options=canned_gesture_classifier_options) + with _GestureRecognizer.create_from_options(options) as recognizer: + # Load the pointing up image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(_POINTING_UP_IMAGE)) + # Performs hand gesture recognition on the input. + recognition_result = recognizer.recognize(test_image) + expected_recognition_result = _get_expected_gesture_recognition_result( + _POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) + # Comparing results. + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) + def test_recognize_fails_with_region_of_interest(self): # Creates gesture recognizer. base_options = _BaseOptions(model_asset_path=self.model_path) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 0505471e8..dec149908 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -88,10 +88,9 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 82dc00f19..2659f9a03 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -21,10 +21,9 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2 -from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -34,8 +33,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image _BaseOptions = base_options_module.BaseOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions -_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions -_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -137,11 +135,16 @@ class GestureRecognizerOptions: score in the hand landmark detection. min_tracking_confidence: The minimum confidence score for the hand tracking to be considered successful. - min_gesture_confidence: The minimum confidence score for the gestures to be - considered successful. If < 0, the gesture confidence thresholds in the - model metadata are used. - TODO: Note this option is subject to change, after scoring merging - calculator is implemented. + canned_gesture_classifier_options: Options for configuring the canned + gestures classifier, such as score threshold, allow list and deny list of + gestures. The categories for canned gesture classifiers are: + ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", + "Thumb_Up", "Victory", "ILoveYou"] + TODO :Note this option is subject to change. + custom_gesture_classifier_options: Options for configuring the custom + gestures classifier, such as score threshold, allow list and deny list of + gestures. + TODO :Note this option is subject to change. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. @@ -152,7 +155,8 @@ class GestureRecognizerOptions: min_hand_detection_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 - min_gesture_confidence: Optional[float] = -1 + canned_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions() + custom_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions() result_callback: Optional[ Callable[[GestureRecognitionResult, image_module.Image, int], None]] = None @@ -163,23 +167,22 @@ class GestureRecognizerOptions: base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + # Initialize gesture recognizer options from base options. + gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto( + base_options=base_options_proto) # Configure hand detector and hand landmarker options. - hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto() + hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence # Configure hand gesture recognizer options. - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() - hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence - hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence + hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options + hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(self.canned_gesture_classifier_options.to_pb2()) + hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(self.custom_gesture_classifier_options.to_pb2()) - return _GestureRecognizerGraphOptionsProto( - base_options=base_options_proto, - hand_landmarker_graph_options=hand_landmarker_options_proto, - hand_gesture_recognizer_graph_options=hand_gesture_recognizer_options_proto - ) + return gesture_recognizer_options_proto class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):