From fb4872b068b9d34d63997779f3b746d389852fa5 Mon Sep 17 00:00:00 2001 From: kinaryml Date: Sun, 30 Oct 2022 15:42:26 -0700 Subject: [PATCH] Refactored code and removed some issues --- .../python/components/containers/landmark.py | 40 +---- .../containers/landmark_detection_result.py | 14 +- .../tasks/python/vision/gesture_recognizer.py | 145 +++++------------- 3 files changed, 49 insertions(+), 150 deletions(-) diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index 2c87ee676..7eb7d8e96 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -30,9 +30,9 @@ class Landmark: Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points. Attributes: - x: The x coordinate of the 3D point. - y: The y coordinate of the 3D point. - z: The z coordinate of the 3D point. + x: The x coordinate. + y: The y coordinate. + z: The z coordinate. visibility: Landmark visibility. Should stay unset if not supported. Float score of whether landmark is visible or occluded by other objects. Landmark considered as invisible also if it is not present on the screen @@ -72,20 +72,6 @@ class Landmark: visibility=pb2_obj.visibility, presence=pb2_obj.presence) - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, Landmark): - return False - - return self.to_pb2().__eq__(other.to_pb2()) - @dataclasses.dataclass class NormalizedLandmark: @@ -94,9 +80,9 @@ class NormalizedLandmark: All coordinates should be within [0, 1]. Attributes: - x: The normalized x coordinate of the 3D point. - y: The normalized y coordinate of the 3D point. - z: The normalized z coordinate of the 3D point. + x: The normalized x coordinate. + y: The normalized y coordinate. + z: The normalized z coordinate. visibility: Landmark visibility. Should stay unset if not supported. Float score of whether landmark is visible or occluded by other objects. Landmark considered as invisible also if it is not present on the screen @@ -138,17 +124,3 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - - Args: - other: The object to be compared with. - - Returns: - True if the objects are equal. - """ - if not isinstance(other, NormalizedLandmark): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index ad21812c7..7c21733e2 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -14,7 +14,7 @@ """Landmarks Detection Result data class.""" import dataclasses -from typing import Any, Optional, List +from typing import Optional, List from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.framework.formats import classification_pb2 @@ -89,15 +89,3 @@ class LandmarksDetectionResult: _Landmark.create_from_pb2(landmark) for landmark in pb2_obj.world_landmarks.landmark], rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) - - def __eq__(self, other: Any) -> bool: - """Checks if this object is equal to the given object. - Args: - other: The object to be compared with. - Returns: - True if the objects are equal. - """ - if not isinstance(other, LandmarksDetectionResult): - return False - - return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index e8d9ef342..c6d30dc4e 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -48,7 +48,6 @@ _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_TaskRunner = task_runner_module.TaskRunner _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' @@ -86,6 +85,45 @@ class GestureRecognitionResult: hand_world_landmarks: List[List[landmark_module.Landmark]] +def _build_recognition_result( + output_packets: Mapping[str, packet_module.Packet] +) -> GestureRecognitionResult: + gestures_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_GESTURE_STREAM_NAME]) + handedness_proto_list = packet_getter.get_proto_list( + output_packets[_HANDEDNESS_STREAM_NAME]) + hand_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_LANDMARKS_STREAM_NAME]) + hand_world_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + + return GestureRecognitionResult( + [ + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in gesture_classifications.classification] + for gesture_classifications in gestures_proto_list + ], [ + [ + category_module.Category( + index=gesture.index, score=gesture.score, + display_name=gesture.display_name, category_name=gesture.label) + for gesture in handedness_classifications.classification] + for handedness_classifications in handedness_proto_list + ], [ + [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + for hand_landmark in hand_landmarks.landmark] + for hand_landmarks in hand_landmarks_proto_list + ], [ + [landmark_module.Landmark.create_from_pb2(hand_world_landmark) + for hand_world_landmark in hand_world_landmarks.landmark] + for hand_world_landmarks in hand_world_landmarks_proto_list + ] + ) + + @dataclasses.dataclass class GestureRecognizerOptions: """Options for the gesture recognizer task. @@ -220,40 +258,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) return - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - gesture_recognition_result = GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + gesture_recognition_result = _build_recognition_result(output_packets) timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp options.result_callback( gesture_recognition_result, image, @@ -313,40 +318,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): return GestureRecognitionResult([], [], [], []) - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - return GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + return _build_recognition_result(output_packets) def recognize_for_video( self, image: image_module.Image, @@ -386,40 +358,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): return GestureRecognitionResult([], [], [], []) - gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) - handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) - hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) - hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) - - return GestureRecognitionResult( - [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in gesture_classifications.classification] - for gesture_classifications in gestures_proto_list - ], [ - [ - category_module.Category( - index=gesture.index, score=gesture.score, - display_name=gesture.display_name, category_name=gesture.label) - for gesture in handedness_classifications.classification] - for handedness_classifications in handedness_proto_list - ], [ - [landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) - for hand_landmark in hand_landmarks.landmark] - for hand_landmarks in hand_landmarks_proto_list - ], [ - [landmark_module.Landmark.create_from_pb2(hand_world_landmark) - for hand_world_landmark in hand_world_landmarks.landmark] - for hand_world_landmarks in hand_world_landmarks_proto_list - ] - ) + return _build_recognition_result(output_packets) def recognize_async( self,