Refactored code and removed some issues
This commit is contained in:
parent
4b66599419
commit
fb4872b068
|
@ -30,9 +30,9 @@ class Landmark:
|
|||
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
|
||||
|
||||
Attributes:
|
||||
x: The x coordinate of the 3D point.
|
||||
y: The y coordinate of the 3D point.
|
||||
z: The z coordinate of the 3D point.
|
||||
x: The x coordinate.
|
||||
y: The y coordinate.
|
||||
z: The z coordinate.
|
||||
visibility: Landmark visibility. Should stay unset if not supported.
|
||||
Float score of whether landmark is visible or occluded by other objects.
|
||||
Landmark considered as invisible also if it is not present on the screen
|
||||
|
@ -72,20 +72,6 @@ class Landmark:
|
|||
visibility=pb2_obj.visibility,
|
||||
presence=pb2_obj.presence)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Landmark):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedLandmark:
|
||||
|
@ -94,9 +80,9 @@ class NormalizedLandmark:
|
|||
All coordinates should be within [0, 1].
|
||||
|
||||
Attributes:
|
||||
x: The normalized x coordinate of the 3D point.
|
||||
y: The normalized y coordinate of the 3D point.
|
||||
z: The normalized z coordinate of the 3D point.
|
||||
x: The normalized x coordinate.
|
||||
y: The normalized y coordinate.
|
||||
z: The normalized z coordinate.
|
||||
visibility: Landmark visibility. Should stay unset if not supported.
|
||||
Float score of whether landmark is visible or occluded by other objects.
|
||||
Landmark considered as invisible also if it is not present on the screen
|
||||
|
@ -138,17 +124,3 @@ class NormalizedLandmark:
|
|||
z=pb2_obj.z,
|
||||
visibility=pb2_obj.visibility,
|
||||
presence=pb2_obj.presence)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, NormalizedLandmark):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"""Landmarks Detection Result data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Optional, List
|
||||
from typing import Optional, List
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
|
@ -89,15 +89,3 @@ class LandmarksDetectionResult:
|
|||
_Landmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.world_landmarks.landmark],
|
||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, LandmarksDetectionResult):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
|
|
@ -48,7 +48,6 @@ _ClassifierOptions = classifier_options.ClassifierOptions
|
|||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
_TaskRunner = task_runner_module.TaskRunner
|
||||
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
|
@ -86,6 +85,45 @@ class GestureRecognitionResult:
|
|||
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||
|
||||
|
||||
def _build_recognition_result(
|
||||
output_packets: Mapping[str, packet_module.Packet]
|
||||
) -> GestureRecognitionResult:
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
|
||||
return GestureRecognitionResult(
|
||||
[
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognizerOptions:
|
||||
"""Options for the gesture recognizer task.
|
||||
|
@ -220,40 +258,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
return
|
||||
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
|
||||
gesture_recognition_result = GestureRecognitionResult(
|
||||
[
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
gesture_recognition_result = _build_recognition_result(output_packets)
|
||||
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
|
||||
options.result_callback(
|
||||
gesture_recognition_result, image,
|
||||
|
@ -313,40 +318,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
return GestureRecognitionResult([], [], [], [])
|
||||
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
|
||||
return GestureRecognitionResult(
|
||||
[
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
return _build_recognition_result(output_packets)
|
||||
|
||||
def recognize_for_video(
|
||||
self, image: image_module.Image,
|
||||
|
@ -386,40 +358,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
return GestureRecognitionResult([], [], [], [])
|
||||
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
|
||||
return GestureRecognitionResult(
|
||||
[
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
return _build_recognition_result(output_packets)
|
||||
|
||||
def recognize_async(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user