Added a simple test to verify gesture recognition results

This commit is contained in:
kinaryml 2022-10-25 07:38:04 -07:00
parent 9a1a9d4c13
commit 18eb089d39
8 changed files with 184 additions and 168 deletions

View File

@ -55,11 +55,13 @@ py_library(
)
py_library(
name = "gesture",
srcs = ["gesture.py"],
name = "landmark_detection_result",
srcs = ["landmark_detection_result.py"],
deps = [
":rect",
":classification",
":landmark",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -14,14 +14,13 @@
"""Classification data class."""
import dataclasses
from typing import Any, List
from typing import Any, List, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_ClassificationListCollectionProto = classification_pb2.ClassificationListCollection
@dataclasses.dataclass
@ -35,10 +34,10 @@ class Classification:
display_name: Optional human-readable string for display purposes.
"""
index: int
score: float
label_name: str
display_name: str
index: Optional[int] = None
score: Optional[float] = None
label: Optional[str] = None
display_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationProto:
@ -46,7 +45,7 @@ class Classification:
return _ClassificationProto(
index=self.index,
score=self.score,
label_name=self.label_name,
label=self.label,
display_name=self.display_name)
@classmethod
@ -56,7 +55,7 @@ class Classification:
return Classification(
index=pb2_obj.index,
score=pb2_obj.score,
label_name=pb2_obj.label_name,
label=pb2_obj.label,
display_name=pb2_obj.display_name)
def __eq__(self, other: Any) -> bool:
@ -86,8 +85,8 @@ class ClassificationList:
"""
classifications: List[Classification]
tensor_index: int
tensor_name: str
tensor_index: Optional[int] = None
tensor_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationListProto:

View File

@ -1,138 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gesture data class."""
import dataclasses
from typing import Any, List
from mediapipe.tasks.python.components.containers import classification
from mediapipe.tasks.python.components.containers import landmark
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@dataclasses.dataclass
class GestureRecognitionResult:
""" The gesture recognition result from GestureRecognizer, where each vector
element represents a single hand detected in the image.
Attributes:
gestures: Recognized hand gestures with sorted order such that the
winning label is the first item in the list.
handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates.
hand_world_landmarks: Detected hand landmarks in world coordinates.
"""
gestures: List[classification.ClassificationList]
handedness: List[classification.ClassificationList]
hand_landmarks: List[landmark.NormalizedLandmarkList]
hand_world_landmarks: List[landmark.LandmarkList]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionProto:
"""Generates a Detection protobuf object."""
labels = []
label_ids = []
scores = []
display_names = []
for category in self.categories:
scores.append(category.score)
if category.index:
label_ids.append(category.index)
if category.category_name:
labels.append(category.category_name)
if category.display_name:
display_names.append(category.display_name)
return _DetectionProto(
label=labels,
label_id=label_ids,
score=scores,
display_name=display_names,
location_data=_LocationDataProto(
format=_LocationDataProto.Format.BOUNDING_BOX,
bounding_box=self.bounding_box.to_pb2()))
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
"""Creates a `Detection` object from the given protobuf object."""
categories = []
for idx, score in enumerate(pb2_obj.score):
categories.append(
category_module.Category(
score=score,
index=pb2_obj.label_id[idx]
if idx < len(pb2_obj.label_id) else None,
category_name=pb2_obj.label[idx]
if idx < len(pb2_obj.label) else None,
display_name=pb2_obj.display_name[idx]
if idx < len(pb2_obj.display_name) else None))
return Detection(
bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
pb2_obj.location_data.bounding_box),
categories=categories)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Detection):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class DetectionResult:
"""Represents the list of detected objects.
Attributes:
detections: A list of `Detection` objects.
"""
detections: List[Detection]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionListProto:
"""Generates a DetectionList protobuf object."""
return _DetectionListProto(
detection=[detection.to_pb2() for detection in self.detections])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult':
"""Creates a `DetectionResult` object from the given protobuf object."""
return DetectionResult(detections=[
Detection.create_from_pb2(detection) for detection in pb2_obj.detection
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, DetectionResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -0,0 +1,82 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Landmark Detection Result data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_NormalizedRect = rect_module.NormalizedRect
_ClassificationList = classification_module.ClassificationList
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
_LandmarkList = landmark_module.LandmarkList
@dataclasses.dataclass
class LandmarksDetectionResult:
"""Represents the landmarks detection result.
Attributes:
landmarks : A `NormalizedLandmarkList` object.
classifications : A `ClassificationList` object.
world_landmarks : A `LandmarkList` object.
rect : A `NormalizedRect` object.
"""
landmarks: Optional[_NormalizedLandmarkList]
classifications: Optional[_ClassificationList]
world_landmarks: Optional[_LandmarkList]
rect: _NormalizedRect
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarksDetectionResultProto:
"""Generates a LandmarksDetectionResult protobuf object."""
return _LandmarksDetectionResultProto(
landmarks=self.landmarks.to_pb2(),
classifications=self.classifications.to_pb2(),
world_landmarks=self.world_landmarks.to_pb2(),
rect=self.rect.to_pb2())
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _LandmarksDetectionResultProto
) -> 'LandmarksDetectionResult':
"""Creates a `LandmarksDetectionResult` object from the given protobuf
object."""
return LandmarksDetectionResult(
landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks),
classifications=_ClassificationList.create_from_pb2(
pb2_obj.classifications),
world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks),
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, LandmarksDetectionResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -43,15 +43,19 @@ py_test(
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/containers:classification",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:gesture_recognizer",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python"
],
)

View File

@ -15,23 +15,31 @@
import enum
from google.protobuf import text_format
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import gesture_recognizer
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_BaseOptions = base_options_module.BaseOptions
_NormalizedRect = rect_module.NormalizedRect
_Classification = classification_module.Classification
_ClassificationList = classification_module.ClassificationList
_Landmark = landmark_module.Landmark
_LandmarkList = landmark_module.LandmarkList
_NormalizedLandmark = landmark_module.NormalizedLandmark
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
_Image = image_module.Image
_GestureRecognizer = gesture_recognizer.GestureRecognizer
_GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
@ -39,8 +47,35 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task'
_IMAGE_FILE = 'right_hands.jpg'
_EXPECTED_DETECTION_RESULT = _GestureRecognitionResult([], [], [], [])
_THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = "thumb_up_landmarks.pbtxt"
_THUMB_UP_LABEL = "Thumb_Up"
_THUMB_UP_INDEX = 5
_LANDMARKS_ERROR_TOLERANCE = 0.03
def _get_expected_gesture_recognition_result(
file_path: str, gesture_label: str, gesture_index: int
) -> _GestureRecognitionResult:
landmarks_detection_result_file_path = test_utils.get_test_data_path(
file_path)
with open(landmarks_detection_result_file_path, "rb") as f:
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
# # Use this if a .pb file is available.
# landmarks_detection_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto)
gesture = _ClassificationList(
classifications=[
_Classification(label=gesture_label, index=gesture_index,
display_name='')
], tensor_index=0, tensor_name='')
return _GestureRecognitionResult(
gestures=[gesture],
handedness=[landmarks_detection_result.classifications],
hand_landmarks=[landmarks_detection_result.landmarks],
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
class ModelFileType(enum.Enum):
@ -53,14 +88,45 @@ class GestureRecognizerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE))
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
self.gesture_recognizer_model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_MODEL_FILE)
def _assert_actual_result_approximately_matches_expected_result(
self,
actual_result: _GestureRecognitionResult,
expected_result: _GestureRecognitionResult
):
# Expects to have the same number of hands detected.
self.assertLen(actual_result.hand_landmarks,
len(expected_result.hand_landmarks))
self.assertLen(actual_result.hand_world_landmarks,
len(expected_result.hand_world_landmarks))
self.assertLen(actual_result.handedness, len(expected_result.handedness))
self.assertLen(actual_result.gestures, len(expected_result.gestures))
# Actual landmarks match expected landmarks.
self.assertEqual(actual_result.hand_landmarks,
expected_result.hand_landmarks)
# Actual handedness matches expected handedness.
actual_top_handedness = actual_result.handedness[0].classifications[0]
expected_top_handedness = expected_result.handedness[0].classifications[0]
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
self.assertEqual(actual_top_handedness.label, expected_top_handedness.label)
# Actual gesture with top score matches expected gesture.
actual_top_gesture = actual_result.gestures[0].classifications[0]
expected_top_gesture = expected_result.gestures[0].classifications[0]
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _EXPECTED_DETECTION_RESULT),
(ModelFileType.FILE_CONTENT, _EXPECTED_DETECTION_RESULT))
def test_recognize(self, model_file_type, expected_recognition_result):
(ModelFileType.FILE_NAME, 0.3, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
)),
(ModelFileType.FILE_CONTENT, 0.3, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
)))
def test_recognize(self, model_file_type, min_gesture_confidence,
expected_recognition_result):
# Creates gesture recognizer.
if model_file_type is ModelFileType.FILE_NAME:
gesture_recognizer_base_options = _BaseOptions(
@ -75,13 +141,16 @@ class GestureRecognizerTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _GestureRecognizerOptions(
base_options=gesture_recognizer_base_options)
base_options=gesture_recognizer_base_options,
min_gesture_confidence=min_gesture_confidence
)
recognizer = _GestureRecognizer.create_from_options(options)
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image)
# Comparing results.
self.assertEqual(recognition_result, expected_recognition_result)
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
# Closes the gesture recognizer explicitly when the detector is not used in
# a context.
recognizer.close()

View File

@ -136,8 +136,6 @@ class GestureRecognizerOptions:
"""Generates an GestureRecognizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# hand_landmark_detector_base_options_proto = self.hand_landmark_detector_base_options.to_pb2()
# hand_landmark_detector_base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Configure hand detector options.
hand_detector_options_proto = _HandDetectorGraphOptionsProto(
@ -153,13 +151,12 @@ class GestureRecognizerOptions:
min_tracking_confidence=self.min_tracking_confidence)
# Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto()
if self.min_gesture_confidence >= 0:
classifier_options = _ClassifierOptions(
score_threshold=self.min_gesture_confidence)
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options = \
_GestureClassifierGraphOptionsProto(
classifier_options=classifier_options.to_pb2())
classifier_options = _ClassifierOptions(
score_threshold=self.min_gesture_confidence)
gesture_classifier_options = _GestureClassifierGraphOptionsProto(
classifier_options=classifier_options.to_pb2())
hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto(
canned_gesture_classifier_graph_options=gesture_classifier_options)
return _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto,

View File

@ -121,6 +121,7 @@ filegroup(
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"hand_landmarker.task",
"gesture_recognizer.task",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite",