Removed classification proto to use the existing category dataclass instead and removed NormalizedLandmarkList and LandmarkList dataclasses
This commit is contained in:
parent
0f7c5d5e90
commit
f62cfd1690
|
@ -36,15 +36,6 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification",
|
||||
srcs = ["classification.py"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "landmark",
|
||||
srcs = ["landmark.py"],
|
||||
|
@ -59,9 +50,11 @@ py_library(
|
|||
srcs = ["landmark_detection_result.py"],
|
||||
deps = [
|
||||
":rect",
|
||||
":classification",
|
||||
":landmark",
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"""Category data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
@ -39,10 +39,10 @@ class Category:
|
|||
category_name: The label of this category object.
|
||||
"""
|
||||
|
||||
index: int
|
||||
score: float
|
||||
display_name: str
|
||||
category_name: str
|
||||
index: Optional[int] = None
|
||||
score: Optional[float] = None
|
||||
display_name: Optional[str] = None
|
||||
category_name: Optional[str] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _CategoryProto:
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Classification data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_ClassificationProto = classification_pb2.Classification
|
||||
_ClassificationListProto = classification_pb2.ClassificationList
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Classification:
|
||||
"""A classification.
|
||||
|
||||
Attributes:
|
||||
index: The index of the class in the corresponding label map.
|
||||
score: The probability score for this class.
|
||||
label_name: Label or name of the class.
|
||||
display_name: Optional human-readable string for display purposes.
|
||||
"""
|
||||
|
||||
index: Optional[int] = None
|
||||
score: Optional[float] = None
|
||||
label: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationProto:
|
||||
"""Generates a Classification protobuf object."""
|
||||
return _ClassificationProto(
|
||||
index=self.index,
|
||||
score=self.score,
|
||||
label=self.label,
|
||||
display_name=self.display_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Classification':
|
||||
"""Creates a `Classification` object from the given protobuf object."""
|
||||
return Classification(
|
||||
index=pb2_obj.index,
|
||||
score=pb2_obj.score,
|
||||
label=pb2_obj.label,
|
||||
display_name=pb2_obj.display_name)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Classification):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationList:
|
||||
"""Represents the classifications for a given classifier.
|
||||
Attributes:
|
||||
classification : A list of `Classification` objects.
|
||||
tensor_index: Optional index of the tensor that produced these
|
||||
classifications.
|
||||
tensor_name: Optional name of the tensor that produced these
|
||||
classifications tensor metadata name.
|
||||
"""
|
||||
|
||||
classifications: List[Classification]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationListProto:
|
||||
"""Generates a ClassificationList protobuf object."""
|
||||
return _ClassificationListProto(
|
||||
classification=[
|
||||
classification.to_pb2()
|
||||
for classification in self.classifications
|
||||
])
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls,
|
||||
pb2_obj: _ClassificationListProto
|
||||
) -> 'ClassificationList':
|
||||
"""Creates a `ClassificationList` object from the given protobuf object."""
|
||||
return ClassificationList(
|
||||
classifications=[
|
||||
Classification.create_from_pb2(classification)
|
||||
for classification in pb2_obj.classification
|
||||
])
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, ClassificationList):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -20,9 +20,7 @@ from mediapipe.framework.formats import landmark_pb2
|
|||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_LandmarkProto = landmark_pb2.Landmark
|
||||
_LandmarkListProto = landmark_pb2.LandmarkList
|
||||
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
|
||||
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -89,53 +87,6 @@ class Landmark:
|
|||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LandmarkList:
|
||||
"""Represents the list of landmarks.
|
||||
|
||||
Attributes:
|
||||
landmarks : A list of `Landmark` objects.
|
||||
"""
|
||||
|
||||
landmarks: List[Landmark]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _LandmarkListProto:
|
||||
"""Generates a LandmarkList protobuf object."""
|
||||
return _LandmarkListProto(
|
||||
landmark=[
|
||||
landmark.to_pb2()
|
||||
for landmark in self.landmarks
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls,
|
||||
pb2_obj: _LandmarkListProto
|
||||
) -> 'LandmarkList':
|
||||
"""Creates a `LandmarkList` object from the given protobuf object."""
|
||||
return LandmarkList(
|
||||
landmarks=[
|
||||
Landmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.landmark
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, LandmarkList):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedLandmark:
|
||||
"""A normalized version of above Landmark proto.
|
||||
|
@ -201,50 +152,3 @@ class NormalizedLandmark:
|
|||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedLandmarkList:
|
||||
"""Represents the list of normalized landmarks.
|
||||
|
||||
Attributes:
|
||||
landmarks : A list of `Landmark` objects.
|
||||
"""
|
||||
|
||||
landmarks: List[NormalizedLandmark]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _NormalizedLandmarkListProto:
|
||||
"""Generates a NormalizedLandmarkList protobuf object."""
|
||||
return _NormalizedLandmarkListProto(
|
||||
landmark=[
|
||||
landmark.to_pb2()
|
||||
for landmark in self.landmarks
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls,
|
||||
pb2_obj: _NormalizedLandmarkListProto
|
||||
) -> 'NormalizedLandmarkList':
|
||||
"""Creates a `NormalizedLandmarkList` object from the given protobuf object."""
|
||||
return NormalizedLandmarkList(
|
||||
landmarks=[
|
||||
NormalizedLandmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.landmark
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, NormalizedLandmarkList):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
|
|
@ -14,19 +14,25 @@
|
|||
"""Landmarks Detection Result data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||
_ClassificationProto = classification_pb2.Classification
|
||||
_ClassificationListProto = classification_pb2.ClassificationList
|
||||
_LandmarkListProto = landmark_pb2.LandmarkList
|
||||
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
|
||||
_NormalizedRect = rect_module.NormalizedRect
|
||||
_ClassificationList = classification_module.ClassificationList
|
||||
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
|
||||
_LandmarkList = landmark_module.LandmarkList
|
||||
_Category = category_module.Category
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_Landmark = landmark_module.Landmark
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -34,25 +40,32 @@ class LandmarksDetectionResult:
|
|||
"""Represents the landmarks detection result.
|
||||
|
||||
Attributes:
|
||||
landmarks : A `NormalizedLandmarkList` object.
|
||||
classifications : A `ClassificationList` object.
|
||||
world_landmarks : A `LandmarkList` object.
|
||||
landmarks : A list of `NormalizedLandmark` objects.
|
||||
categories : A list of `Category` objects.
|
||||
world_landmarks : A list of `Landmark` objects.
|
||||
rect : A `NormalizedRect` object.
|
||||
"""
|
||||
|
||||
landmarks: Optional[_NormalizedLandmarkList]
|
||||
classifications: Optional[_ClassificationList]
|
||||
world_landmarks: Optional[_LandmarkList]
|
||||
landmarks: Optional[List[_NormalizedLandmark]]
|
||||
categories: Optional[List[_Category]]
|
||||
world_landmarks: Optional[List[_Landmark]]
|
||||
rect: _NormalizedRect
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
||||
"""Generates a LandmarksDetectionResult protobuf object."""
|
||||
return _LandmarksDetectionResultProto(
|
||||
landmarks=self.landmarks.to_pb2(),
|
||||
classifications=self.classifications.to_pb2(),
|
||||
world_landmarks=self.world_landmarks.to_pb2(),
|
||||
rect=self.rect.to_pb2())
|
||||
landmarks=_NormalizedLandmarkListProto(landmarks=self.landmarks),
|
||||
classifications=_ClassificationListProto(
|
||||
classification=[
|
||||
_ClassificationProto(
|
||||
index=category.index,
|
||||
score=category.score,
|
||||
label=category.category_name,
|
||||
display_name=category.display_name)
|
||||
for category in self.categories]),
|
||||
world_landmarks=_LandmarkListProto(landmarks=self.world_landmarks),
|
||||
rect=self.rect.to_pb2())
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -63,11 +76,19 @@ class LandmarksDetectionResult:
|
|||
"""Creates a `LandmarksDetectionResult` object from the given protobuf
|
||||
object."""
|
||||
return LandmarksDetectionResult(
|
||||
landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks),
|
||||
classifications=_ClassificationList.create_from_pb2(
|
||||
pb2_obj.classifications),
|
||||
world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks),
|
||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||
landmarks=[
|
||||
_NormalizedLandmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.landmarks.landmark],
|
||||
categories=[category_module.Category(
|
||||
score=classification.score,
|
||||
index=classification.index,
|
||||
category_name=classification.label,
|
||||
display_name=classification.display_name)
|
||||
for classification in pb2_obj.classifications.classification],
|
||||
world_landmarks=[
|
||||
_Landmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.world_landmarks.landmark],
|
||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
|
@ -69,7 +69,7 @@ py_test(
|
|||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/containers:classification",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
|
|
|
@ -24,7 +24,7 @@ from absl.testing import parameterized
|
|||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -36,12 +36,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
|
|||
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect_module.Rect
|
||||
_Classification = classification_module.Classification
|
||||
_ClassificationList = classification_module.ClassificationList
|
||||
_Category = category_module.Category
|
||||
_Landmark = landmark_module.Landmark
|
||||
_LandmarkList = landmark_module.LandmarkList
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
|
||||
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
|
||||
_Image = image_module.Image
|
||||
_GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||
|
@ -76,14 +73,11 @@ def _get_expected_gesture_recognition_result(
|
|||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
landmarks_detection_result_proto)
|
||||
gesture = _ClassificationList(
|
||||
classifications=[
|
||||
_Classification(label=gesture_label, index=gesture_index,
|
||||
display_name='')
|
||||
])
|
||||
gesture = _Category(category_name=gesture_label, index=gesture_index,
|
||||
display_name='')
|
||||
return _GestureRecognitionResult(
|
||||
gestures=[gesture],
|
||||
handedness=[landmarks_detection_result.classifications],
|
||||
gestures=[[gesture]],
|
||||
handedness=[landmarks_detection_result.categories],
|
||||
hand_landmarks=[landmarks_detection_result.landmarks],
|
||||
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
||||
|
||||
|
@ -115,25 +109,27 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
self.assertLen(actual_result.handedness, len(expected_result.handedness))
|
||||
self.assertLen(actual_result.gestures, len(expected_result.gestures))
|
||||
# Actual landmarks match expected landmarks.
|
||||
self.assertLen(actual_result.hand_landmarks[0].landmarks,
|
||||
len(expected_result.hand_landmarks[0].landmarks))
|
||||
actual_landmarks = actual_result.hand_landmarks[0].landmarks
|
||||
expected_landmarks = expected_result.hand_landmarks[0].landmarks
|
||||
self.assertLen(actual_result.hand_landmarks[0],
|
||||
len(expected_result.hand_landmarks[0]))
|
||||
actual_landmarks = actual_result.hand_landmarks[0]
|
||||
expected_landmarks = expected_result.hand_landmarks[0]
|
||||
for i in range(len(actual_landmarks)):
|
||||
self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x,
|
||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||
self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y,
|
||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||
# Actual handedness matches expected handedness.
|
||||
actual_top_handedness = actual_result.handedness[0].classifications[0]
|
||||
expected_top_handedness = expected_result.handedness[0].classifications[0]
|
||||
actual_top_handedness = actual_result.handedness[0][0]
|
||||
expected_top_handedness = expected_result.handedness[0][0]
|
||||
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
||||
self.assertEqual(actual_top_handedness.label, expected_top_handedness.label)
|
||||
self.assertEqual(actual_top_handedness.category_name,
|
||||
expected_top_handedness.category_name)
|
||||
# Actual gesture with top score matches expected gesture.
|
||||
actual_top_gesture = actual_result.gestures[0].classifications[0]
|
||||
expected_top_gesture = expected_result.gestures[0].classifications[0]
|
||||
actual_top_gesture = actual_result.gestures[0][0]
|
||||
expected_top_gesture = expected_result.gestures[0][0]
|
||||
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
||||
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
|
||||
self.assertEqual(actual_top_gesture.category_name,
|
||||
expected_top_gesture.category_name)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -235,12 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase):
|
|||
expected_result = _get_expected_gesture_recognition_result(
|
||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
|
||||
# Only contains one top scoring gesture.
|
||||
self.assertLen(recognition_result.gestures[0].classifications, 1)
|
||||
self.assertLen(recognition_result.gestures[0], 1)
|
||||
# Actual gesture with top score matches expected gesture.
|
||||
actual_top_gesture = recognition_result.gestures[0].classifications[0]
|
||||
expected_top_gesture = expected_result.gestures[0].classifications[0]
|
||||
actual_top_gesture = recognition_result.gestures[0][0]
|
||||
expected_top_gesture = expected_result.gestures[0][0]
|
||||
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
||||
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
|
||||
self.assertEqual(actual_top_gesture.category_name,
|
||||
expected_top_gesture.category_name)
|
||||
|
||||
def test_recognize_succeeds_with_num_hands(self):
|
||||
# Creates gesture recognizer.
|
||||
|
|
|
@ -74,7 +74,7 @@ py_library(
|
|||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classification",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
|
|
|
@ -27,7 +27,7 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco
|
|||
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
|
@ -80,10 +80,10 @@ class GestureRecognitionResult:
|
|||
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||
"""
|
||||
|
||||
gestures: List[classification_module.ClassificationList]
|
||||
handedness: List[classification_module.ClassificationList]
|
||||
hand_landmarks: List[landmark_module.NormalizedLandmarkList]
|
||||
hand_world_landmarks: List[landmark_module.LandmarkList]
|
||||
gestures: List[List[category_module.Category]]
|
||||
handedness: List[List[category_module.Category]]
|
||||
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -231,16 +231,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
gesture_recognition_result = GestureRecognitionResult(
|
||||
[
|
||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
||||
for gestures in gestures_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
||||
for handedness in handedness_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
|
@ -314,16 +324,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
return GestureRecognitionResult(
|
||||
[
|
||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
||||
for gestures in gestures_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
||||
for handedness in handedness_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
|
@ -377,16 +397,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
return GestureRecognitionResult(
|
||||
[
|
||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
||||
for gestures in gestures_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in gesture_classifications.classification]
|
||||
for gesture_classifications in gestures_proto_list
|
||||
], [
|
||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
||||
for handedness in handedness_proto_list
|
||||
[
|
||||
category_module.Category(
|
||||
index=gesture.index, score=gesture.score,
|
||||
display_name=gesture.display_name, category_name=gesture.label)
|
||||
for gesture in handedness_classifications.classification]
|
||||
for handedness_classifications in handedness_proto_list
|
||||
], [
|
||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
||||
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark]
|
||||
for hand_landmarks in hand_landmarks_proto_list
|
||||
], [
|
||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
||||
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user