Refactored code and removed some issues

This commit is contained in:
kinaryml 2022-10-30 15:42:26 -07:00
parent 4b66599419
commit fb4872b068
3 changed files with 49 additions and 150 deletions

View File

@ -30,9 +30,9 @@ class Landmark:
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points. Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
Attributes: Attributes:
x: The x coordinate of the 3D point. x: The x coordinate.
y: The y coordinate of the 3D point. y: The y coordinate.
z: The z coordinate of the 3D point. z: The z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. visibility: Landmark visibility. Should stay unset if not supported.
Float score of whether landmark is visible or occluded by other objects. Float score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen Landmark considered as invisible also if it is not present on the screen
@ -72,20 +72,6 @@ class Landmark:
visibility=pb2_obj.visibility, visibility=pb2_obj.visibility,
presence=pb2_obj.presence) presence=pb2_obj.presence)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Landmark):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass @dataclasses.dataclass
class NormalizedLandmark: class NormalizedLandmark:
@ -94,9 +80,9 @@ class NormalizedLandmark:
All coordinates should be within [0, 1]. All coordinates should be within [0, 1].
Attributes: Attributes:
x: The normalized x coordinate of the 3D point. x: The normalized x coordinate.
y: The normalized y coordinate of the 3D point. y: The normalized y coordinate.
z: The normalized z coordinate of the 3D point. z: The normalized z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. visibility: Landmark visibility. Should stay unset if not supported.
Float score of whether landmark is visible or occluded by other objects. Float score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen Landmark considered as invisible also if it is not present on the screen
@ -138,17 +124,3 @@ class NormalizedLandmark:
z=pb2_obj.z, z=pb2_obj.z,
visibility=pb2_obj.visibility, visibility=pb2_obj.visibility,
presence=pb2_obj.presence) presence=pb2_obj.presence)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, NormalizedLandmark):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -14,7 +14,7 @@
"""Landmarks Detection Result data class.""" """Landmarks Detection Result data class."""
import dataclasses import dataclasses
from typing import Any, Optional, List from typing import Optional, List
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2 from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import classification_pb2
@ -89,15 +89,3 @@ class LandmarksDetectionResult:
_Landmark.create_from_pb2(landmark) _Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark], for landmark in pb2_obj.world_landmarks.landmark],
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, LandmarksDetectionResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -48,7 +48,6 @@ _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner_module.TaskRunner
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
@ -86,6 +85,45 @@ class GestureRecognitionResult:
hand_world_landmarks: List[List[landmark_module.Landmark]] hand_world_landmarks: List[List[landmark_module.Landmark]]
def _build_recognition_result(
output_packets: Mapping[str, packet_module.Packet]
) -> GestureRecognitionResult:
gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
return GestureRecognitionResult(
[
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
@dataclasses.dataclass @dataclasses.dataclass
class GestureRecognizerOptions: class GestureRecognizerOptions:
"""Options for the gesture recognizer task. """Options for the gesture recognizer task.
@ -220,40 +258,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
return return
gestures_proto_list = packet_getter.get_proto_list( gesture_recognition_result = _build_recognition_result(output_packets)
output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
gesture_recognition_result = GestureRecognitionResult(
[
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
gesture_recognition_result, image, gesture_recognition_result, image,
@ -313,40 +318,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], []) return GestureRecognitionResult([], [], [], [])
gestures_proto_list = packet_getter.get_proto_list( return _build_recognition_result(output_packets)
output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
return GestureRecognitionResult(
[
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
def recognize_for_video( def recognize_for_video(
self, image: image_module.Image, self, image: image_module.Image,
@ -386,40 +358,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], []) return GestureRecognitionResult([], [], [], [])
gestures_proto_list = packet_getter.get_proto_list( return _build_recognition_result(output_packets)
output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
return GestureRecognitionResult(
[
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
def recognize_async( def recognize_async(
self, self,