mediapipe/mediapipe/tasks/python/components/containers/detections.py

171 lines
5.7 KiB
Python

# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detections data class."""
import dataclasses
from typing import Any, List
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_DetectionListProto = detection_pb2.DetectionList
_DetectionProto = detection_pb2.Detection
_LocationDataProto = location_data_pb2.LocationData
@dataclasses.dataclass
class Detection:
"""Represents one detected object in the object detector's results.
Attributes:
bounding_box: A BoundingBox object.
categories: A list of Category objects.
keypoints: A list of NormalizedKeypoint objects.
"""
bounding_box: bounding_box_module.BoundingBox = None
categories: List[category_module.Category] = None
keypoints: List[keypoint_module.NormalizedKeypoint] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionProto:
"""Generates a Detection protobuf object."""
labels = []
label_ids = []
scores = []
display_names = []
relative_keypoints = []
for category in self.categories:
scores.append(category.score)
if category.index:
label_ids.append(category.index)
if category.category_name:
labels.append(category.category_name)
if category.display_name:
display_names.append(category.display_name)
if self.keypoints:
for keypoint in self.keypoints:
relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
if keypoint.x:
relative_keypoint_proto.x = keypoint.x
if keypoint.y:
relative_keypoint_proto.y = keypoint.y
if keypoint.label:
relative_keypoint_proto.keypoint_label = keypoint.label
if keypoint.score:
relative_keypoint_proto.score = keypoint.score
relative_keypoints.append(relative_keypoint_proto)
return _DetectionProto(
label=labels,
label_id=label_ids,
score=scores,
display_name=display_names,
location_data=_LocationDataProto(
format=_LocationDataProto.Format.BOUNDING_BOX,
bounding_box=self.bounding_box.to_pb2(),
relative_keypoints=relative_keypoints))
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
"""Creates a `Detection` object from the given protobuf object."""
categories = []
keypoints = []
for idx, score in enumerate(pb2_obj.score):
categories.append(
category_module.Category(
score=score,
index=pb2_obj.label_id[idx]
if idx < len(pb2_obj.label_id) else None,
category_name=pb2_obj.label[idx]
if idx < len(pb2_obj.label) else None,
display_name=pb2_obj.display_name[idx]
if idx < len(pb2_obj.display_name) else None))
if pb2_obj.location_data.relative_keypoints:
for idx in range(len(pb2_obj.location_data.relative_keypoints)):
keypoints.append(
keypoint_module.NormalizedKeypoint(
x=pb2_obj.location_data.relative_keypoints[idx].x,
y=pb2_obj.location_data.relative_keypoints[idx].y,
label=pb2_obj.location_data.relative_keypoints[idx].keypoint_label,
score=pb2_obj.location_data.relative_keypoints[idx].score))
return Detection(
bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
pb2_obj.location_data.bounding_box),
categories=categories,
keypoints=keypoints)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Detection):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class DetectionResult:
"""Represents the list of detected objects.
Attributes:
detections: A list of `Detection` objects.
"""
detections: List[Detection]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionListProto:
"""Generates a DetectionList protobuf object."""
return _DetectionListProto(
detection=[detection.to_pb2() for detection in self.detections])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult':
"""Creates a `DetectionResult` object from the given protobuf object."""
return DetectionResult(detections=[
Detection.create_from_pb2(detection) for detection in pb2_obj.detection
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, DetectionResult):
return False
return self.to_pb2().__eq__(other.to_pb2())