Updated landmark_detection_result

This commit is contained in:
kinaryml 2022-11-08 10:37:37 -08:00
parent 9c9519eeb8
commit e8d771baf3

View File

@ -54,23 +54,19 @@ class LandmarksDetectionResult:
"""Generates a LandmarksDetectionResult protobuf object.""" """Generates a LandmarksDetectionResult protobuf object."""
landmarks = _NormalizedLandmarkListProto() landmarks = _NormalizedLandmarkListProto()
landmarks.landmark.extend([
landmark.to_pb2() for landmark in self.landmarks
])
classifications = _ClassificationListProto() classifications = _ClassificationListProto()
classifications.classification.extend([
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name) for category in self.categories
])
world_landmarks = _LandmarkListProto() world_landmarks = _LandmarkListProto()
world_landmarks.landmark.extend([
world_landmark.to_pb2() for world_landmark in self.world_landmarks for landmark in self.landmarks:
]) landmarks.landmark.append(landmark.to_pb2())
for category in self.categories:
classifications.classification.append(
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name))
return _LandmarksDetectionResultProto( return _LandmarksDetectionResultProto(
landmarks=landmarks, landmarks=landmarks,
@ -86,6 +82,9 @@ class LandmarksDetectionResult:
"""Creates a `LandmarksDetectionResult` object from the given protobuf object. """Creates a `LandmarksDetectionResult` object from the given protobuf object.
""" """
categories = [] categories = []
landmarks = []
world_landmarks = []
for classification in pb2_obj.classifications.classification: for classification in pb2_obj.classifications.classification:
categories.append( categories.append(
category_module.Category( category_module.Category(
@ -93,14 +92,15 @@ class LandmarksDetectionResult:
index=classification.index, index=classification.index,
category_name=classification.label, category_name=classification.label,
display_name=classification.display_name)) display_name=classification.display_name))
for landmark in pb2_obj.landmarks.landmark:
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
for landmark in pb2_obj.world_landmarks.landmark:
world_landmarks.append(_Landmark.create_from_pb2(landmark)
)
return LandmarksDetectionResult( return LandmarksDetectionResult(
landmarks=[ landmarks=landmarks,
_NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmarks.landmark
],
categories=categories, categories=categories,
world_landmarks=[ world_landmarks=world_landmarks,
_Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark
],
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))