Add allow_list/deny_list support

This commit is contained in:
kinaryml 2022-10-31 22:16:37 -07:00
parent 1aaaca1e12
commit d3b472e888
4 changed files with 127 additions and 29 deletions

View File

@ -88,6 +88,7 @@ py_test(
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:landmark_detection_result", "//mediapipe/tasks/python/components/containers:landmark_detection_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:gesture_recognizer", "//mediapipe/tasks/python/vision:gesture_recognizer",

View File

@ -27,6 +27,7 @@ from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import gesture_recognizer from mediapipe.tasks.python.vision import gesture_recognizer
@ -40,6 +41,7 @@ _Category = category_module.Category
_Landmark = landmark_module.Landmark _Landmark = landmark_module.Landmark
_NormalizedLandmark = landmark_module.NormalizedLandmark _NormalizedLandmark = landmark_module.NormalizedLandmark
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult _LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
_ClassifierOptions = classifier_options.ClassifierOptions
_Image = image_module.Image _Image = image_module.Image
_GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizer = gesture_recognizer.GestureRecognizer
_GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
@ -59,10 +61,12 @@ _VICTORY_LABEL = 'Victory'
_THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = 'Thumb_Up' _THUMB_UP_LABEL = 'Thumb_Up'
_POINTING_UP_IMAGE = 'pointing_up.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg' _POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt' _POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_LABEL = 'Pointing_Up' _POINTING_UP_LABEL = 'Pointing_Up'
_ROCK_LABEL = "Rock" _ROCK_LABEL = 'Rock'
_LANDMARKS_ERROR_TOLERANCE = 0.03 _LANDMARKS_ERROR_TOLERANCE = 0.03
_GESTURE_EXPECTED_INDEX = -1 _GESTURE_EXPECTED_INDEX = -1
@ -227,11 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_min_gesture_confidence(self): def test_recognize_succeeds_with_score_threshold(self):
# Creates gesture recognizer. # Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, canned_gesture_classifier_options = _ClassifierOptions(score_threshold=.5)
min_gesture_confidence=0.5) options = _GestureRecognizerOptions(
base_options=base_options,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer: with _GestureRecognizer.create_from_options(options) as recognizer:
# Performs hand gesture recognition on the input. # Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image) recognition_result = recognizer.recognize(self.test_image)
@ -273,7 +279,7 @@ class GestureRecognizerTest(parameterized.TestCase):
recognition_result = recognizer.recognize(test_image, recognition_result = recognizer.recognize(test_image,
image_processing_options) image_processing_options)
expected_recognition_result = _get_expected_gesture_recognition_result( expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL) _POINTING_UP_ROTATED_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results. # Comparing results.
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
@ -294,14 +300,14 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_custom_gesture_fist(self): def test_recognize_succeeds_with_custom_gesture_rock(self):
# Creates gesture recognizer. # Creates gesture recognizer.
model_path = test_utils.get_test_data_path( model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE) _GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1) options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer: with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the fist image. # Load the rock image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path(_FIST_IMAGE)) test_utils.get_test_data_path(_FIST_IMAGE))
# Performs hand gesture recognition on the input. # Performs hand gesture recognition on the input.
@ -312,6 +318,95 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_allow_gesture_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
category_allowlist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_deny_gesture_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
category_denylist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
actual_top_gesture = recognition_result.gestures[0][0]
self.assertEqual(actual_top_gesture.category_name, 'None')
def test_recognize_succeeds_with_allow_all_gestures_except_pointing_up(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
score_threshold=.5, category_allowlist=[
'None', 'Open_Palm', 'Victory', 'Thumb_Down', 'Thumb_Up',
'ILoveYou', 'Closed_Fist'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
actual_top_gesture = recognition_result.gestures[0][0]
self.assertEqual(actual_top_gesture.category_name, 'None')
def test_recognize_succeeds_with_prefer_allow_list_than_deny_list(self):
# Creates gesture recognizer.
model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
base_options = _BaseOptions(model_asset_path=model_path)
canned_gesture_classifier_options = _ClassifierOptions(
score_threshold=.5, category_allowlist=['Pointing_Up'],
category_denylist=['Pointing_Up'])
options = _GestureRecognizerOptions(
base_options=base_options,
num_hands=1,
canned_gesture_classifier_options=canned_gesture_classifier_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_fails_with_region_of_interest(self): def test_recognize_fails_with_region_of_interest(self):
# Creates gesture recognizer. # Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)

View File

@ -88,10 +88,9 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",

View File

@ -21,10 +21,9 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2 from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -34,8 +33,7 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions _ClassifierOptions = classifier_options.ClassifierOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -137,11 +135,16 @@ class GestureRecognizerOptions:
score in the hand landmark detection. score in the hand landmark detection.
min_tracking_confidence: The minimum confidence score for the hand tracking min_tracking_confidence: The minimum confidence score for the hand tracking
to be considered successful. to be considered successful.
min_gesture_confidence: The minimum confidence score for the gestures to be canned_gesture_classifier_options: Options for configuring the canned
considered successful. If < 0, the gesture confidence thresholds in the gestures classifier, such as score threshold, allow list and deny list of
model metadata are used. gestures. The categories for canned gesture classifiers are:
TODO: Note this option is subject to change, after scoring merging ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down",
calculator is implemented. "Thumb_Up", "Victory", "ILoveYou"]
TODO :Note this option is subject to change.
custom_gesture_classifier_options: Options for configuring the custom
gestures classifier, such as score threshold, allow list and deny list of
gestures.
TODO :Note this option is subject to change.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
@ -152,7 +155,8 @@ class GestureRecognizerOptions:
min_hand_detection_confidence: Optional[float] = 0.5 min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5
min_gesture_confidence: Optional[float] = -1 canned_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions()
custom_gesture_classifier_options: Optional[_ClassifierOptions] = _ClassifierOptions()
result_callback: Optional[ result_callback: Optional[
Callable[[GestureRecognitionResult, image_module.Image, Callable[[GestureRecognitionResult, image_module.Image,
int], None]] = None int], None]] = None
@ -163,23 +167,22 @@ class GestureRecognizerOptions:
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Initialize gesture recognizer options from base options.
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto)
# Configure hand detector and hand landmarker options. # Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto() hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
# Configure hand gesture recognizer options. # Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(self.canned_gesture_classifier_options.to_pb2())
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(self.custom_gesture_classifier_options.to_pb2())
return _GestureRecognizerGraphOptionsProto( return gesture_recognizer_options_proto
base_options=base_options_proto,
hand_landmarker_graph_options=hand_landmarker_options_proto,
hand_gesture_recognizer_graph_options=hand_gesture_recognizer_options_proto
)
class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):