Added remaining tests for the GestureRecognizer Python MediaPipe Tasks API

This commit is contained in:
kinaryml 2022-10-25 11:11:15 -07:00
parent 18eb089d39
commit 8762d15c81
8 changed files with 414 additions and 42 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Landmark Detection Result data class.""" """Landmarks Detection Result data class."""
import dataclasses import dataclasses
from typing import Any, Optional from typing import Any, Optional

View File

@ -56,6 +56,7 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:gesture_recognizer", "//mediapipe/tasks/python/vision:gesture_recognizer",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"@com_google_protobuf//:protobuf_python" "@com_google_protobuf//:protobuf_python"
], ],
) )

View File

@ -14,7 +14,9 @@
"""Tests for gesture recognizer.""" """Tests for gesture recognizer."""
import enum import enum
from unittest import mock
import numpy as np
from google.protobuf import text_format from google.protobuf import text_format
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -29,10 +31,11 @@ from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import gesture_recognizer from mediapipe.tasks.python.vision import gesture_recognizer
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult _LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_NormalizedRect = rect_module.NormalizedRect _Rect = rect_module.Rect
_Classification = classification_module.Classification _Classification = classification_module.Classification
_ClassificationList = classification_module.ClassificationList _ClassificationList = classification_module.ClassificationList
_Landmark = landmark_module.Landmark _Landmark = landmark_module.Landmark
@ -45,12 +48,19 @@ _GestureRecognizer = gesture_recognizer.GestureRecognizer
_GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions _GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
_GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult _GestureRecognitionResult = gesture_recognizer.GestureRecognitionResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task' _GESTURE_RECOGNIZER_MODEL_FILE = 'gesture_recognizer.task'
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
_TWO_HANDS_IMAGE = 'right_hands.jpg'
_THUMB_UP_IMAGE = 'thumb_up.jpg' _THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = "thumb_up_landmarks.pbtxt" _THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_THUMB_UP_LABEL = "Thumb_Up" _THUMB_UP_LABEL = 'Thumb_Up'
_THUMB_UP_INDEX = 5 _THUMB_UP_INDEX = 5
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_POINTING_UP_LABEL = 'Pointing_Up'
_POINTING_UP_INDEX = 3
_LANDMARKS_ERROR_TOLERANCE = 0.03 _LANDMARKS_ERROR_TOLERANCE = 0.03
@ -89,7 +99,7 @@ class GestureRecognizerTest(parameterized.TestCase):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_THUMB_UP_IMAGE)) test_utils.get_test_data_path(_THUMB_UP_IMAGE))
self.gesture_recognizer_model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
_GESTURE_RECOGNIZER_MODEL_FILE) _GESTURE_RECOGNIZER_MODEL_FILE)
def _assert_actual_result_approximately_matches_expected_result( def _assert_actual_result_approximately_matches_expected_result(
@ -105,8 +115,15 @@ class GestureRecognizerTest(parameterized.TestCase):
self.assertLen(actual_result.handedness, len(expected_result.handedness)) self.assertLen(actual_result.handedness, len(expected_result.handedness))
self.assertLen(actual_result.gestures, len(expected_result.gestures)) self.assertLen(actual_result.gestures, len(expected_result.gestures))
# Actual landmarks match expected landmarks. # Actual landmarks match expected landmarks.
self.assertEqual(actual_result.hand_landmarks, self.assertLen(actual_result.hand_landmarks[0].landmarks,
expected_result.hand_landmarks) len(expected_result.hand_landmarks[0].landmarks))
actual_landmarks = actual_result.hand_landmarks[0].landmarks
expected_landmarks = expected_result.hand_landmarks[0].landmarks
for i in range(len(actual_landmarks)):
self.assertAlmostEqual(actual_landmarks[i].x, expected_landmarks[i].x,
delta=_LANDMARKS_ERROR_TOLERANCE)
self.assertAlmostEqual(actual_landmarks[i].y, expected_landmarks[i].y,
delta=_LANDMARKS_ERROR_TOLERANCE)
# Actual handedness matches expected handedness. # Actual handedness matches expected handedness.
actual_top_handedness = actual_result.handedness[0].classifications[0] actual_top_handedness = actual_result.handedness[0].classifications[0]
expected_top_handedness = expected_result.handedness[0].classifications[0] expected_top_handedness = expected_result.handedness[0].classifications[0]
@ -118,32 +135,56 @@ class GestureRecognizerTest(parameterized.TestCase):
self.assertEqual(actual_top_gesture.index, expected_top_gesture.index) self.assertEqual(actual_top_gesture.index, expected_top_gesture.index)
self.assertEqual(actual_top_gesture.label, expected_top_gesture.label) self.assertEqual(actual_top_gesture.label, expected_top_gesture.label)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _GestureRecognizer.create_from_model_path(self.model_path) as recognizer:
self.assertIsInstance(recognizer, _GestureRecognizer)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
self.assertIsInstance(recognizer, _GestureRecognizer)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _GestureRecognizerOptions(base_options=base_options)
_GestureRecognizer.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _GestureRecognizerOptions(base_options=base_options)
recognizer = _GestureRecognizer.create_from_options(options)
self.assertIsInstance(recognizer, _GestureRecognizer)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 0.3, _get_expected_gesture_recognition_result( (ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
)), )),
(ModelFileType.FILE_CONTENT, 0.3, _get_expected_gesture_recognition_result( (ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX _THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
))) )))
def test_recognize(self, model_file_type, min_gesture_confidence, def test_recognize(self, model_file_type, expected_recognition_result):
expected_recognition_result):
# Creates gesture recognizer. # Creates gesture recognizer.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
gesture_recognizer_base_options = _BaseOptions( base_options = _BaseOptions(model_asset_path=self.model_path)
model_asset_path=self.gesture_recognizer_model_path)
elif model_file_type is ModelFileType.FILE_CONTENT: elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.gesture_recognizer_model_path, 'rb') as f: with open(self.model_path, 'rb') as f:
model_content = f.read() model_content = f.read()
gesture_recognizer_base_options = _BaseOptions( base_options = _BaseOptions(model_asset_buffer=model_content)
model_asset_buffer=model_content)
else: else:
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _GestureRecognizerOptions( options = _GestureRecognizerOptions(base_options=base_options)
base_options=gesture_recognizer_base_options,
min_gesture_confidence=min_gesture_confidence
)
recognizer = _GestureRecognizer.create_from_options(options) recognizer = _GestureRecognizer.create_from_options(options)
# Performs hand gesture recognition on the input. # Performs hand gesture recognition on the input.
@ -151,10 +192,238 @@ class GestureRecognizerTest(parameterized.TestCase):
# Comparing results. # Comparing results.
self._assert_actual_result_approximately_matches_expected_result( self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result) recognition_result, expected_recognition_result)
# Closes the gesture recognizer explicitly when the detector is not used in # Closes the gesture recognizer explicitly when the gesture recognizer is
# a context. # not used in a context.
recognizer.close() recognizer.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
)),
(ModelFileType.FILE_CONTENT, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX
)))
def test_recognize_in_context(self, model_file_type,
expected_recognition_result):
# Creates gesture recognizer.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _GestureRecognizerOptions(base_options=base_options)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(self.test_image)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_succeeds_with_num_hands(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=2)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up rotated image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_TWO_HANDS_IMAGE))
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image)
# Comparing results.
self.assertLen(recognition_result.handedness, 2)
def test_recognize_succeeds_with_rotation(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the pointing up rotated image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
# Performs hand gesture recognition on the input.
recognition_result = recognizer.recognize(test_image,
image_processing_options)
expected_recognition_result = _get_expected_gesture_recognition_result(
_POINTING_UP_LANDMARKS, _POINTING_UP_LABEL, _POINTING_UP_INDEX)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_recognize_fails_with_region_of_interest(self):
# Creates gesture recognizer.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _GestureRecognizerOptions(base_options=base_options, num_hands=1)
with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest."):
with _GestureRecognizer.create_from_options(options) as recognizer:
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
image_processing_options = _ImageProcessingOptions(
region_of_interest=_Rect(0, 0, 1, 1))
# Attempt to perform hand gesture recognition on the cropped input.
recognizer.recognize(self.test_image, image_processing_options)
def test_empty_recognition_outputs(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path))
with _GestureRecognizer.create_from_options(options) as recognizer:
# Load the image with no hands.
no_hands_test_image = _Image.create_from_file(
test_utils.get_test_data_path(_NO_HANDS_IMAGE))
# Performs gesture recognition on the input.
recognition_result = recognizer.recognize(no_hands_test_image)
self.assertEmpty(recognition_result.hand_landmarks)
self.assertEmpty(recognition_result.hand_world_landmarks)
self.assertEmpty(recognition_result.handedness)
self.assertEmpty(recognition_result.gestures)
def test_missing_result_callback(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _GestureRecognizer.create_from_options(options) as unused_recognizer:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _GestureRecognizer.create_from_options(options) as unused_recognizer:
pass
def test_calling_recognize_for_video_in_image_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
recognizer.recognize_for_video(self.test_image, 0)
def test_calling_recognize_async_in_image_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
recognizer.recognize_async(self.test_image, 0)
def test_calling_recognize_in_video_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
recognizer.recognize(self.test_image)
def test_calling_recognize_async_in_video_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
recognizer.recognize_async(self.test_image, 0)
def test_recognize_for_video_with_out_of_order_timestamp(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _GestureRecognizer.create_from_options(options) as recognizer:
unused_result = recognizer.recognize_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
recognizer.recognize_for_video(self.test_image, 0)
def test_recognize_for_video(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _GestureRecognizer.create_from_options(options) as recognizer:
for timestamp in range(0, 300, 30):
recognition_result = recognizer.recognize_for_video(self.test_image,
timestamp)
expected_recognition_result = _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)
self._assert_actual_result_approximately_matches_expected_result(
recognition_result, expected_recognition_result)
def test_calling_recognize_in_live_stream_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
recognizer.recognize(self.test_image)
def test_calling_recognize_for_video_in_live_stream_mode(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _GestureRecognizer.create_from_options(options) as recognizer:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
recognizer.recognize_for_video(self.test_image, 0)
def test_recognize_async_calls_with_illegal_timestamp(self):
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _GestureRecognizer.create_from_options(options) as recognizer:
recognizer.recognize_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
recognizer.recognize_async(self.test_image, 0)
@parameterized.parameters(
(_THUMB_UP_IMAGE, _get_expected_gesture_recognition_result(
_THUMB_UP_LANDMARKS, _THUMB_UP_LABEL, _THUMB_UP_INDEX)),
(_NO_HANDS_IMAGE, _GestureRecognitionResult([], [], [], [])))
def test_recognize_async_calls(self, image_path, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
observed_timestamp_ms = -1
def check_result(result: _GestureRecognitionResult, output_image: _Image,
timestamp_ms: int):
if result.hand_landmarks and result.hand_world_landmarks and \
result.handedness and result.gestures:
self._assert_actual_result_approximately_matches_expected_result(
result, expected_result)
else:
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _GestureRecognizerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _GestureRecognizer.create_from_options(options) as recognizer:
for timestamp in range(0, 300, 30):
recognizer.recognize_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -52,7 +52,6 @@ py_library(
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/containers:classification", "//mediapipe/tasks/python/components/containers:classification",
"//mediapipe/tasks/python/components/containers:landmark", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options", "//mediapipe/tasks/python/components/processors:classifier_options",

View File

@ -23,6 +23,14 @@ py_library(
srcs = ["vision_task_running_mode.py"], srcs = ["vision_task_running_mode.py"],
) )
py_library(
name = "image_processing_options",
srcs = ["image_processing_options.py"],
deps = [
"//mediapipe/tasks/python/components/containers:rect",
],
)
py_library( py_library(
name = "base_vision_task_api", name = "base_vision_task_api",
srcs = [ srcs = [
@ -30,6 +38,7 @@ py_library(
], ],
deps = [ deps = [
":vision_task_running_mode", ":vision_task_running_mode",
":image_processing_options",
"//mediapipe/framework:calculator_py_pb2", "//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -13,17 +13,22 @@
# limitations under the License. # limitations under the License.
"""MediaPipe vision task base api.""" """MediaPipe vision task base api."""
import math
from typing import Callable, Mapping, Optional from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2 from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner_module.TaskRunner
_Packet = packet_module.Packet _Packet = packet_module.Packet
_NormalizedRect = rect_module.NormalizedRect
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
class BaseVisionTaskApi(object): class BaseVisionTaskApi(object):
@ -122,6 +127,50 @@ class BaseVisionTaskApi(object):
+ self._running_mode.name) + self._running_mode.name)
self._runner.send(inputs) self._runner.send(inputs)
@staticmethod
def convert_to_normalized_rect(
options: _ImageProcessingOptions,
roi_allowed: bool = True
) -> _NormalizedRect:
"""
Convert from ImageProcessingOptions to NormalizedRect, performing sanity
checks on-the-fly. If the input ImageProcessingOptions is not present,
returns a default NormalizedRect covering the whole image with rotation set
to 0. If 'roi_allowed' is false, an error will be returned if the input
ImageProcessingOptions has its 'region_of_interest' field set.
Args:
options: Options for image processing.
roi_allowed: Indicates if the `region_of_interest` field is allowed to be
set. By default, it's set to True.
"""
normalized_rect = _NormalizedRect(rotation=0, x_center=0.5, y_center=0.5,
width=1, height=1)
if options is None:
return normalized_rect
if options.rotation_degrees % 90 != 0:
raise ValueError("Expected rotation to be a multiple of 90°.")
# Convert to radians counter-clockwise.
normalized_rect.rotation = -options.rotation_degrees * math.pi / 180.0
if options.region_of_interest:
if not roi_allowed:
raise ValueError("This task doesn't support region-of-interest.")
roi = options.region_of_interest
if roi.x_center >= roi.width or roi.y_center >= roi.height:
raise ValueError(
"Expected Rect with x_center < width and y_center < height.")
if roi.x_center < 0 or roi.y_center < 0 or roi.width > 1 or roi.height > 1:
raise ValueError("Expected Rect values to be in [0,1].")
normalized_rect.x_center = roi.x_center + roi.width / 2.0
normalized_rect.y_center = roi.y_center + roi.height / 2.0
normalized_rect.width = roi.width - roi.x_center
normalized_rect.height = roi.height - roi.y_center
return normalized_rect
def close(self) -> None: def close(self) -> None:
"""Shuts down the mediapipe vision task instance. """Shuts down the mediapipe vision task instance.

View File

@ -0,0 +1,39 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe vision options for image processing."""
import dataclasses
from typing import Optional
from mediapipe.tasks.python.components.containers import rect as rect_module
@dataclasses.dataclass
class ImageProcessingOptions:
"""Options for image processing.
If both region-of-interest and rotation are specified, the crop around the
region-of-interest is extracted first, then the specified rotation is applied
to the crop.
Attributes:
region_of_interest: The optional region-of-interest to crop from the image.
If not specified, the full image is used. Coordinates must be in [0,1]
with 'left' < 'right' and 'top' < bottom.
rotation_degress: The rotation to apply to the image (or cropped
region-of-interest), in degrees clockwise. The rotation must be a
multiple (positive or negative) of 90°.
"""
region_of_interest: Optional[rect_module.Rect] = None
rotation_degrees: int = 0

