Refactored code and removed some issues
This commit is contained in:
parent
4b66599419
commit
fb4872b068
|
@ -30,9 +30,9 @@ class Landmark:
|
||||||
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
|
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
x: The x coordinate of the 3D point.
|
x: The x coordinate.
|
||||||
y: The y coordinate of the 3D point.
|
y: The y coordinate.
|
||||||
z: The z coordinate of the 3D point.
|
z: The z coordinate.
|
||||||
visibility: Landmark visibility. Should stay unset if not supported.
|
visibility: Landmark visibility. Should stay unset if not supported.
|
||||||
Float score of whether landmark is visible or occluded by other objects.
|
Float score of whether landmark is visible or occluded by other objects.
|
||||||
Landmark considered as invisible also if it is not present on the screen
|
Landmark considered as invisible also if it is not present on the screen
|
||||||
|
@ -72,20 +72,6 @@ class Landmark:
|
||||||
visibility=pb2_obj.visibility,
|
visibility=pb2_obj.visibility,
|
||||||
presence=pb2_obj.presence)
|
presence=pb2_obj.presence)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Landmark):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NormalizedLandmark:
|
class NormalizedLandmark:
|
||||||
|
@ -94,9 +80,9 @@ class NormalizedLandmark:
|
||||||
All coordinates should be within [0, 1].
|
All coordinates should be within [0, 1].
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
x: The normalized x coordinate of the 3D point.
|
x: The normalized x coordinate.
|
||||||
y: The normalized y coordinate of the 3D point.
|
y: The normalized y coordinate.
|
||||||
z: The normalized z coordinate of the 3D point.
|
z: The normalized z coordinate.
|
||||||
visibility: Landmark visibility. Should stay unset if not supported.
|
visibility: Landmark visibility. Should stay unset if not supported.
|
||||||
Float score of whether landmark is visible or occluded by other objects.
|
Float score of whether landmark is visible or occluded by other objects.
|
||||||
Landmark considered as invisible also if it is not present on the screen
|
Landmark considered as invisible also if it is not present on the screen
|
||||||
|
@ -138,17 +124,3 @@ class NormalizedLandmark:
|
||||||
z=pb2_obj.z,
|
z=pb2_obj.z,
|
||||||
visibility=pb2_obj.visibility,
|
visibility=pb2_obj.visibility,
|
||||||
presence=pb2_obj.presence)
|
presence=pb2_obj.presence)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, NormalizedLandmark):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""Landmarks Detection Result data class."""
|
"""Landmarks Detection Result data class."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||||
from mediapipe.framework.formats import classification_pb2
|
from mediapipe.framework.formats import classification_pb2
|
||||||
|
@ -89,15 +89,3 @@ class LandmarksDetectionResult:
|
||||||
_Landmark.create_from_pb2(landmark)
|
_Landmark.create_from_pb2(landmark)
|
||||||
for landmark in pb2_obj.world_landmarks.landmark],
|
for landmark in pb2_obj.world_landmarks.landmark],
|
||||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, LandmarksDetectionResult):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
|
@ -48,7 +48,6 @@ _ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
|
||||||
|
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
|
@ -86,6 +85,45 @@ class GestureRecognitionResult:
|
||||||
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_recognition_result(
|
||||||
|
output_packets: Mapping[str, packet_module.Packet]
|
||||||
|
) -> GestureRecognitionResult:
|
||||||
|
gestures_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||||
|
handedness_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||||
|
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||||
|
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||||
|
|
||||||
|
return GestureRecognitionResult(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in gesture_classifications.classification]
|
||||||
|
for gesture_classifications in gestures_proto_list
|
||||||
|
], [
|
||||||
|
[
|
||||||
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in handedness_classifications.classification]
|
||||||
|
for handedness_classifications in handedness_proto_list
|
||||||
|
], [
|
||||||
|
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||||
|
for hand_landmark in hand_landmarks.landmark]
|
||||||
|
for hand_landmarks in hand_landmarks_proto_list
|
||||||
|
], [
|
||||||
|
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||||
|
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class GestureRecognizerOptions:
|
class GestureRecognizerOptions:
|
||||||
"""Options for the gesture recognizer task.
|
"""Options for the gesture recognizer task.
|
||||||
|
@ -220,40 +258,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
return
|
return
|
||||||
|
|
||||||
gestures_proto_list = packet_getter.get_proto_list(
|
gesture_recognition_result = _build_recognition_result(output_packets)
|
||||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
|
||||||
handedness_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
|
||||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
|
||||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
|
||||||
|
|
||||||
gesture_recognition_result = GestureRecognitionResult(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in gesture_classifications.classification]
|
|
||||||
for gesture_classifications in gestures_proto_list
|
|
||||||
], [
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in handedness_classifications.classification]
|
|
||||||
for handedness_classifications in handedness_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
|
||||||
for hand_landmark in hand_landmarks.landmark]
|
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
|
||||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
|
||||||
]
|
|
||||||
)
|
|
||||||
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
|
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
|
||||||
options.result_callback(
|
options.result_callback(
|
||||||
gesture_recognition_result, image,
|
gesture_recognition_result, image,
|
||||||
|
@ -313,40 +318,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
return GestureRecognitionResult([], [], [], [])
|
return GestureRecognitionResult([], [], [], [])
|
||||||
|
|
||||||
gestures_proto_list = packet_getter.get_proto_list(
|
return _build_recognition_result(output_packets)
|
||||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
|
||||||
handedness_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
|
||||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
|
||||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
|
||||||
|
|
||||||
return GestureRecognitionResult(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in gesture_classifications.classification]
|
|
||||||
for gesture_classifications in gestures_proto_list
|
|
||||||
], [
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in handedness_classifications.classification]
|
|
||||||
for handedness_classifications in handedness_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
|
||||||
for hand_landmark in hand_landmarks.landmark]
|
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
|
||||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def recognize_for_video(
|
def recognize_for_video(
|
||||||
self, image: image_module.Image,
|
self, image: image_module.Image,
|
||||||
|
@ -386,40 +358,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
return GestureRecognitionResult([], [], [], [])
|
return GestureRecognitionResult([], [], [], [])
|
||||||
|
|
||||||
gestures_proto_list = packet_getter.get_proto_list(
|
return _build_recognition_result(output_packets)
|
||||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
|
||||||
handedness_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
|
||||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
|
||||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
|
||||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
|
||||||
|
|
||||||
return GestureRecognitionResult(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in gesture_classifications.classification]
|
|
||||||
for gesture_classifications in gestures_proto_list
|
|
||||||
], [
|
|
||||||
[
|
|
||||||
category_module.Category(
|
|
||||||
index=gesture.index, score=gesture.score,
|
|
||||||
display_name=gesture.display_name, category_name=gesture.label)
|
|
||||||
for gesture in handedness_classifications.classification]
|
|
||||||
for handedness_classifications in handedness_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
|
||||||
for hand_landmark in hand_landmarks.landmark]
|
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
|
||||||
], [
|
|
||||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
|
||||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def recognize_async(
|
def recognize_async(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user