Updated landmark_detection_result

This commit is contained in:
kinaryml 2022-11-08 10:37:37 -08:00
parent 9c9519eeb8
commit e8d771baf3

View File

@ -54,23 +54,19 @@ class LandmarksDetectionResult:
"""Generates a LandmarksDetectionResult protobuf object."""
landmarks = _NormalizedLandmarkListProto()
landmarks.landmark.extend([
landmark.to_pb2() for landmark in self.landmarks
])
classifications = _ClassificationListProto()
classifications.classification.extend([
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name) for category in self.categories
])
world_landmarks = _LandmarkListProto()
world_landmarks.landmark.extend([
world_landmark.to_pb2() for world_landmark in self.world_landmarks
])
for landmark in self.landmarks:
landmarks.landmark.append(landmark.to_pb2())
for category in self.categories:
classifications.classification.append(
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name))
return _LandmarksDetectionResultProto(
landmarks=landmarks,
@ -86,6 +82,9 @@ class LandmarksDetectionResult:
"""Creates a `LandmarksDetectionResult` object from the given protobuf object.
"""
categories = []
landmarks = []
world_landmarks = []
for classification in pb2_obj.classifications.classification:
categories.append(
category_module.Category(
@ -93,14 +92,15 @@ class LandmarksDetectionResult:
index=classification.index,
category_name=classification.label,
display_name=classification.display_name))
for landmark in pb2_obj.landmarks.landmark:
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
for landmark in pb2_obj.world_landmarks.landmark:
world_landmarks.append(_Landmark.create_from_pb2(landmark)
)
return LandmarksDetectionResult(
landmarks=[
_NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmarks.landmark
],
landmarks=landmarks,
categories=categories,
world_landmarks=[
_Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark
],
world_landmarks=world_landmarks,
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))