Removed classification proto to use the existing category dataclass instead and removed NormalizedLandmarkList and LandmarkList dataclasses

This commit is contained in:
kinaryml 2022-10-30 08:23:14 -07:00
parent 0f7c5d5e90
commit f62cfd1690
9 changed files with 127 additions and 303 deletions

View File

@ -36,15 +36,6 @@ py_library(
],
)
py_library(
name = "classification",
srcs = ["classification.py"],
deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "landmark",
srcs = ["landmark.py"],
@ -59,9 +50,11 @@ py_library(
srcs = ["landmark_detection_result.py"],
deps = [
":rect",
":classification",
":landmark",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -14,7 +14,7 @@
"""Category data class."""
import dataclasses
from typing import Any
from typing import Any, Optional
from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -39,10 +39,10 @@ class Category:
category_name: The label of this category object.
"""
index: int
score: float
display_name: str
category_name: str
index: Optional[int] = None
score: Optional[float] = None
display_name: Optional[str] = None
category_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _CategoryProto:

View File

@ -1,121 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
@dataclasses.dataclass
class Classification:
"""A classification.
Attributes:
index: The index of the class in the corresponding label map.
score: The probability score for this class.
label_name: Label or name of the class.
display_name: Optional human-readable string for display purposes.
"""
index: Optional[int] = None
score: Optional[float] = None
label: Optional[str] = None
display_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationProto:
"""Generates a Classification protobuf object."""
return _ClassificationProto(
index=self.index,
score=self.score,
label=self.label,
display_name=self.display_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Classification':
"""Creates a `Classification` object from the given protobuf object."""
return Classification(
index=pb2_obj.index,
score=pb2_obj.score,
label=pb2_obj.label,
display_name=pb2_obj.display_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Classification):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class ClassificationList:
"""Represents the classifications for a given classifier.
Attributes:
classification : A list of `Classification` objects.
tensor_index: Optional index of the tensor that produced these
classifications.
tensor_name: Optional name of the tensor that produced these
classifications tensor metadata name.
"""
classifications: List[Classification]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationListProto:
"""Generates a ClassificationList protobuf object."""
return _ClassificationListProto(
classification=[
classification.to_pb2()
for classification in self.classifications
])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _ClassificationListProto
) -> 'ClassificationList':
"""Creates a `ClassificationList` object from the given protobuf object."""
return ClassificationList(
classifications=[
Classification.create_from_pb2(classification)
for classification in pb2_obj.classification
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationList):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -20,9 +20,7 @@ from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarkProto = landmark_pb2.Landmark
_LandmarkListProto = landmark_pb2.LandmarkList
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
@dataclasses.dataclass
@ -89,53 +87,6 @@ class Landmark:
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class LandmarkList:
"""Represents the list of landmarks.
Attributes:
landmarks : A list of `Landmark` objects.
"""
landmarks: List[Landmark]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarkListProto:
"""Generates a LandmarkList protobuf object."""
return _LandmarkListProto(
landmark=[
landmark.to_pb2()
for landmark in self.landmarks
]
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _LandmarkListProto
) -> 'LandmarkList':
"""Creates a `LandmarkList` object from the given protobuf object."""
return LandmarkList(
landmarks=[
Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmark
]
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, LandmarkList):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class NormalizedLandmark:
"""A normalized version of above Landmark proto.
@ -201,50 +152,3 @@ class NormalizedLandmark:
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class NormalizedLandmarkList:
"""Represents the list of normalized landmarks.
Attributes:
landmarks : A list of `Landmark` objects.
"""
landmarks: List[NormalizedLandmark]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _NormalizedLandmarkListProto:
"""Generates a NormalizedLandmarkList protobuf object."""
return _NormalizedLandmarkListProto(
landmark=[
landmark.to_pb2()
for landmark in self.landmarks
]
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _NormalizedLandmarkListProto
) -> 'NormalizedLandmarkList':
"""Creates a `NormalizedLandmarkList` object from the given protobuf object."""
return NormalizedLandmarkList(
landmarks=[
NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmark
]
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, NormalizedLandmarkList):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -14,19 +14,25 @@
"""Landmarks Detection Result data class."""
import dataclasses
from typing import Any, Optional
from typing import Any, Optional, List
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_LandmarkListProto = landmark_pb2.LandmarkList
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
_NormalizedRect = rect_module.NormalizedRect
_ClassificationList = classification_module.ClassificationList
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
_LandmarkList = landmark_module.LandmarkList
_Category = category_module.Category
_NormalizedLandmark = landmark_module.NormalizedLandmark
_Landmark = landmark_module.Landmark
@dataclasses.dataclass
@ -34,25 +40,32 @@ class LandmarksDetectionResult:
"""Represents the landmarks detection result.
Attributes:
landmarks : A `NormalizedLandmarkList` object.
classifications : A `ClassificationList` object.
world_landmarks : A `LandmarkList` object.
landmarks : A list of `NormalizedLandmark` objects.
categories : A list of `Category` objects.
world_landmarks : A list of `Landmark` objects.
rect : A `NormalizedRect` object.
"""
landmarks: Optional[_NormalizedLandmarkList]
classifications: Optional[_ClassificationList]
world_landmarks: Optional[_LandmarkList]
landmarks: Optional[List[_NormalizedLandmark]]
categories: Optional[List[_Category]]
world_landmarks: Optional[List[_Landmark]]
rect: _NormalizedRect
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarksDetectionResultProto:
"""Generates a LandmarksDetectionResult protobuf object."""
return _LandmarksDetectionResultProto(
landmarks=self.landmarks.to_pb2(),
classifications=self.classifications.to_pb2(),
world_landmarks=self.world_landmarks.to_pb2(),
rect=self.rect.to_pb2())
landmarks=_NormalizedLandmarkListProto(landmarks=self.landmarks),
classifications=_ClassificationListProto(
classification=[
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name)
for category in self.categories]),
world_landmarks=_LandmarkListProto(landmarks=self.world_landmarks),
rect=self.rect.to_pb2())
@classmethod
@doc_controls.do_not_generate_docs
@ -63,11 +76,19 @@ class LandmarksDetectionResult:
"""Creates a `LandmarksDetectionResult` object from the given protobuf
object."""
return LandmarksDetectionResult(
landmarks=_NormalizedLandmarkList.create_from_pb2(pb2_obj.landmarks),
classifications=_ClassificationList.create_from_pb2(
pb2_obj.classifications),
world_landmarks=_LandmarkList.create_from_pb2(pb2_obj.world_landmarks),
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
landmarks=[
_NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmarks.landmark],
categories=[category_module.Category(
score=classification.score,
index=classification.index,
category_name=classification.label,
display_name=classification.display_name)
for classification in pb2_obj.classifications.classification],
world_landmarks=[
_Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark],
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.

View File

@ -69,7 +69,7 @@ py_test(
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/containers:classification",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
"//mediapipe/tasks/python/core:base_options",

View File

@ -24,7 +24,7 @@ from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
@ -36,12 +36,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_BaseOptions = base_options_module.BaseOptions
_Rect = rect_module.Rect
_Classification = classification_module.Classification
_ClassificationList = classification_module.ClassificationList
_Category = category_module.Category
_Landmark = landmark_module.Landmark
_LandmarkList = landmark_module.LandmarkList
_NormalizedLandmark = landmark_module.NormalizedLandmark
_NormalizedLandmarkList = landmark_module.NormalizedLandmarkList
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
_Image = image_module.Image
_GestureRecognizer = gesture_recognizer.GestureRecognizer
@ -76,14 +73,11 @@ def _get_expected_gesture_recognition_result(
text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto)
gesture = _ClassificationList(
classifications=[
_Classification(label=gesture_label, index=gesture_index,
display_name='')
])
gesture = _Category(category_name=gesture_label, index=gesture_index,
display_name='')
return _GestureRecognitionResult(
gestures=[gesture],
handedness=[landmarks_detection_result.classifications],
gestures=[[gesture]],
handedness=[landmarks_detection_result.categories],
hand_landmarks=[landmarks_detection_result.landmarks],
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
@ -115,25 +109,27 @@ class GestureRecognizerTest(parameterized.TestCase):
self.assertLen(actual_result.handedness, len(expected_result.handedness))
self.assertLen(actual_result.gestures, len(expected_result.gestures))
# Actual landmarks match expected landmarks.
self.assertLen(actual_result.hand_landmarks[0].landmarks,
len(expected_result.hand_landmarks[0].landmarks))
actual_landmarks = actual_result.hand_landmarks[0].landmarks
expected_landmarks = expected_result.hand_landmarks[0].landmarks
self.assertLen(actual_result.hand_landmarks[0],
len(expected_result.hand_landmarks[0]))
actual_landmarks = actual_result.hand_landmarks[0]
expected_landmarks = expected_result.hand_landmarks[0]
for i in range(len(actual_landmarks)):
self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x,
delta=_LANDMARKS_ERROR_TOLERANCE)
self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y,
delta=_LANDMARKS_ERROR_TOLERANCE)
# Actual handedness matches expected handedness.
actual_top_handedness = actual_result.handedness[0].classifications[0]
expected_top_handedness = expected_result.handedness[0].classifications[0]
actual_top_handedness = actual_result.handedness[0][0]
expected_top_handedness = expected_result.handedness[0][0]
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
self.assertEqual(actual_top_handedness.label, expected_top_handedness.label)
self.assertEqual(actual_top_handedness.category_name,
expected_top_handedness.category_name)
# Actual gesture with top score matches expected gesture.
actual_top_gesture = actual_result.gestures[0].classifications[0]
expected_top_gesture = expected_result.gestures[0].classifications[0]
actual_top_gesture = actual_result.gestures[0][0]
expected_top_gesture = expected_result.gestures[0][0]
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
self.assertEqual(actual_top_gesture.category_name,
expected_top_gesture.category_name)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
@ -235,12 +231,13 @@ class GestureRecognizerTest(parameterized.TestCase):
expected_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
# Only contains one top scoring gesture.
self.assertLen(recognition_result.gestures[0].classifications, 1)
self.assertLen(recognition_result.gestures[0], 1)
# Actual gesture with top score matches expected gesture.
actual_top_gesture = recognition_result.gestures[0].classifications[0]
expected_top_gesture = expected_result.gestures[0].classifications[0]
actual_top_gesture = recognition_result.gestures[0][0]
expected_top_gesture = expected_result.gestures[0][0]
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
self.assertEqual(actual_top_gesture.category_name,
expected_top_gesture.category_name)
def test_recognize_succeeds_with_num_hands(self):
# Creates gesture recognizer.

View File

@ -74,7 +74,7 @@ py_library(
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",

View File

@ -27,7 +27,7 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
@ -80,10 +80,10 @@ class GestureRecognitionResult:
hand_world_landmarks: Detected hand landmarks in world coordinates.
"""
gestures: List[classification_module.ClassificationList]
handedness: List[classification_module.ClassificationList]
hand_landmarks: List[landmark_module.NormalizedLandmarkList]
hand_world_landmarks: List[landmark_module.LandmarkList]
gestures: List[List[category_module.Category]]
handedness: List[List[category_module.Category]]
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
hand_world_landmarks: List[List[landmark_module.Landmark]]
@dataclasses.dataclass
@ -231,16 +231,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
gesture_recognition_result = GestureRecognitionResult(
[
classification_module.ClassificationList.create_from_pb2(gestures)
for gestures in gestures_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
classification_module.ClassificationList.create_from_pb2(handedness)
for handedness in handedness_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
@ -314,16 +324,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
return GestureRecognitionResult(
[
classification_module.ClassificationList.create_from_pb2(gestures)
for gestures in gestures_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
classification_module.ClassificationList.create_from_pb2(handedness)
for handedness in handedness_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)
@ -377,16 +397,26 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
return GestureRecognitionResult(
[
classification_module.ClassificationList.create_from_pb2(gestures)
for gestures in gestures_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in gesture_classifications.classification]
for gesture_classifications in gestures_proto_list
], [
classification_module.ClassificationList.create_from_pb2(handedness)
for handedness in handedness_proto_list
[
category_module.Category(
index=gesture.index, score=gesture.score,
display_name=gesture.display_name, category_name=gesture.label)
for gesture in handedness_classifications.classification]
for handedness_classifications in handedness_proto_list
], [
landmark_module.NormalizedLandmarkList.create_from_pb2(hand_landmarks)
[landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark]
for hand_landmarks in hand_landmarks_proto_list
], [
landmark_module.LandmarkList.create_from_pb2(hand_world_landmarks)
[landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark]
for hand_world_landmarks in hand_world_landmarks_proto_list
]
)