mediapipe/mediapipe2/python/solutions/drawing_utils.py
2021-06-10 23:01:19 +00:00

228 lines
9.2 KiB
Python

# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe solution drawing utils."""
import math
from typing import List, Optional, Tuple, Union
import cv2
import dataclasses
import numpy as np
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import location_data_pb2
from mediapipe.framework.formats import landmark_pb2
PRESENCE_THRESHOLD = 0.5
RGB_CHANNELS = 3
RED_COLOR = (0, 0, 255)
GREEN_COLOR = (0, 128, 0)
BLUE_COLOR = (255, 0, 0)
VISIBILITY_THRESHOLD = 0.5
@dataclasses.dataclass
class DrawingSpec:
# Color for drawing the annotation. Default to the green color.
color: Tuple[int, int, int] = (0, 255, 0)
# Thickness for drawing the annotation. Default to 2 pixels.
thickness: int = 2
# Circle radius. Default to 2 pixels.
circle_radius: int = 2
def _normalized_to_pixel_coordinates(
normalized_x: float, normalized_y: float, image_width: int,
image_height: int) -> Union[None, Tuple[int, int]]:
"""Converts normalized value pair to pixel coordinates."""
# Checks if the float value is between 0 and 1.
def is_valid_normalized_value(value: float) -> bool:
return (value > 0 or math.isclose(0, value)) and (value < 1 or
math.isclose(1, value))
if not (is_valid_normalized_value(normalized_x) and
is_valid_normalized_value(normalized_y)):
# TODO: Draw coordinates even if it's outside of the image bounds.
return None
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
return x_px, y_px
def draw_detection(
image: np.ndarray,
detection: detection_pb2.Detection,
keypoint_drawing_spec: DrawingSpec = DrawingSpec(color=RED_COLOR),
bbox_drawing_spec: DrawingSpec = DrawingSpec()):
"""Draws the detction bounding box and keypoints on the image.
Args:
image: A three channel RGB image represented as numpy ndarray.
detection: A detection proto message to be annotated on the image.
keypoint_drawing_spec: A DrawingSpec object that specifies the keypoints'
drawing settings such as color, line thickness, and circle radius.
bbox_drawing_spec: A DrawingSpec object that specifies the bounding box's
drawing settings such as color and line thickness.
Raises:
ValueError: If one of the followings:
a) If the input image is not three channel RGB.
b) If the location data is not relative data.
"""
if not detection.location_data:
return
if image.shape[2] != RGB_CHANNELS:
raise ValueError('Input image must contain three channel rgb data.')
image_rows, image_cols, _ = image.shape
location = detection.location_data
if location.format != location_data_pb2.LocationData.RELATIVE_BOUNDING_BOX:
raise ValueError(
'LocationData must be relative for this drawing funtion to work.')
# Draws keypoints.
for keypoint in location.relative_keypoints:
keypoint_px = _normalized_to_pixel_coordinates(keypoint.x, keypoint.y,
image_cols, image_rows)
cv2.circle(image, keypoint_px, keypoint_drawing_spec.circle_radius,
keypoint_drawing_spec.color, keypoint_drawing_spec.thickness)
# Draws bounding box if exists.
if not location.HasField('relative_bounding_box'):
return
relative_bounding_box = location.relative_bounding_box
rect_start_point = _normalized_to_pixel_coordinates(
relative_bounding_box.xmin, relative_bounding_box.ymin, image_cols,
image_rows)
rect_end_point = _normalized_to_pixel_coordinates(
relative_bounding_box.xmin + relative_bounding_box.width,
relative_bounding_box.ymin + +relative_bounding_box.height, image_cols,
image_rows)
cv2.rectangle(image, rect_start_point, rect_end_point,
bbox_drawing_spec.color, bbox_drawing_spec.thickness)
def draw_landmarks(
image: np.ndarray,
landmark_list: landmark_pb2.NormalizedLandmarkList,
connections: Optional[List[Tuple[int, int]]] = None,
landmark_drawing_spec: DrawingSpec = DrawingSpec(color=RED_COLOR),
connection_drawing_spec: DrawingSpec = DrawingSpec()):
"""Draws the landmarks and the connections on the image.
Args:
image: A three channel RGB image represented as numpy ndarray.
landmark_list: A normalized landmark list proto message to be annotated on
the image.
connections: A list of landmark index tuples that specifies how landmarks to
be connected in the drawing.
landmark_drawing_spec: A DrawingSpec object that specifies the landmarks'
drawing settings such as color, line thickness, and circle radius.
connection_drawing_spec: A DrawingSpec object that specifies the
connections' drawing settings such as color and line thickness.
Raises:
ValueError: If one of the followings:
a) If the input image is not three channel RGB.
b) If any connetions contain invalid landmark index.
"""
if not landmark_list:
return
if image.shape[2] != RGB_CHANNELS:
raise ValueError('Input image must contain three channel rgb data.')
image_rows, image_cols, _ = image.shape
idx_to_coordinates = {}
for idx, landmark in enumerate(landmark_list.landmark):
if ((landmark.HasField('visibility') and
landmark.visibility < VISIBILITY_THRESHOLD) or
(landmark.HasField('presence') and
landmark.presence < PRESENCE_THRESHOLD)):
continue
landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,
image_cols, image_rows)
if landmark_px:
idx_to_coordinates[idx] = landmark_px
if connections:
num_landmarks = len(landmark_list.landmark)
# Draws the connections if the start and end landmarks are both visible.
for connection in connections:
start_idx = connection[0]
end_idx = connection[1]
if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
raise ValueError(f'Landmark index is out of range. Invalid connection '
f'from landmark #{start_idx} to landmark #{end_idx}.')
if start_idx in idx_to_coordinates and end_idx in idx_to_coordinates:
cv2.line(image, idx_to_coordinates[start_idx],
idx_to_coordinates[end_idx], connection_drawing_spec.color,
connection_drawing_spec.thickness)
# Draws landmark points after finishing the connection lines, which is
# aesthetically better.
for landmark_px in idx_to_coordinates.values():
cv2.circle(image, landmark_px, landmark_drawing_spec.circle_radius,
landmark_drawing_spec.color, landmark_drawing_spec.thickness)
def draw_axis(
image: np.ndarray,
rotation: np.ndarray,
translation: np.ndarray,
focal_length: Tuple[float, float] = (1.0, 1.0),
principal_point: Tuple[float, float] = (0.0, 0.0),
axis_length: float = 0.1,
axis_drawing_spec: DrawingSpec = DrawingSpec()):
"""Draws the 3D axis on the image.
Args:
image: A three channel RGB image represented as numpy ndarray.
rotation: Rotation matrix from object to camera coordinate frame.
translation: Translation vector from object to camera coordinate frame.
focal_length: camera focal length along x and y directions.
principal_point: camera principal point in x and y.
axis_length: length of the axis in the drawing.
axis_drawing_spec: A DrawingSpec object that specifies the xyz axis
drawing settings such as line thickness.
Raises:
ValueError: If one of the followings:
a) If the input image is not three channel RGB.
"""
if image.shape[2] != RGB_CHANNELS:
raise ValueError('Input image must contain three channel rgb data.')
image_rows, image_cols, _ = image.shape
# Create axis points in camera coordinate frame.
axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])
axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation
x = axis_cam[..., 0]
y = axis_cam[..., 1]
z = axis_cam[..., 2]
# Project 3D points to NDC space.
fx, fy = focal_length
px, py = principal_point
x_ndc = np.clip(-fx * x / (z + 1e-5) + px, -1., 1.)
y_ndc = np.clip(-fy * y / (z + 1e-5) + py, -1., 1.)
# Convert from NDC space to image space.
x_im = np.int32((1 + x_ndc) * 0.5 * image_cols)
y_im = np.int32((1 - y_ndc) * 0.5 * image_rows)
# Draw xyz axis on the image.
origin = (x_im[0], y_im[0])
x_axis = (x_im[1], y_im[1])
y_axis = (x_im[2], y_im[2])
z_axis = (x_im[3], y_im[3])
cv2.arrowedLine(image, origin, x_axis, RED_COLOR,
axis_drawing_spec.thickness)
cv2.arrowedLine(image, origin, y_axis, GREEN_COLOR,
axis_drawing_spec.thickness)
cv2.arrowedLine(image, origin, z_axis, BLUE_COLOR,
axis_drawing_spec.thickness)