Add allow_list/deny_list support

This commit is contained in:
kinaryml 2022-10-31 22:16:37 -07:00
parent 1aaaca1e12
commit d3b472e888
4 changed files with 127 additions and 29 deletions

View File

@ -88,6 +88,7 @@ py_test(
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:gesture_recognizer",

View File

@ -27,6 +27,7 @@ from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import gesture_recognizer
@ -40,6 +41,7 @@ _Category = category_module.Category
_Landmark = landmark_module.Landmark
_NormalizedLandmark = landmark_module.NormalizedLandmark
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
_ClassifierOptions = classifier_options.ClassifierOptions
_Image = image_module.Image
_GestureRecognizer = gesture_recognizer.GestureRecognizer
_GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
@ -59,10 +61,12 @@ _VICTORY_LABEL = 'Victory'
_THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = 'Thumb_Up'
_POINTING_UP_IMAGE = 'pointing_up.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_LABEL = 'Pointing_Up'
_ROCK_LABEL = "Rock"
_ROCK_LABEL = 'Rock'
_LANDMARKS_ERROR_TOLERANCE = 0.03
_GESTURE_EXPECTED_INDEX = -1
@ -227,11 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_min_gesture_confidence(self):
def test_recognize_succeeds_with_score_threshold(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options,
min_gesture_confidence=0.5)
canned_gesture_classifier_options = _ClassifierOptions(score_threshold=.5)
options = _GestureRecognizerOptions(
base_options=base_options,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image)
@ -273,7 +279,7 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize(test_image,
image_processing_options)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
_POINTING_UP_ROTATED_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
@ -294,14 +300,14 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_custom_gesture_fist(self):
def test_recognize_succeeds_with_custom_gesture_rock(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the fist image.
# Load the rock image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_FIST_IMAGE))
# Performs hand gesture recognition on the input.
@ -312,6 +318,95 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_allow_gesture_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
category_allowlist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_deny_gesture_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
category_denylist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
actual_top_gesture = recognition_result.gestures[0][0]
self.assertEqual(actual_top_gesture.category_name, 'None')
def test_recognize_succeeds_with_allow_all_gestures_except_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
score_threshold=.5, category_allowlist=[
'None', 'Open_Palm', 'Victory', 'Thumb_Down', 'Thumb_Up',
'ILoveYou', 'Closed_Fist'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
actual_top_gesture = recognition_result.gestures[0][0]
self.assertEqual(actual_top_gesture.category_name, 'None')
def test_recognize_succeeds_with_prefer_allow_list_than_deny_list(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
score_threshold=.5, category_allowlist=['Pointing_Up'],
category_denylist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_fails_with_region_of_interest(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)

View File

@ -88,10 +88,9 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -21,10 +21,9 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -34,8 +33,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
_BaseOptions = base_options_module.BaseOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -137,11 +135,16 @@ class GestureRecognizerOptions:
score in the hand landmark detection.
min_tracking_confidence: The minimum confidence score for the hand tracking
to be considered successful.
min_gesture_confidence: The minimum confidence score for the gestures to be
considered successful. If < 0, the gesture confidence thresholds in the
model metadata are used.
TODO: Note this option is subject to change, after scoring merging
calculator is implemented.
canned_gesture_classifier_options: Options for configuring the canned
gestures classifier, such as score threshold, allow list and deny list of
gestures. The categories for canned gesture classifiers are:
["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down",
"Thumb_Up", "Victory", "ILoveYou"]
TODO :Note this option is subject to change.
custom_gesture_classifier_options: Options for configuring the custom
gestures classifier, such as score threshold, allow list and deny list of
gestures.
TODO :Note this option is subject to change.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
@ -152,7 +155,8 @@ class GestureRecognizerOptions:
min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5
min_gesture_confidence: Optional[float] = -1
canned_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions()
custom_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions()
result_callback: Optional[
Callable[[GestureRecognitionResult, image_module.Image,
int], None]] = None
@ -163,23 +167,22 @@ class GestureRecognizerOptions:
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Initialize gesture recognizer options from base options.
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto)
# Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto()
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
# Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto()
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(self.canned_gesture_classifier_options.to_pb2())
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(self.custom_gesture_classifier_options.to_pb2())
return _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto,
hand_landmarker_graph_options=hand_landmarker_options_proto,
hand_gesture_recognizer_graph_options=hand_gesture_recognizer_options_proto
)
return gesture_recognizer_options_proto
class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):