Revised gesture recognizer implementation
This commit is contained in:
parent
5ec87c8bd2
commit
19be9e9012
|
@ -47,22 +47,26 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult
|
|||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task'
|
||||
_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = 'gesture_recognizer.task'
|
||||
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE = 'gesture_recognizer_with_custom_classifier.task'
|
||||
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
|
||||
_TWO_HANDS_IMAGE = 'right_hands.jpg'
|
||||
_FIST_IMAGE = 'fist.jpg'
|
||||
_FIST_LANDMARKS = 'fist_landmarks.pbtxt'
|
||||
_FIST_LABEL = 'Closed_Fist'
|
||||
_THUMB_UP_IMAGE = 'thumb_up.jpg'
|
||||
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
|
||||
_THUMB_UP_LABEL = 'Thumb_Up'
|
||||
_THUMB_UP_INDEX = 5
|
||||
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
|
||||
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
|
||||
_POINTING_UP_LABEL = 'Pointing_Up'
|
||||
_POINTING_UP_INDEX = 3
|
||||
_ROCK_LABEL = "Rock"
|
||||
_LANDMARKS_ERROR_TOLERANCE = 0.03
|
||||
_GESTURE_EXPECTED_INDEX = -1
|
||||
|
||||
|
||||
def _get_expected_gesture_recognition_result(
|
||||
file_path: str, gesture_label: str, gesture_index: int
|
||||
file_path: str, gesture_label: str
|
||||
) -> _GestureRecognitionResult:
|
||||
landmarks_detection_result_file_path = test_utils.get_test_data_path(
|
||||
file_path)
|
||||
|
@ -73,7 +77,8 @@ def _get_expected_gesture_recognition_result(
|
|||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
landmarks_detection_result_proto)
|
||||
gesture = _Category(category_name=gesture_label, index=gesture_index,
|
||||
gesture = _Category(category_name=gesture_label,
|
||||
index=_GESTURE_EXPECTED_INDEX,
|
||||
display_name='')
|
||||
return _GestureRecognitionResult(
|
||||
gestures=[[gesture]],
|
||||
|
@ -94,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_GESTURE_RECOGNIZER_MODEL_FILE)
|
||||
_GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
|
||||
|
||||
def _assert_actual_result_approximately_matches_expected_result(
|
||||
self,
|
||||
|
@ -127,7 +132,7 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
# Actual gesture with top score matches expected gesture.
|
||||
actual_top_gesture = actual_result.gestures[0][0]
|
||||
expected_top_gesture = expected_result.gestures[0][0]
|
||||
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
||||
self.assertEqual(actual_top_gesture.index, _GESTURE_EXPECTED_INDEX)
|
||||
self.assertEqual(actual_top_gesture.category_name,
|
||||
expected_top_gesture.category_name)
|
||||
|
||||
|
@ -163,10 +168,10 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
|
||||
)),
|
||||
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
|
||||
)))
|
||||
def test_recognize(self, model_file_type, expected_recognition_result):
|
||||
# Creates gesture recognizer.
|
||||
|
@ -194,10 +199,10 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
|
||||
)),
|
||||
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL
|
||||
)))
|
||||
def test_recognize_in_context(self, model_file_type,
|
||||
expected_recognition_result):
|
||||
|
@ -224,12 +229,12 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
# Creates gesture recognizer.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _GestureRecognizerOptions(base_options=base_options,
|
||||
min_gesture_confidence=2)
|
||||
min_gesture_confidence=0.5)
|
||||
with _GestureRecognizer.create_from_options(options) as recognizer:
|
||||
# Performs hand gesture recognition on the input.
|
||||
recognition_result = recognizer.recognize(self.test_image)
|
||||
expected_result = _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
|
||||
# Only contains one top scoring gesture.
|
||||
self.assertLen(recognition_result.gestures[0], 1)
|
||||
# Actual gesture with top score matches expected gesture.
|
||||
|
@ -266,11 +271,29 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
recognition_result = recognizer.recognize(test_image,
|
||||
image_processing_options)
|
||||
expected_recognition_result = _get_expected_gesture_recognition_result(
|
||||
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX)
|
||||
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL)
|
||||
# Comparing results.
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
recognition_result, expected_recognition_result)
|
||||
|
||||
def test_recognize_succeeds_with_custom_gesture_fist(self):
|
||||
# Creates gesture recognizer.
|
||||
model_path = test_utils.get_test_data_path(
|
||||
_GESTURE_RECOGNIZER_WITH_CUSTOM_CLASSIFIER_BUNDLE_ASSET_FILE)
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
|
||||
with _GestureRecognizer.create_from_options(options) as recognizer:
|
||||
# Load the fist image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_FIST_IMAGE))
|
||||
# Performs hand gesture recognition on the input.
|
||||
recognition_result = recognizer.recognize(test_image)
|
||||
expected_recognition_result = _get_expected_gesture_recognition_result(
|
||||
_FIST_LANDMARKS, _ROCK_LABEL)
|
||||
# Comparing results.
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
recognition_result, expected_recognition_result)
|
||||
|
||||
def test_recognize_fails_with_region_of_interest(self):
|
||||
# Creates gesture recognizer.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -373,7 +396,7 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
recognition_result = recognizer.recognize_for_video(self.test_image,
|
||||
timestamp)
|
||||
expected_recognition_result = _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
recognition_result, expected_recognition_result)
|
||||
|
||||
|
@ -410,7 +433,7 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)),
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL)),
|
||||
(_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], [])))
|
||||
def test_recognize_async_calls(self, image_path, expected_result):
|
||||
test_image = _Image.create_from_file(
|
||||
|
|
|
@ -87,12 +87,9 @@ py_library(
|
|||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
|
|
|
@ -20,13 +20,9 @@ from mediapipe.python import packet_creator
|
|||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_recognizer_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
|
@ -38,12 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions
|
||||
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
|
||||
_HandGestureRecognizerGraphOptionsProto = hand_gesture_recognizer_graph_options_pb2.HandGestureRecognizerGraphOptions
|
||||
_HandDetectorGraphOptionsProto = hand_detector_graph_options_pb2.HandDetectorGraphOptions
|
||||
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
|
||||
_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
@ -64,6 +57,7 @@ _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
|||
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
_GESTURE_DEFAULT_INDEX = -1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -72,8 +66,9 @@ class GestureRecognitionResult:
|
|||
element represents a single hand detected in the image.
|
||||
|
||||
Attributes:
|
||||
gestures: Recognized hand gestures with sorted order such that the
|
||||
winning label is the first item in the list.
|
||||
gestures: Recognized hand gestures of detected hands. Note that the index
|
||||
of the gesture is always 0, because the raw indices from multiple gesture
|
||||
classifiers cannot consolidate to a meaningful index.
|
||||
handedness: Classification of handedness.
|
||||
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||
|
@ -101,16 +96,16 @@ def _build_recognition_result(
|
|||
[
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
index=_GESTURE_DEFAULT_INDEX, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
index=handedness.index, score=handedness.score,
|
||||
display_name=handedness.display_name, category_name=handedness.label)
|
||||
for handedness in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
|
@ -170,26 +165,17 @@ class GestureRecognizerOptions:
|
|||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
|
||||
# Configure hand detector options.
|
||||
hand_detector_options_proto = _HandDetectorGraphOptionsProto(
|
||||
num_hands=self.num_hands,
|
||||
min_detection_confidence=self.min_hand_detection_confidence)
|
||||
|
||||
# Configure hand landmarker options.
|
||||
hand_landmarks_detector_options_proto = _HandLandmarksDetectorGraphOptionsProto(
|
||||
min_detection_confidence=self.min_hand_presence_confidence)
|
||||
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
|
||||
hand_detector_graph_options=hand_detector_options_proto,
|
||||
hand_landmarks_detector_graph_options=hand_landmarks_detector_options_proto,
|
||||
min_tracking_confidence=self.min_tracking_confidence)
|
||||
# Configure hand detector and hand landmarker options.
|
||||
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto()
|
||||
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||
|
||||
# Configure hand gesture recognizer options.
|
||||
classifier_options = _ClassifierOptions(
|
||||
score_threshold=self.min_gesture_confidence)
|
||||
gesture_classifier_options = _GestureClassifierGraphOptionsProto(
|
||||
classifier_options=classifier_options.to_pb2())
|
||||
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto(
|
||||
canned_gesture_classifier_graph_options=gesture_classifier_options)
|
||||
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto()
|
||||
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
|
||||
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.score_threshold = self.min_gesture_confidence
|
||||
|
||||
return _GestureRecognizerGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
|
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -130,6 +130,7 @@ filegroup(
|
|||
"hand_landmark_lite.tflite",
|
||||
"hand_landmarker.task",
|
||||
"gesture_recognizer.task",
|
||||
"gesture_recognizer_with_custom_classifier.task",
|
||||
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v1_0.25_224_1_metadata_1.tflite",
|
||||
|
|
Loading…
Reference in New Issue
Block a user