Added a test for the canned classification of the gesture victory

This commit is contained in:
kinaryml 2022-10-31 05:47:28 -07:00
parent 888ddd4b74
commit d635b4281e

View File

@ -53,7 +53,9 @@ _NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
_TWO_HANDS_IMAGE = 'right_hands.jpg' _TWO_HANDS_IMAGE = 'right_hands.jpg'
_FIST_IMAGE = 'fist.jpg' _FIST_IMAGE = 'fist.jpg'
_FIST_LANDMARKS = 'fist_landmarks.pbtxt' _FIST_LANDMARKS = 'fist_landmarks.pbtxt'
_FIST_LABEL = 'Closed_Fist' _VICTORY_IMAGE = 'victory.jpg'
_VICTORY_LANDMARKS = 'victory_landmarks.pbtxt'
_VICTORY_LABEL = 'Victory'
_THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt' _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = 'Thumb_Up' _THUMB_UP_LABEL = 'Thumb_Up'
@ -276,6 +278,22 @@ class GestureRecognizerTest(parameterized.TestCase):
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_canned_gesture_victory(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the fist image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_VICTORY_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
expected_recognition_result = _get_expected_gesture_recognition_result(
_VICTORY_LANDMARKS, _VICTORY_LABEL)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_custom_gesture_fist(self): def test_recognize_succeeds_with_custom_gesture_fist(self):
# Creates gesture recognizer. # Creates gesture recognizer.
model_path = test_utils.get_test_data_path( model_path = test_utils.get_test_data_path(