171 lines
5.7 KiB
Python
171 lines
5.7 KiB
Python
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Detections data class."""
|
|
|
|
import dataclasses
|
|
from typing import Any, List
|
|
|
|
from mediapipe.framework.formats import detection_pb2
|
|
from mediapipe.framework.formats import location_data_pb2
|
|
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
|
|
from mediapipe.tasks.python.components.containers import category as category_module
|
|
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
|
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
|
|
|
_DetectionListProto = detection_pb2.DetectionList
|
|
_DetectionProto = detection_pb2.Detection
|
|
_LocationDataProto = location_data_pb2.LocationData
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Detection:
|
|
"""Represents one detected object in the object detector's results.
|
|
|
|
Attributes:
|
|
bounding_box: A BoundingBox object.
|
|
categories: A list of Category objects.
|
|
keypoints: A list of NormalizedKeypoint objects.
|
|
"""
|
|
|
|
bounding_box: bounding_box_module.BoundingBox = None
|
|
categories: List[category_module.Category] = None
|
|
keypoints: List[keypoint_module.NormalizedKeypoint] = None
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def to_pb2(self) -> _DetectionProto:
|
|
"""Generates a Detection protobuf object."""
|
|
labels = []
|
|
label_ids = []
|
|
scores = []
|
|
display_names = []
|
|
relative_keypoints = []
|
|
|
|
for category in self.categories:
|
|
scores.append(category.score)
|
|
if category.index:
|
|
label_ids.append(category.index)
|
|
if category.category_name:
|
|
labels.append(category.category_name)
|
|
if category.display_name:
|
|
display_names.append(category.display_name)
|
|
|
|
if self.keypoints:
|
|
for keypoint in self.keypoints:
|
|
relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
|
|
if keypoint.x:
|
|
relative_keypoint_proto.x = keypoint.x
|
|
if keypoint.y:
|
|
relative_keypoint_proto.y = keypoint.y
|
|
if keypoint.label:
|
|
relative_keypoint_proto.keypoint_label = keypoint.label
|
|
if keypoint.score:
|
|
relative_keypoint_proto.score = keypoint.score
|
|
relative_keypoints.append(relative_keypoint_proto)
|
|
|
|
return _DetectionProto(
|
|
label=labels,
|
|
label_id=label_ids,
|
|
score=scores,
|
|
display_name=display_names,
|
|
location_data=_LocationDataProto(
|
|
format=_LocationDataProto.Format.BOUNDING_BOX,
|
|
bounding_box=self.bounding_box.to_pb2(),
|
|
relative_keypoints=relative_keypoints))
|
|
|
|
@classmethod
|
|
@doc_controls.do_not_generate_docs
|
|
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
|
|
"""Creates a `Detection` object from the given protobuf object."""
|
|
categories = []
|
|
keypoints = []
|
|
|
|
for idx, score in enumerate(pb2_obj.score):
|
|
categories.append(
|
|
category_module.Category(
|
|
score=score,
|
|
index=pb2_obj.label_id[idx]
|
|
if idx < len(pb2_obj.label_id) else None,
|
|
category_name=pb2_obj.label[idx]
|
|
if idx < len(pb2_obj.label) else None,
|
|
display_name=pb2_obj.display_name[idx]
|
|
if idx < len(pb2_obj.display_name) else None))
|
|
|
|
if pb2_obj.location_data.relative_keypoints:
|
|
for idx in range(len(pb2_obj.location_data.relative_keypoints)):
|
|
keypoints.append(
|
|
keypoint_module.NormalizedKeypoint(
|
|
x=pb2_obj.location_data.relative_keypoints[idx].x,
|
|
y=pb2_obj.location_data.relative_keypoints[idx].y,
|
|
label=pb2_obj.location_data.relative_keypoints[idx].keypoint_label,
|
|
score=pb2_obj.location_data.relative_keypoints[idx].score))
|
|
|
|
return Detection(
|
|
bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
|
|
pb2_obj.location_data.bounding_box),
|
|
categories=categories,
|
|
keypoints=keypoints)
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
"""Checks if this object is equal to the given object.
|
|
|
|
Args:
|
|
other: The object to be compared with.
|
|
|
|
Returns:
|
|
True if the objects are equal.
|
|
"""
|
|
if not isinstance(other, Detection):
|
|
return False
|
|
|
|
return self.to_pb2().__eq__(other.to_pb2())
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DetectionResult:
|
|
"""Represents the list of detected objects.
|
|
|
|
Attributes:
|
|
detections: A list of `Detection` objects.
|
|
"""
|
|
|
|
detections: List[Detection]
|
|
|
|
@doc_controls.do_not_generate_docs
|
|
def to_pb2(self) -> _DetectionListProto:
|
|
"""Generates a DetectionList protobuf object."""
|
|
return _DetectionListProto(
|
|
detection=[detection.to_pb2() for detection in self.detections])
|
|
|
|
@classmethod
|
|
@doc_controls.do_not_generate_docs
|
|
def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult':
|
|
"""Creates a `DetectionResult` object from the given protobuf object."""
|
|
return DetectionResult(detections=[
|
|
Detection.create_from_pb2(detection) for detection in pb2_obj.detection
|
|
])
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
"""Checks if this object is equal to the given object.
|
|
|
|
Args:
|
|
other: The object to be compared with.
|
|
|
|
Returns:
|
|
True if the objects are equal.
|
|
"""
|
|
if not isinstance(other, DetectionResult):
|
|
return False
|
|
|
|
return self.to_pb2().__eq__(other.to_pb2())
|