This commit is contained in:
Lian Hui Lui 2023-12-17 18:41:09 +00:00 committed by GitHub
commit fc17266fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,7 +58,6 @@ def _normalized_to_pixel_coordinates(
if not (is_valid_normalized_value(normalized_x) and if not (is_valid_normalized_value(normalized_x) and
is_valid_normalized_value(normalized_y)): is_valid_normalized_value(normalized_y)):
# TODO: Draw coordinates even if it's outside of the image bounds.
return None return None
x_px = min(math.floor(normalized_x * image_width), image_width - 1) x_px = min(math.floor(normalized_x * image_width), image_width - 1)
y_px = min(math.floor(normalized_y * image_height), image_height - 1) y_px = min(math.floor(normalized_y * image_height), image_height - 1)
@ -85,6 +84,20 @@ def draw_detection(
a) If the input image is not three channel BGR. a) If the input image is not three channel BGR.
b) If the location data is not relative data. b) If the location data is not relative data.
""" """
def _normalized_to_pixel_coordinates_unsafe(
normalized_x: float, normalized_y: float, image_width: int,
image_height: int) -> Tuple[int, int]:
"""Converts normalized coordinates to pixel coordinates.
Normalized coordinates can be out of image bounds, therefore the result
may also be outside image bounds."""
x_px = math.floor(normalized_x * image_width)
y_px = math.floor(normalized_y * image_height)
return x_px, y_px
if not detection.location_data: if not detection.location_data:
return return
if image.shape[2] != _BGR_CHANNELS: if image.shape[2] != _BGR_CHANNELS:
@ -97,7 +110,7 @@ def draw_detection(
'LocationData must be relative for this drawing funtion to work.') 'LocationData must be relative for this drawing funtion to work.')
# Draws keypoints. # Draws keypoints.
for keypoint in location.relative_keypoints: for keypoint in location.relative_keypoints:
keypoint_px = _normalized_to_pixel_coordinates(keypoint.x, keypoint.y, keypoint_px = _normalized_to_pixel_coordinates_unsafe(keypoint.x, keypoint.y,
image_cols, image_rows) image_cols, image_rows)
cv2.circle(image, keypoint_px, keypoint_drawing_spec.circle_radius, cv2.circle(image, keypoint_px, keypoint_drawing_spec.circle_radius,
keypoint_drawing_spec.color, keypoint_drawing_spec.thickness) keypoint_drawing_spec.color, keypoint_drawing_spec.thickness)
@ -105,10 +118,10 @@ def draw_detection(
if not location.HasField('relative_bounding_box'): if not location.HasField('relative_bounding_box'):
return return
relative_bounding_box = location.relative_bounding_box relative_bounding_box = location.relative_bounding_box
rect_start_point = _normalized_to_pixel_coordinates( rect_start_point = _normalized_to_pixel_coordinates_unsafe(
relative_bounding_box.xmin, relative_bounding_box.ymin, image_cols, relative_bounding_box.xmin, relative_bounding_box.ymin, image_cols,
image_rows) image_rows)
rect_end_point = _normalized_to_pixel_coordinates( rect_end_point = _normalized_to_pixel_coordinates_unsafe(
relative_bounding_box.xmin + relative_bounding_box.width, relative_bounding_box.xmin + relative_bounding_box.width,
relative_bounding_box.ymin + relative_bounding_box.height, image_cols, relative_bounding_box.ymin + relative_bounding_box.height, image_cols,
image_rows) image_rows)