diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 325dff5fc..8aaa64cc9 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -55,11 +55,13 @@ py_library( ) py_library( - name = "gesture", - srcs = ["gesture.py"], + name = "landmark_detection_result", + srcs = ["landmark_detection_result.py"], deps = [ + ":rect", ":classification", ":landmark", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/classification.py b/mediapipe/tasks/python/components/containers/classification.py index 157c34528..465e2dd28 100644 --- a/mediapipe/tasks/python/components/containers/classification.py +++ b/mediapipe/tasks/python/components/containers/classification.py @@ -14,14 +14,13 @@ """Classification data class.""" import dataclasses -from typing import Any, List +from typing import Any, List, Optional from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _ClassificationProto = classification_pb2.Classification _ClassificationListProto = classification_pb2.ClassificationList -_ClassificationListCollectionProto = classification_pb2.ClassificationListCollection @dataclasses.dataclass @@ -35,10 +34,10 @@ class Classification: display_name: Optional human-readable string for display purposes. """ - index: int - score: float - label_name: str - display_name: str + index: Optional[int] = None + score: Optional[float] = None + label: Optional[str] = None + display_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ClassificationProto: @@ -46,7 +45,7 @@ class Classification: return _ClassificationProto( index=self.index, score=self.score, - label_name=self.label_name, + label=self.label, display_name=self.display_name) @classmethod @@ -56,7 +55,7 @@ class Classification: return Classification( index=pb2_obj.index, score=pb2_obj.score, - label_name=pb2_obj.label_name, + label=pb2_obj.label, display_name=pb2_obj.display_name) def __eq__(self, other: Any) -> bool: @@ -86,8 +85,8 @@ class ClassificationList: """ classifications: List[Classification] - tensor_index: int - tensor_name: str + tensor_index: Optional[int] = None + tensor_name: Optional[str] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ClassificationListProto: diff --git a/mediapipe/tasks/python/components/containers/gesture.py b/mediapipe/tasks/python/components/containers/gesture.py deleted file mode 100644 index f314d18bd..000000000 --- a/mediapipe/tasks/python/components/containers/gesture.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Gesture data class.""" - -import dataclasses -from typing import Any, List - -from mediapipe.tasks.python.components.containers import classification -from mediapipe.tasks.python.components.containers import landmark -from mediapipe.tasks.python.core.optional_dependencies import doc_controls - - -@dataclasses.dataclass -class GestureRecognitionResult: - """ The gesture recognition result from GestureRecognizer, where each vector - element represents a single hand detected in the image. - - Attributes: - gestures: Recognized hand gestures with sorted order such that the - winning label is the first item in the list. - handedness: Classification of handedness. - hand_landmarks: Detected hand landmarks in normalized image coordinates. - hand_world_landmarks: Detected hand landmarks in world coordinates. - """ - - gestures: List[classification.ClassificationList] - handedness: List[classification.ClassificationList] - hand_landmarks: List[landmark.NormalizedLandmarkList] - hand_world_landmarks: List[landmark.LandmarkList] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _DetectionProto: - """Generates a Detection protobuf object.""" - labels = [] - label_ids = [] - scores = [] - display_names = [] - for category in self.categories: - scores.append(category.score) - if category.index: - label_ids.append(category.index) - if category.category_name: - labels.append(category.category_name) - if category.display_name: - display_names.append(category.display_name) - return _DetectionProto( - label=labels, - label_id=label_ids, - score=scores, - display_name=display_names, - location_data=_LocationDataProto( - format=_LocationDataProto.Format.BOUNDING_BOX, - bounding_box=self.bounding_box.to_pb2())) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': - """Creates a `Detection` object from the given protobuf object.""" - categories = [] - for idx, score in enumerate(pb2_obj.score): - categories.append( - category_module.Category( - score=score, - index=pb2_obj.label_id[idx] - if idx < len(pb2_obj.label_id) else None, - category_name=pb2_obj.label[idx] - if idx < len(pb2_obj.label) else None, - display_name=pb2_obj.display_name[idx] - if idx < len(pb2_obj.display_name) else None)) - - return Detection( - bounding_box=bounding_box_module.BoundingBox.create_from_pb2( - pb2_obj.location_data.bounding_box), - categories=categories) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Detection): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - - -@dataclasses.dataclass -class DetectionResult: - """Represents the list of detected objects. - - Attributes: - detections: A list of `Detection` objects. - """ - - detections: List[Detection] - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _DetectionListProto: - """Generates a DetectionList protobuf object.""" - return _DetectionListProto( - detection=[detection.to_pb2() for detection in self.detections]) - - @classmethod - @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult': - """Creates a `DetectionResult` object from the given protobuf object.""" - return DetectionResult(detections=[ - Detection.create_from_pb2(detection) for detection in pb2_obj.detection - ]) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, DetectionResult): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py new file mode 100644 index 000000000..c3d93d414 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -0,0 +1,82 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Landmark Detection Result data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 +from mediapipe.tasks.python.components.containers import rect as rect_module +from mediapipe.tasks.python.components.containers import classification as classification_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult +_NormalizedRect = rect_module.NormalizedRect +_ClassificationList = classification_module.ClassificationList +_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList +_LandmarkList = landmark_module.LandmarkList + + +@dataclasses.dataclass +class LandmarksDetectionResult: + """Represents the landmarks detection result. + + Attributes: + landmarks : A `NormalizedLandmarkList` object. + classifications : A `ClassificationList` object. + world_landmarks : A `LandmarkList` object. + rect : A `NormalizedRect` object. + """ + + landmarks: Optional[_NormalizedLandmarkList] + classifications: Optional[_ClassificationList] + world_landmarks: Optional[_LandmarkList] + rect: _NormalizedRect + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _LandmarksDetectionResultProto: + """Generates a LandmarksDetectionResult protobuf object.""" + return _LandmarksDetectionResultProto( + landmarks=self.landmarks.to_pb2(), + classifications=self.classifications.to_pb2(), + world_landmarks=self.world_landmarks.to_pb2(), + rect=self.rect.to_pb2()) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, + pb2_obj: _LandmarksDetectionResultProto + ) -> 'LandmarksDetectionResult': + """Creates a `LandmarksDetectionResult` object from the given protobuf + object.""" + return LandmarksDetectionResult( + landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks), + classifications=_ClassificationList.create_from_pb2( + pb2_obj.classifications), + world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks), + rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + Args: + other: The object to be compared with. + Returns: + True if the objects are equal. + """ + if not isinstance(other, LandmarksDetectionResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 0dd83edcf..0d8b99984 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -43,15 +43,19 @@ py_test( data = [ "//mediapipe/tasks/testdata/vision:test_images", "//mediapipe/tasks/testdata/vision:test_models", + "//mediapipe/tasks/testdata/vision:test_protos", ], deps = [ "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/components/containers:classification", "//mediapipe/tasks/python/components/containers:landmark", - "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/components/containers:landmark_detection_result", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:gesture_recognizer", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@com_google_protobuf//:protobuf_python" ], ) diff --git a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py index 288cfd1f5..7d731d805 100644 --- a/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py +++ b/mediapipe/tasks/python/test/vision/gesture_recognizer_test.py @@ -15,23 +15,31 @@ import enum +from google.protobuf import text_format from absl.testing import absltest from absl.testing import parameterized from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import classification as classification_module from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import gesture_recognizer from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult _BaseOptions = base_options_module.BaseOptions _NormalizedRect = rect_module.NormalizedRect +_Classification = classification_module.Classification _ClassificationList = classification_module.ClassificationList +_Landmark = landmark_module.Landmark _LandmarkList = landmark_module.LandmarkList +_NormalizedLandmark = landmark_module.NormalizedLandmark _NormalizedLandmarkList = landmark_module.NormalizedLandmarkList +_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult _Image = image_module.Image _GestureRecognizer = gesture_recognizer.GestureRecognizer _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions @@ -39,8 +47,35 @@ _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' -_IMAGE_FILE = 'right_hands.jpg' -_EXPECTED_DETECTION_RESULT = _GestureRecognitionResult([], [], [], []) +_THUMB_UP_IMAGE = 'thumb_up.jpg' +_THUMB_UP_LANDMARKS = "thumb_up_landmarks.pbtxt" +_THUMB_UP_LABEL = "Thumb_Up" +_THUMB_UP_INDEX = 5 +_LANDMARKS_ERROR_TOLERANCE = 0.03 + + +def _get_expected_gesture_recognition_result( + file_path: str, gesture_label: str, gesture_index: int +) -> _GestureRecognitionResult: + landmarks_detection_result_file_path = test_utils.get_test_data_path( + file_path) + with open(landmarks_detection_result_file_path, "rb") as f: + landmarks_detection_result_proto = _LandmarksDetectionResultProto() + # # Use this if a .pb file is available. + # landmarks_detection_result_proto.ParseFromString(f.read()) + text_format.Parse(f.read(), landmarks_detection_result_proto) + landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2( + landmarks_detection_result_proto) + gesture = _ClassificationList( + classifications=[ + _Classification(label=gesture_label, index=gesture_index, + display_name='') + ], tensor_index=0, tensor_name='') + return _GestureRecognitionResult( + gestures=[gesture], + handedness=[landmarks_detection_result.classifications], + hand_landmarks=[landmarks_detection_result.landmarks], + hand_world_landmarks=[landmarks_detection_result.world_landmarks]) class ModelFileType(enum.Enum): @@ -53,14 +88,45 @@ class GestureRecognizerTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_utils.get_test_data_path(_IMAGE_FILE)) + test_utils.get_test_data_path(_THUMB_UP_IMAGE)) self.gesture_recognizer_model_path = test_utils.get_test_data_path( _GESTURE_RECOGNIZER_MODEL_FILE) + def _assert_actual_result_approximately_matches_expected_result( + self, + actual_result: _GestureRecognitionResult, + expected_result: _GestureRecognitionResult + ): + # Expects to have the same number of hands detected. + self.assertLen(actual_result.hand_landmarks, + len(expected_result.hand_landmarks)) + self.assertLen(actual_result.hand_world_landmarks, + len(expected_result.hand_world_landmarks)) + self.assertLen(actual_result.handedness, len(expected_result.handedness)) + self.assertLen(actual_result.gestures, len(expected_result.gestures)) + # Actual landmarks match expected landmarks. + self.assertEqual(actual_result.hand_landmarks, + expected_result.hand_landmarks) + # Actual handedness matches expected handedness. + actual_top_handedness = actual_result.handedness[0].classifications[0] + expected_top_handedness = expected_result.handedness[0].classifications[0] + self.assertEqual(actual_top_handedness.index, expected_top_handedness.index) + self.assertEqual(actual_top_handedness.label, expected_top_handedness.label) + # Actual gesture with top score matches expected gesture. + actual_top_gesture = actual_result.gestures[0].classifications[0] + expected_top_gesture = expected_result.gestures[0].classifications[0] + self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) + self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) + @parameterized.parameters( - (ModelFileType.FILE_NAME, _EXPECTED_DETECTION_RESULT), - (ModelFileType.FILE_CONTENT, _EXPECTED_DETECTION_RESULT)) - def test_recognize(self, model_file_type, expected_recognition_result): + (ModelFileType.FILE_NAME, 0.3, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + )), + (ModelFileType.FILE_CONTENT, 0.3, _get_expected_gesture_recognition_result( + _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX + ))) + def test_recognize(self, model_file_type, min_gesture_confidence, + expected_recognition_result): # Creates gesture recognizer. if model_file_type is ModelFileType.FILE_NAME: gesture_recognizer_base_options = _BaseOptions( @@ -75,13 +141,16 @@ class GestureRecognizerTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _GestureRecognizerOptions( - base_options=gesture_recognizer_base_options) + base_options=gesture_recognizer_base_options, + min_gesture_confidence=min_gesture_confidence + ) recognizer = _GestureRecognizer.create_from_options(options) # Performs hand gesture recognition on the input. recognition_result = recognizer.recognize(self.test_image) # Comparing results. - self.assertEqual(recognition_result, expected_recognition_result) + self._assert_actual_result_approximately_matches_expected_result( + recognition_result, expected_recognition_result) # Closes the gesture recognizer explicitly when the detector is not used in # a context. recognizer.close() diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index aca7a5277..c00508b36 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -136,8 +136,6 @@ class GestureRecognizerOptions: """Generates an GestureRecognizerOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True - # hand_landmark_detector_base_options_proto = self.hand_landmark_detector_base_options.to_pb2() - # hand_landmark_detector_base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True # Configure hand detector options. hand_detector_options_proto = _HandDetectorGraphOptionsProto( @@ -153,13 +151,12 @@ class GestureRecognizerOptions: min_tracking_confidence=self.min_tracking_confidence) # Configure hand gesture recognizer options. - hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto() - if self.min_gesture_confidence >= 0: - classifier_options = _ClassifierOptions( - score_threshold=self.min_gesture_confidence) - hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options = \ - _GestureClassifierGraphOptionsProto( - classifier_options=classifier_options.to_pb2()) + classifier_options = _ClassifierOptions( + score_threshold=self.min_gesture_confidence) + gesture_classifier_options = _GestureClassifierGraphOptionsProto( + classifier_options=classifier_options.to_pb2()) + hand_gesture_recognizer_options_proto = _HandGestureRecognizerGraphOptionsProto( + canned_gesture_classifier_graph_options=gesture_classifier_options) return _GestureRecognizerGraphOptionsProto( base_options=base_options_proto, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ebb8f05a6..365921bc1 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -121,6 +121,7 @@ filegroup( "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", + "gesture_recognizer.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",