160 lines
6.6 KiB
Python
160 lines
6.6 KiB
Python
# Copyright 2020 The MediaPipe Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for mediapipe.python.solutions.pose."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from typing import NamedTuple
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import cv2
|
|
import numpy as np
|
|
import numpy.testing as npt
|
|
|
|
# resources dependency
|
|
# undeclared dependency
|
|
from mediapipe.python.solutions import drawing_utils as mp_drawing
|
|
from mediapipe.python.solutions import pose as mp_pose
|
|
|
|
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
|
DIFF_THRESHOLD = 15 # pixels
|
|
EXPECTED_POSE_LANDMARKS = np.array([[460, 283], [467, 273], [471, 273],
|
|
[474, 273], [465, 273], [465, 273],
|
|
[466, 273], [491, 277], [480, 277],
|
|
[470, 294], [465, 294], [545, 319],
|
|
[453, 329], [622, 323], [375, 316],
|
|
[696, 316], [299, 307], [719, 316],
|
|
[278, 306], [721, 311], [274, 304],
|
|
[713, 313], [283, 306], [520, 476],
|
|
[467, 471], [612, 550], [358, 490],
|
|
[701, 613], [349, 611], [709, 624],
|
|
[363, 630], [730, 633], [303, 628]])
|
|
|
|
|
|
class PoseTest(parameterized.TestCase):
|
|
|
|
def _landmarks_list_to_array(self, landmark_list, image_shape):
|
|
rows, cols, _ = image_shape
|
|
return np.asarray([(lmk.x * cols, lmk.y * rows, lmk.z * cols)
|
|
for lmk in landmark_list.landmark])
|
|
|
|
def _assert_diff_less(self, array1, array2, threshold):
|
|
npt.assert_array_less(np.abs(array1 - array2), threshold)
|
|
|
|
def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int):
|
|
mp_drawing.draw_landmarks(frame, results.pose_landmarks,
|
|
mp_pose.POSE_CONNECTIONS)
|
|
path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] +
|
|
'_frame_{}.png'.format(idx))
|
|
cv2.imwrite(path, frame)
|
|
|
|
def test_invalid_image_shape(self):
|
|
with mp_pose.Pose() as pose:
|
|
with self.assertRaisesRegex(
|
|
ValueError, 'Input image must contain three channel rgb data.'):
|
|
pose.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
|
|
|
|
def test_blank_image(self):
|
|
with mp_pose.Pose() as pose:
|
|
image = np.zeros([100, 100, 3], dtype=np.uint8)
|
|
image.fill(255)
|
|
results = pose.process(image)
|
|
self.assertIsNone(results.pose_landmarks)
|
|
|
|
@parameterized.named_parameters(('static_lite', True, 0, 3),
|
|
('static_full', True, 1, 3),
|
|
('static_heavy', True, 2, 3),
|
|
('video_lite', False, 0, 3),
|
|
('video_full', False, 1, 3),
|
|
('video_heavy', False, 2, 3))
|
|
def test_on_image(self, static_image_mode, model_complexity, num_frames):
|
|
image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg')
|
|
image = cv2.imread(image_path)
|
|
with mp_pose.Pose(static_image_mode=static_image_mode,
|
|
model_complexity=model_complexity) as pose:
|
|
for idx in range(num_frames):
|
|
results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
self._annotate(image.copy(), results, idx)
|
|
self._assert_diff_less(
|
|
self._landmarks_list_to_array(results.pose_landmarks,
|
|
image.shape)[:, :2],
|
|
EXPECTED_POSE_LANDMARKS, DIFF_THRESHOLD)
|
|
|
|
@parameterized.named_parameters(
|
|
('full', 1, 'pose_squats.full.npz'))
|
|
def test_on_video(self, model_complexity, expected_name):
|
|
"""Tests pose models on a video."""
|
|
# If set to `True` will dump actual predictions to .npz and JSON files.
|
|
dump_predictions = False
|
|
|
|
# Set threshold for comparing actual and expected predictions in pixels.
|
|
diff_threshold = 50
|
|
|
|
video_path = os.path.join(os.path.dirname(__file__),
|
|
'testdata/pose_squats.mp4')
|
|
expected_path = os.path.join(os.path.dirname(__file__),
|
|
'testdata/{}'.format(expected_name))
|
|
|
|
# Predict pose landmarks for each frame.
|
|
video_cap = cv2.VideoCapture(video_path)
|
|
actual_per_frame = []
|
|
frame_idx = 0
|
|
with mp_pose.Pose(static_image_mode=False,
|
|
model_complexity=model_complexity) as pose:
|
|
while True:
|
|
# Get next frame of the video.
|
|
success, input_frame = video_cap.read()
|
|
if not success:
|
|
break
|
|
|
|
# Run pose tracker.
|
|
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB)
|
|
result = pose.process(image=input_frame)
|
|
pose_landmarks = self._landmarks_list_to_array(result.pose_landmarks,
|
|
input_frame.shape)
|
|
|
|
actual_per_frame.append(pose_landmarks)
|
|
|
|
input_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
|
|
self._annotate(input_frame, result, frame_idx)
|
|
frame_idx += 1
|
|
actual = np.asarray(actual_per_frame)
|
|
|
|
if dump_predictions:
|
|
# Dump .npz
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
np.savez(tmp_file, predictions=np.array(actual))
|
|
print('Predictions saved as .npz to {}'.format(tmp_file.name))
|
|
|
|
# Dump JSON
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
with open(tmp_file.name, 'w') as fl:
|
|
dump_data = {'predictions': np.around(actual, 3).tolist()}
|
|
fl.write(json.dumps(dump_data, indent=2, separators=(',', ': ')))
|
|
print('Predictions saved as JSON to {}'.format(tmp_file.name))
|
|
|
|
# Validate actual vs. expected predictions.
|
|
expected = np.load(expected_path)['predictions']
|
|
assert actual.shape == expected.shape, (
|
|
'Unexpected shape of predictions: {} instead of {}'.format(
|
|
actual.shape, expected.shape))
|
|
self._assert_diff_less(
|
|
actual[..., :2], expected[..., :2], threshold=diff_threshold)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main()
|