Removed classification proto to use the existing category dataclass instead and removed NormalizedLandmarkList and LandmarkList dataclasses
This commit is contained in:
parent
0f7c5d5e90
commit
f62cfd1690
|
@ -36,15 +36,6 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "classification",
|
|
||||||
srcs = ["classification.py"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework/formats:classification_py_pb2",
|
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "landmark",
|
name = "landmark",
|
||||||
srcs = ["landmark.py"],
|
srcs = ["landmark.py"],
|
||||||
|
@ -59,9 +50,11 @@ py_library(
|
||||||
srcs = ["landmark_detection_result.py"],
|
srcs = ["landmark_detection_result.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":rect",
|
":rect",
|
||||||
":classification",
|
|
||||||
":landmark",
|
":landmark",
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""Category data class."""
|
"""Category data class."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
@ -39,10 +39,10 @@ class Category:
|
||||||
category_name: The label of this category object.
|
category_name: The label of this category object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index: int
|
index: Optional[int] = None
|
||||||
score: float
|
score: Optional[float] = None
|
||||||
display_name: str
|
display_name: Optional[str] = None
|
||||||
category_name: str
|
category_name: Optional[str] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _CategoryProto:
|
def to_pb2(self) -> _CategoryProto:
|
||||||
|
|
|
@ -1,121 +0,0 @@
|
||||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Classification data class."""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from mediapipe.framework.formats import classification_pb2
|
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|
||||||
|
|
||||||
_ClassificationProto = classification_pb2.Classification
|
|
||||||
_ClassificationListProto = classification_pb2.ClassificationList
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Classification:
|
|
||||||
"""A classification.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
index: The index of the class in the corresponding label map.
|
|
||||||
score: The probability score for this class.
|
|
||||||
label_name: Label or name of the class.
|
|
||||||
display_name: Optional human-readable string for display purposes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
index: Optional[int] = None
|
|
||||||
score: Optional[float] = None
|
|
||||||
label: Optional[str] = None
|
|
||||||
display_name: Optional[str] = None
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _ClassificationProto:
|
|
||||||
"""Generates a Classification protobuf object."""
|
|
||||||
return _ClassificationProto(
|
|
||||||
index=self.index,
|
|
||||||
score=self.score,
|
|
||||||
label=self.label,
|
|
||||||
display_name=self.display_name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Classification':
|
|
||||||
"""Creates a `Classification` object from the given protobuf object."""
|
|
||||||
return Classification(
|
|
||||||
index=pb2_obj.index,
|
|
||||||
score=pb2_obj.score,
|
|
||||||
label=pb2_obj.label,
|
|
||||||
display_name=pb2_obj.display_name)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Classification):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ClassificationList:
|
|
||||||
"""Represents the classifications for a given classifier.
|
|
||||||
Attributes:
|
|
||||||
classification : A list of `Classification` objects.
|
|
||||||
tensor_index: Optional index of the tensor that produced these
|
|
||||||
classifications.
|
|
||||||
tensor_name: Optional name of the tensor that produced these
|
|
||||||
classifications tensor metadata name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
classifications: List[Classification]
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _ClassificationListProto:
|
|
||||||
"""Generates a ClassificationList protobuf object."""
|
|
||||||
return _ClassificationListProto(
|
|
||||||
classification=[
|
|
||||||
classification.to_pb2()
|
|
||||||
for classification in self.classifications
|
|
||||||
])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls,
|
|
||||||
pb2_obj: _ClassificationListProto
|
|
||||||
) -> 'ClassificationList':
|
|
||||||
"""Creates a `ClassificationList` object from the given protobuf object."""
|
|
||||||
return ClassificationList(
|
|
||||||
classifications=[
|
|
||||||
Classification.create_from_pb2(classification)
|
|
||||||
for classification in pb2_obj.classification
|
|
||||||
])
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, ClassificationList):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
|
@ -20,9 +20,7 @@ from mediapipe.framework.formats import landmark_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_LandmarkProto = landmark_pb2.Landmark
|
_LandmarkProto = landmark_pb2.Landmark
|
||||||
_LandmarkListProto = landmark_pb2.LandmarkList
|
|
||||||
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
|
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
|
||||||
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -89,53 +87,6 @@ class Landmark:
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class LandmarkList:
|
|
||||||
"""Represents the list of landmarks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
landmarks : A list of `Landmark` objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
landmarks: List[Landmark]
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _LandmarkListProto:
|
|
||||||
"""Generates a LandmarkList protobuf object."""
|
|
||||||
return _LandmarkListProto(
|
|
||||||
landmark=[
|
|
||||||
landmark.to_pb2()
|
|
||||||
for landmark in self.landmarks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls,
|
|
||||||
pb2_obj: _LandmarkListProto
|
|
||||||
) -> 'LandmarkList':
|
|
||||||
"""Creates a `LandmarkList` object from the given protobuf object."""
|
|
||||||
return LandmarkList(
|
|
||||||
landmarks=[
|
|
||||||
Landmark.create_from_pb2(landmark)
|
|
||||||
for landmark in pb2_obj.landmark
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, LandmarkList):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NormalizedLandmark:
|
class NormalizedLandmark:
|
||||||
"""A normalized version of above Landmark proto.
|
"""A normalized version of above Landmark proto.
|
||||||
|
@ -201,50 +152,3 @@ class NormalizedLandmark:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class NormalizedLandmarkList:
|
|
||||||
"""Represents the list of normalized landmarks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
landmarks : A list of `Landmark` objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
landmarks: List[NormalizedLandmark]
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _NormalizedLandmarkListProto:
|
|
||||||
"""Generates a NormalizedLandmarkList protobuf object."""
|
|
||||||
return _NormalizedLandmarkListProto(
|
|
||||||
landmark=[
|
|
||||||
landmark.to_pb2()
|
|
||||||
for landmark in self.landmarks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(
|
|
||||||
cls,
|
|
||||||
pb2_obj: _NormalizedLandmarkListProto
|
|
||||||
) -> 'NormalizedLandmarkList':
|
|
||||||
"""Creates a `NormalizedLandmarkList` object from the given protobuf object."""
|
|
||||||
return NormalizedLandmarkList(
|
|
||||||
landmarks=[
|
|
||||||
NormalizedLandmark.create_from_pb2(landmark)
|
|
||||||
for landmark in pb2_obj.landmark
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, NormalizedLandmarkList):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
|
@ -14,19 +14,25 @@
|
||||||
"""Landmarks Detection Result data class."""
|
"""Landmarks Detection Result data class."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||||
|
from mediapipe.framework.formats import classification_pb2
|
||||||
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||||
|
_ClassificationProto = classification_pb2.Classification
|
||||||
|
_ClassificationListProto = classification_pb2.ClassificationList
|
||||||
|
_LandmarkListProto = landmark_pb2.LandmarkList
|
||||||
|
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
|
||||||
_NormalizedRect = rect_module.NormalizedRect
|
_NormalizedRect = rect_module.NormalizedRect
|
||||||
_ClassificationList = classification_module.ClassificationList
|
_Category = category_module.Category
|
||||||
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
|
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||||
_LandmarkList = landmark_module.LandmarkList
|
_Landmark = landmark_module.Landmark
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -34,24 +40,31 @@ class LandmarksDetectionResult:
|
||||||
"""Represents the landmarks detection result.
|
"""Represents the landmarks detection result.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
landmarks : A `NormalizedLandmarkList` object.
|
landmarks : A list of `NormalizedLandmark` objects.
|
||||||
classifications : A `ClassificationList` object.
|
categories : A list of `Category` objects.
|
||||||
world_landmarks : A `LandmarkList` object.
|
world_landmarks : A list of `Landmark` objects.
|
||||||
rect : A `NormalizedRect` object.
|
rect : A `NormalizedRect` object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
landmarks: Optional[_NormalizedLandmarkList]
|
landmarks: Optional[List[_NormalizedLandmark]]
|
||||||
classifications: Optional[_ClassificationList]
|
categories: Optional[List[_Category]]
|
||||||
world_landmarks: Optional[_LandmarkList]
|
world_landmarks: Optional[List[_Landmark]]
|
||||||
rect: _NormalizedRect
|
rect: _NormalizedRect
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
||||||
"""Generates a LandmarksDetectionResult protobuf object."""
|
"""Generates a LandmarksDetectionResult protobuf object."""
|
||||||
return _LandmarksDetectionResultProto(
|
return _LandmarksDetectionResultProto(
|
||||||
landmarks=self.landmarks.to_pb2(),
|
landmarks=_NormalizedLandmarkListProto(landmarks=self.landmarks),
|
||||||
classifications=self.classifications.to_pb2(),
|
classifications=_ClassificationListProto(
|
||||||
world_landmarks=self.world_landmarks.to_pb2(),
|
classification=[
|
||||||
|
_ClassificationProto(
|
||||||
|
index=category.index,
|
||||||
|
score=category.score,
|
||||||
|
label=category.category_name,
|
||||||
|
display_name=category.display_name)
|
||||||
|
for category in self.categories]),
|
||||||
|
world_landmarks=_LandmarkListProto(landmarks=self.world_landmarks),
|
||||||
rect=self.rect.to_pb2())
|
rect=self.rect.to_pb2())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -63,10 +76,18 @@ class LandmarksDetectionResult:
|
||||||
"""Creates a `LandmarksDetectionResult` object from the given protobuf
|
"""Creates a `LandmarksDetectionResult` object from the given protobuf
|
||||||
object."""
|
object."""
|
||||||
return LandmarksDetectionResult(
|
return LandmarksDetectionResult(
|
||||||
landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks),
|
landmarks=[
|
||||||
classifications=_ClassificationList.create_from_pb2(
|
_NormalizedLandmark.create_from_pb2(landmark)
|
||||||
pb2_obj.classifications),
|
for landmark in pb2_obj.landmarks.landmark],
|
||||||
world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks),
|
categories=[category_module.Category(
|
||||||
|
score=classification.score,
|
||||||
|
index=classification.index,
|
||||||
|
category_name=classification.label,
|
||||||
|
display_name=classification.display_name)
|
||||||
|
for classification in pb2_obj.classifications.classification],
|
||||||
|
world_landmarks=[
|
||||||
|
_Landmark.create_from_pb2(landmark)
|
||||||
|
for landmark in pb2_obj.world_landmarks.landmark],
|
||||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
|
|
@ -69,7 +69,7 @@ py_test(
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:rect",
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/components/containers:classification",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:landmark",
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
|
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
|
|
@ -24,7 +24,7 @@ from absl.testing import parameterized
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -36,12 +36,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
|
||||||
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_Rect = rect_module.Rect
|
_Rect = rect_module.Rect
|
||||||
_Classification = classification_module.Classification
|
_Category = category_module.Category
|
||||||
_ClassificationList = classification_module.ClassificationList
|
|
||||||
_Landmark = landmark_module.Landmark
|
_Landmark = landmark_module.Landmark
|
||||||
_LandmarkList = landmark_module.LandmarkList
|
|
||||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||||
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
|
|
||||||
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
|
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_GestureRecognizer = gesture_recognizer.GestureRecognizer
|
_GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||||
|
@ -76,14 +73,11 @@ def _get_expected_gesture_recognition_result(
|
||||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||||
landmarks_detection_result_proto)
|
landmarks_detection_result_proto)
|
||||||
gesture = _ClassificationList(
|
gesture = _Category(category_name=gesture_label, index=gesture_index,
|
||||||
classifications=[
|
|
||||||
_Classification(label=gesture_label, index=gesture_index,
|
|
||||||
display_name='')
|
display_name='')
|
||||||
])
|
|
||||||
return _GestureRecognitionResult(
|
return _GestureRecognitionResult(
|
||||||
gestures=[gesture],
|
gestures=[[gesture]],
|
||||||
handedness=[landmarks_detection_result.classifications],
|
handedness=[landmarks_detection_result.categories],
|
||||||
hand_landmarks=[landmarks_detection_result.landmarks],
|
hand_landmarks=[landmarks_detection_result.landmarks],
|
||||||
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
||||||
|
|
||||||
|
@ -115,25 +109,27 @@ class GestureRecognizerTest(parameterized.TestCase):
|
||||||
self.assertLen(actual_result.handedness, len(expected_result.handedness))
|
self.assertLen(actual_result.handedness, len(expected_result.handedness))
|
||||||
self.assertLen(actual_result.gestures, len(expected_result.gestures))
|
self.assertLen(actual_result.gestures, len(expected_result.gestures))
|
||||||
# Actual landmarks match expected landmarks.
|
# Actual landmarks match expected landmarks.
|
||||||
self.assertLen(actual_result.hand_landmarks[0].landmarks,
|
self.assertLen(actual_result.hand_landmarks[0],
|
||||||
len(expected_result.hand_landmarks[0].landmarks))
|
len(expected_result.hand_landmarks[0]))
|
||||||
actual_landmarks = actual_result.hand_landmarks[0].landmarks
|
actual_landmarks = actual_result.hand_landmarks[0]
|
||||||
expected_landmarks = expected_result.hand_landmarks[0].landmarks
|
expected_landmarks = expected_result.hand_landmarks[0]
|
||||||
for i in range(len(actual_landmarks)):
|
for i in range(len(actual_landmarks)):
|
||||||
self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x,
|
self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x,
|
||||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||||
self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y,
|
self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y,
|
||||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||||
# Actual handedness matches expected handedness.
|
# Actual handedness matches expected handedness.
|
||||||
actual_top_handedness = actual_result.handedness[0].classifications[0]
|
actual_top_handedness = actual_result.handedness[0][0]
|
||||||
expected_top_handedness = expected_result.handedness[0].classifications[0]
|
expected_top_handedness = expected_result.handedness[0][0]
|
||||||
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
||||||
self.assertEqual(actual_top_handedness.label, expected_top_handedness.label)
|
self.assertEqual(actual_top_handedness.category_name,
|
||||||
|
expected_top_handedness.category_name)
|
||||||
# Actual gesture with top score matches expected gesture.
|
# Actual gesture with top score matches expected gesture.
|
||||||
actual_top_gesture = actual_result.gestures[0].classifications[0]
|
actual_top_gesture = actual_result.gestures[0][0]
|
||||||
expected_top_gesture = expected_result.gestures[0].classifications[0]
|
expected_top_gesture = expected_result.gestures[0][0]
|
||||||
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
||||||
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
|
self.assertEqual(actual_top_gesture.category_name,
|
||||||
|
expected_top_gesture.category_name)
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -235,12 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase):
|
||||||
expected_result = _get_expected_gesture_recognition_result(
|
expected_result = _get_expected_gesture_recognition_result(
|
||||||
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
|
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
|
||||||
# Only contains one top scoring gesture.
|
# Only contains one top scoring gesture.
|
||||||
self.assertLen(recognition_result.gestures[0].classifications, 1)
|
self.assertLen(recognition_result.gestures[0], 1)
|
||||||
# Actual gesture with top score matches expected gesture.
|
# Actual gesture with top score matches expected gesture.
|
||||||
actual_top_gesture = recognition_result.gestures[0].classifications[0]
|
actual_top_gesture = recognition_result.gestures[0][0]
|
||||||
expected_top_gesture = expected_result.gestures[0].classifications[0]
|
expected_top_gesture = expected_result.gestures[0][0]
|
||||||
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
|
||||||
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
|
self.assertEqual(actual_top_gesture.category_name,
|
||||||
|
expected_top_gesture.category_name)
|
||||||
|
|
||||||
def test_recognize_succeeds_with_num_hands(self):
|
def test_recognize_succeeds_with_num_hands(self):
|
||||||
# Creates gesture recognizer.
|
# Creates gesture recognizer.
|
||||||
|
|
|
@ -74,7 +74,7 @@ py_library(
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/python/components/containers:classification",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:landmark",
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
|
|
@ -27,7 +27,7 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco
|
||||||
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
|
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
|
||||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
||||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
|
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
|
||||||
from mediapipe.tasks.python.components.containers import classification as classification_module
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
from mediapipe.tasks.python.components.processors import classifier_options
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
@ -80,10 +80,10 @@ class GestureRecognitionResult:
|
||||||
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gestures: List[classification_module.ClassificationList]
|
gestures: List[List[category_module.Category]]
|
||||||
handedness: List[classification_module.ClassificationList]
|
handedness: List[List[category_module.Category]]
|
||||||
hand_landmarks: List[landmark_module.NormalizedLandmarkList]
|
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||||
hand_world_landmarks: List[landmark_module.LandmarkList]
|
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -231,16 +231,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
gesture_recognition_result = GestureRecognitionResult(
|
gesture_recognition_result = GestureRecognitionResult(
|
||||||
[
|
[
|
||||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
[
|
||||||
for gestures in gestures_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in gesture_classifications.classification]
|
||||||
|
for gesture_classifications in gestures_proto_list
|
||||||
], [
|
], [
|
||||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
[
|
||||||
for handedness in handedness_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in handedness_classifications.classification]
|
||||||
|
for handedness_classifications in handedness_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||||
|
for hand_landmark in hand_landmarks.landmark]
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
for hand_landmarks in hand_landmarks_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -314,16 +324,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
return GestureRecognitionResult(
|
return GestureRecognitionResult(
|
||||||
[
|
[
|
||||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
[
|
||||||
for gestures in gestures_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in gesture_classifications.classification]
|
||||||
|
for gesture_classifications in gestures_proto_list
|
||||||
], [
|
], [
|
||||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
[
|
||||||
for handedness in handedness_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in handedness_classifications.classification]
|
||||||
|
for handedness_classifications in handedness_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||||
|
for hand_landmark in hand_landmarks.landmark]
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
for hand_landmarks in hand_landmarks_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -377,16 +397,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
|
||||||
return GestureRecognitionResult(
|
return GestureRecognitionResult(
|
||||||
[
|
[
|
||||||
classification_module.ClassificationList.create_from_pb2(gestures)
|
[
|
||||||
for gestures in gestures_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in gesture_classifications.classification]
|
||||||
|
for gesture_classifications in gestures_proto_list
|
||||||
], [
|
], [
|
||||||
classification_module.ClassificationList.create_from_pb2(handedness)
|
[
|
||||||
for handedness in handedness_proto_list
|
category_module.Category(
|
||||||
|
index=gesture.index, score=gesture.score,
|
||||||
|
display_name=gesture.display_name, category_name=gesture.label)
|
||||||
|
for gesture in handedness_classifications.classification]
|
||||||
|
for handedness_classifications in handedness_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
|
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||||
|
for hand_landmark in hand_landmarks.landmark]
|
||||||
for hand_landmarks in hand_landmarks_proto_list
|
for hand_landmarks in hand_landmarks_proto_list
|
||||||
], [
|
], [
|
||||||
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
|
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark]
|
||||||
for hand_world_landmarks in hand_world_landmarks_proto_list
|
for hand_world_landmarks in hand_world_landmarks_proto_list
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user