171 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Detections data class."""
 | |
| 
 | |
| import dataclasses
 | |
| from typing import Any, List
 | |
| 
 | |
| from mediapipe.framework.formats import detection_pb2
 | |
| from mediapipe.framework.formats import location_data_pb2
 | |
| from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
 | |
| from mediapipe.tasks.python.components.containers import category as category_module
 | |
| from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
 | |
| from mediapipe.tasks.python.core.optional_dependencies import doc_controls
 | |
| 
 | |
| _DetectionListProto = detection_pb2.DetectionList
 | |
| _DetectionProto = detection_pb2.Detection
 | |
| _LocationDataProto = location_data_pb2.LocationData
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class Detection:
 | |
|   """Represents one detected object in the object detector's results.
 | |
| 
 | |
|   Attributes:
 | |
|     bounding_box: A BoundingBox object.
 | |
|     categories: A list of Category objects.
 | |
|     keypoints: A list of NormalizedKeypoint objects.
 | |
|   """
 | |
| 
 | |
|   bounding_box: bounding_box_module.BoundingBox = None
 | |
|   categories: List[category_module.Category] = None
 | |
|   keypoints: List[keypoint_module.NormalizedKeypoint] = None
 | |
| 
 | |
|   @doc_controls.do_not_generate_docs
 | |
|   def to_pb2(self) -> _DetectionProto:
 | |
|     """Generates a Detection protobuf object."""
 | |
|     labels = []
 | |
|     label_ids = []
 | |
|     scores = []
 | |
|     display_names = []
 | |
|     relative_keypoints = []
 | |
| 
 | |
|     for category in self.categories:
 | |
|       scores.append(category.score)
 | |
|       if category.index:
 | |
|         label_ids.append(category.index)
 | |
|       if category.category_name:
 | |
|         labels.append(category.category_name)
 | |
|       if category.display_name:
 | |
|         display_names.append(category.display_name)
 | |
| 
 | |
|     if self.keypoints:
 | |
|       for keypoint in self.keypoints:
 | |
|         relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
 | |
|         if keypoint.x:
 | |
|           relative_keypoint_proto.x = keypoint.x
 | |
|         if keypoint.y:
 | |
|           relative_keypoint_proto.y = keypoint.y
 | |
|         if keypoint.label:
 | |
|           relative_keypoint_proto.keypoint_label = keypoint.label
 | |
|         if keypoint.score:
 | |
|           relative_keypoint_proto.score = keypoint.score
 | |
|         relative_keypoints.append(relative_keypoint_proto)
 | |
| 
 | |
|     return _DetectionProto(
 | |
|         label=labels,
 | |
|         label_id=label_ids,
 | |
|         score=scores,
 | |
|         display_name=display_names,
 | |
|         location_data=_LocationDataProto(
 | |
|             format=_LocationDataProto.Format.BOUNDING_BOX,
 | |
|             bounding_box=self.bounding_box.to_pb2(),
 | |
|             relative_keypoints=relative_keypoints))
 | |
| 
 | |
|   @classmethod
 | |
|   @doc_controls.do_not_generate_docs
 | |
|   def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
 | |
|     """Creates a `Detection` object from the given protobuf object."""
 | |
|     categories = []
 | |
|     keypoints = []
 | |
| 
 | |
|     for idx, score in enumerate(pb2_obj.score):
 | |
|       categories.append(
 | |
|           category_module.Category(
 | |
|               score=score,
 | |
|               index=pb2_obj.label_id[idx]
 | |
|               if idx < len(pb2_obj.label_id) else None,
 | |
|               category_name=pb2_obj.label[idx]
 | |
|               if idx < len(pb2_obj.label) else None,
 | |
|               display_name=pb2_obj.display_name[idx]
 | |
|               if idx < len(pb2_obj.display_name) else None))
 | |
| 
 | |
|     if pb2_obj.location_data.relative_keypoints:
 | |
|       for idx in range(len(pb2_obj.location_data.relative_keypoints)):
 | |
|         keypoints.append(
 | |
|           keypoint_module.NormalizedKeypoint(
 | |
|             x=pb2_obj.location_data.relative_keypoints[idx].x,
 | |
|             y=pb2_obj.location_data.relative_keypoints[idx].y,
 | |
|             label=pb2_obj.location_data.relative_keypoints[idx].keypoint_label,
 | |
|             score=pb2_obj.location_data.relative_keypoints[idx].score))
 | |
| 
 | |
|     return Detection(
 | |
|         bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
 | |
|             pb2_obj.location_data.bounding_box),
 | |
|         categories=categories,
 | |
|         keypoints=keypoints)
 | |
| 
 | |
|   def __eq__(self, other: Any) -> bool:
 | |
|     """Checks if this object is equal to the given object.
 | |
| 
 | |
|     Args:
 | |
|       other: The object to be compared with.
 | |
| 
 | |
|     Returns:
 | |
|       True if the objects are equal.
 | |
|     """
 | |
|     if not isinstance(other, Detection):
 | |
|       return False
 | |
| 
 | |
|     return self.to_pb2().__eq__(other.to_pb2())
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class DetectionResult:
 | |
|   """Represents the list of detected objects.
 | |
| 
 | |
|   Attributes:
 | |
|     detections: A list of `Detection` objects.
 | |
|   """
 | |
| 
 | |
|   detections: List[Detection]
 | |
| 
 | |
|   @doc_controls.do_not_generate_docs
 | |
|   def to_pb2(self) -> _DetectionListProto:
 | |
|     """Generates a DetectionList protobuf object."""
 | |
|     return _DetectionListProto(
 | |
|         detection=[detection.to_pb2() for detection in self.detections])
 | |
| 
 | |
|   @classmethod
 | |
|   @doc_controls.do_not_generate_docs
 | |
|   def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult':
 | |
|     """Creates a `DetectionResult` object from the given protobuf object."""
 | |
|     return DetectionResult(detections=[
 | |
|         Detection.create_from_pb2(detection) for detection in pb2_obj.detection
 | |
|     ])
 | |
| 
 | |
|   def __eq__(self, other: Any) -> bool:
 | |
|     """Checks if this object is equal to the given object.
 | |
| 
 | |
|     Args:
 | |
|       other: The object to be compared with.
 | |
| 
 | |
|     Returns:
 | |
|       True if the objects are equal.
 | |
|     """
 | |
|     if not isinstance(other, DetectionResult):
 | |
|       return False
 | |
| 
 | |
|     return self.to_pb2().__eq__(other.to_pb2())
 |