View File

@ -27,7 +27,6 @@ from mediapipe.tasks.cc.vision.gesture_recognizer.proto import hand_gesture_reco
from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_detector.proto import hand_detector_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2 from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarks_detector_graph_options_pb2
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.components.containers import classification as classification_module from mediapipe.tasks.python.components.containers import classification as classification_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options from mediapipe.tasks.python.components.processors import classifier_options
@ -36,8 +35,8 @@ from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_NormalizedRect = rect_module.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions _GestureClassifierGraphOptionsProto = gesture_classifier_graph_options_pb2.GestureClassifierGraphOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
@ -47,6 +46,7 @@ _HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmar
_HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions _HandLandmarksDetectorGraphOptionsProto = hand_landmarks_detector_graph_options_pb2.HandLandmarksDetectorGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner_module.TaskRunner
@ -67,11 +67,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerG
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass @dataclasses.dataclass
class GestureRecognitionResult: class GestureRecognitionResult:
"""The gesture recognition result from GestureRecognizer, where each vector """The gesture recognition result from GestureRecognizer, where each vector
@ -278,7 +273,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
def recognize( def recognize(
self, self,
image: image_module.Image, image: image_module.Image,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult: ) -> GestureRecognitionResult:
"""Performs hand gesture recognition on the given image. Only use this """Performs hand gesture recognition on the given image. Only use this
method when the GestureRecognizer is created with the image running mode. method when the GestureRecognizer is created with the image running mode.
@ -289,7 +284,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
The hand gesture recognition results. The hand gesture recognition results.
@ -298,11 +293,16 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If gesture recognition failed to run. RuntimeError: If gesture recognition failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options,
roi_allowed=False)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
norm_rect.to_pb2())}) normalized_rect.to_pb2())})
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], [])
gestures_proto_list = packet_getter.get_proto_list( gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME]) output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
@ -331,7 +331,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
def recognize_for_video( def recognize_for_video(
self, image: image_module.Image, self, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult: ) -> GestureRecognitionResult:
"""Performs gesture recognition on the provided video frame. Only use this """Performs gesture recognition on the provided video frame. Only use this
method when the GestureRecognizer is created with the video running mode. method when the GestureRecognizer is created with the video running mode.
@ -344,7 +344,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
The hand gesture recognition results. The hand gesture recognition results.
@ -353,14 +353,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If gesture recognition failed to run. RuntimeError: If gesture recognition failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options,
roi_allowed=False)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
norm_rect.to_pb2()).at( normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], [])
gestures_proto_list = packet_getter.get_proto_list( gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME]) output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
@ -390,7 +395,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None: ) -> None:
"""Sends live image data to perform gesture recognition, and the results """Sends live image data to perform gesture recognition, and the results
will be available via the "result_callback" provided in the will be available via the "result_callback" provided in the
@ -415,17 +420,18 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the ValueError: If the current input timestamp is smaller than what the
gesture recognizer has already processed. gesture recognizer has already processed.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options,
roi_allowed=False)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto( _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
norm_rect.to_pb2()).at( normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })