From 97bd9c2157e9ee99e570074c396800d973e4f249 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 6 Apr 2023 04:59:26 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 522307800 --- mediapipe/tasks/python/test/vision/BUILD | 1 + .../test/vision/image_classifier_test.py | 320 +++++++++++++----- .../test/vision/object_detector_test.py | 178 ++++++---- mediapipe/tasks/python/vision/BUILD | 4 + .../vision/core/base_vision_task_api.py | 63 +++- .../tasks/python/vision/face_detector.py | 6 +- .../tasks/python/vision/face_landmarker.py | 6 +- .../tasks/python/vision/gesture_recognizer.py | 190 +++++++---- .../tasks/python/vision/hand_landmarker.py | 148 +++++--- .../tasks/python/vision/image_classifier.py | 109 +++--- .../tasks/python/vision/image_embedder.py | 115 ++++--- .../tasks/python/vision/image_segmenter.py | 128 +++++-- .../tasks/python/vision/object_detector.py | 150 +++++--- 13 files changed, 964 insertions(+), 454 deletions(-) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index a3285a6d6..704e1af5c 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -33,6 +33,7 @@ py_test( "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:object_detector", + "//mediapipe/tasks/python/vision/core:image_processing_options", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index b47efb32b..f1bbc1285 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -55,10 +55,10 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' def _generate_empty_results() -> ImageClassifierResult: return ImageClassifierResult( classifications=[ - _Classifications( - categories=[], head_index=0, head_name='probability') + _Classifications(categories=[], head_index=0, head_name='probability') ], - timestamp_ms=0) + timestamp_ms=0, + ) def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult: @@ -129,10 +129,11 @@ class ImageClassifierTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)) + ) self.model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) + os.path.join(_TEST_DATA_DIR, _MODEL_FILE) + ) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -148,9 +149,11 @@ class ImageClassifierTest(parameterized.TestCase): def test_create_from_options_fails_with_invalid_model_path(self): with self.assertRaisesRegex( - RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') + model_asset_path='/path/to/invalid/model.tflite' + ) options = _ImageClassifierOptions(base_options=base_options) _ImageClassifier.create_from_options(options) @@ -164,9 +167,11 @@ class ImageClassifierTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _generate_burger_results()), - (ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) - def test_classify(self, model_file_type, max_results, - expected_classification_result): + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()), + ) + def test_classify( + self, model_file_type, max_results, expected_classification_result + ): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(model_asset_path=self.model_path) @@ -179,23 +184,27 @@ class ImageClassifierTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, max_results=max_results + ) classifier = _ImageClassifier.create_from_options(options) # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - test_utils.assert_proto_equals(self, image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals( + self, image_result.to_pb2(), expected_classification_result.to_pb2() + ) # Closes the classifier explicitly when the classifier is not used in # a context. classifier.close() @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _generate_burger_results()), - (ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) - def test_classify_in_context(self, model_file_type, max_results, - expected_classification_result): + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()), + ) + def test_classify_in_context( + self, model_file_type, max_results, expected_classification_result + ): if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: @@ -207,13 +216,15 @@ class ImageClassifierTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageClassifierOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, max_results=max_results + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) # Comparing results. - test_utils.assert_proto_equals(self, image_result.to_pb2(), - expected_classification_result.to_pb2()) + test_utils.assert_proto_equals( + self, image_result.to_pb2(), expected_classification_result.to_pb2() + ) def test_classify_succeeds_with_region_of_interest(self): base_options = _BaseOptions(model_asset_path=self.model_path) @@ -222,20 +233,110 @@ class ImageClassifierTest(parameterized.TestCase): # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) + os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg') + ) + ) # Region-of-interest around the soccer ball. roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) image_processing_options = _ImageProcessingOptions(roi) # Performs image classification on the input. image_result = classifier.classify(test_image, image_processing_options) # Comparing results. - test_utils.assert_proto_equals(self, image_result.to_pb2(), - _generate_soccer_ball_results().to_pb2()) + test_utils.assert_proto_equals( + self, image_result.to_pb2(), _generate_soccer_ball_results().to_pb2() + ) + + def test_classify_succeeds_with_rotation(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageClassifierOptions(base_options=base_options, max_results=3) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, 'burger_rotated.jpg') + ) + ) + # Specify a 90° anti-clockwise rotation. + image_processing_options = _ImageProcessingOptions(None, -90) + # Performs image classification on the input. + image_result = classifier.classify(test_image, image_processing_options) + # Comparing results. + expected = ImageClassifierResult( + classifications=[ + _Classifications( + categories=[ + _Category( + index=934, + score=0.754467, + display_name='', + category_name='cheeseburger', + ), + _Category( + index=925, + score=0.0288028, + display_name='', + category_name='guacamole', + ), + _Category( + index=932, + score=0.0286119, + display_name='', + category_name='bagel', + ), + ], + head_index=0, + head_name='probability', + ) + ], + timestamp_ms=0, + ) + test_utils.assert_proto_equals( + self, image_result.to_pb2(), expected.to_pb2() + ) + + def test_classify_succeeds_with_region_of_interest_and_rotation(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageClassifierOptions(base_options=base_options, max_results=1) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, 'multi_objects_rotated.jpg') + ) + ) + # Region-of-interest around the soccer ball, with 90° anti-clockwise + # rotation. + roi = _Rect(left=0.2655, top=0.45, right=0.6925, bottom=0.614) + image_processing_options = _ImageProcessingOptions(roi, -90) + # Performs image classification on the input. + image_result = classifier.classify(test_image, image_processing_options) + # Comparing results. + expected = ImageClassifierResult( + classifications=[ + _Classifications( + categories=[ + _Category( + index=806, + score=0.997684, + display_name='', + category_name='soccer ball', + ), + ], + head_index=0, + head_name='probability', + ) + ], + timestamp_ms=0, + ) + test_utils.assert_proto_equals( + self, image_result.to_pb2(), expected.to_pb2() + ) def test_score_threshold_option(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - score_threshold=_SCORE_THRESHOLD) + score_threshold=_SCORE_THRESHOLD, + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -245,26 +346,33 @@ class ImageClassifierTest(parameterized.TestCase): for category in classification.categories: score = category.score self.assertGreaterEqual( - score, _SCORE_THRESHOLD, - f'Classification with score lower than threshold found. ' - f'{classification}') + score, + _SCORE_THRESHOLD, + ( + 'Classification with score lower than threshold found. ' + f'{classification}' + ), + ) def test_max_results_option(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - score_threshold=_SCORE_THRESHOLD) + score_threshold=_SCORE_THRESHOLD, + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) categories = image_result.classifications[0].categories self.assertLessEqual( - len(categories), _MAX_RESULTS, 'Too many results returned.') + len(categories), _MAX_RESULTS, 'Too many results returned.' + ) def test_allow_list_option(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - category_allowlist=_ALLOW_LIST) + category_allowlist=_ALLOW_LIST, + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -273,13 +381,17 @@ class ImageClassifierTest(parameterized.TestCase): for classification in classifications: for category in classification.categories: label = category.category_name - self.assertIn(label, _ALLOW_LIST, - f'Label {label} found but not in label allow list') + self.assertIn( + label, + _ALLOW_LIST, + f'Label {label} found but not in label allow list', + ) def test_deny_list_option(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - category_denylist=_DENY_LIST) + category_denylist=_DENY_LIST, + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -288,26 +400,30 @@ class ImageClassifierTest(parameterized.TestCase): for classification in classifications: for category in classification.categories: label = category.category_name - self.assertNotIn(label, _DENY_LIST, - f'Label {label} found but in deny list.') + self.assertNotIn( + label, _DENY_LIST, f'Label {label} found but in deny list.' + ) def test_combined_allowlist_and_denylist(self): # Fails with combined allowlist and denylist with self.assertRaisesRegex( ValueError, r'`category_allowlist` and `category_denylist` are mutually ' - r'exclusive options.'): + r'exclusive options.', + ): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), category_allowlist=['foo'], - category_denylist=['bar']) + category_denylist=['bar'], + ) with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_empty_classification_outputs(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - score_threshold=1) + score_threshold=1, + ) with _ImageClassifier.create_from_options(options) as classifier: # Performs image classification on the input. image_result = classifier.classify(self.test_image) @@ -316,9 +432,11 @@ class ImageClassifierTest(parameterized.TestCase): def test_missing_result_callback(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): with _ImageClassifier.create_from_options(options) as unused_classifier: pass @@ -327,67 +445,81 @@ class ImageClassifierTest(parameterized.TestCase): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=running_mode, - result_callback=mock.MagicMock()) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): with _ImageClassifier.create_from_options(options) as unused_classifier: pass def test_calling_classify_for_video_in_image_mode(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): classifier.classify_for_video(self.test_image, 0) def test_calling_classify_async_in_image_mode(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): classifier.classify_async(self.test_image, 0) def test_calling_classify_in_video_mode(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): classifier.classify(self.test_image) def test_calling_classify_async_in_video_mode(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): classifier.classify_async(self.test_image, 0) def test_classify_for_video_with_out_of_order_timestamp(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageClassifier.create_from_options(options) as classifier: unused_result = classifier.classify_for_video(self.test_image, 1) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): classifier.classify_for_video(self.test_image, 0) def test_classify_for_video(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - max_results=4) + max_results=4, + ) with _ImageClassifier.create_from_options(options) as classifier: for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( - self.test_image, timestamp) + self.test_image, timestamp + ) test_utils.assert_proto_equals( self, classification_result.to_pb2(), @@ -398,18 +530,22 @@ class ImageClassifierTest(parameterized.TestCase): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - max_results=1) + max_results=1, + ) with _ImageClassifier.create_from_options(options) as classifier: # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) + os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg') + ) + ) # Region-of-interest around the soccer ball. roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) image_processing_options = _ImageProcessingOptions(roi) for timestamp in range(0, 300, 30): classification_result = classifier.classify_for_video( - test_image, timestamp, image_processing_options) + test_image, timestamp, image_processing_options + ) test_utils.assert_proto_equals( self, classification_result.to_pb2(), @@ -420,20 +556,24 @@ class ImageClassifierTest(parameterized.TestCase): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): classifier.classify(self.test_image) def test_calling_classify_for_video_in_live_stream_mode(self): options = _ImageClassifierOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageClassifier.create_from_options(options) as classifier: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): classifier.classify_for_video(self.test_image, 0) def test_classify_async_calls_with_illegal_timestamp(self): @@ -441,25 +581,32 @@ class ImageClassifierTest(parameterized.TestCase): base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 100) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): classifier.classify_async(self.test_image, 0) - @parameterized.parameters((0, _generate_burger_results()), - (1, _generate_empty_results())) + @parameterized.parameters( + (0, _generate_burger_results()), (1, _generate_empty_results()) + ) def test_classify_async_calls(self, threshold, expected_result): observed_timestamp_ms = -1 - def check_result(result: ImageClassifierResult, output_image: _Image, - timestamp_ms: int): - test_utils.assert_proto_equals(self, result.to_pb2(), - expected_result.to_pb2()) + def check_result( + result: ImageClassifierResult, output_image: _Image, timestamp_ms: int + ): + test_utils.assert_proto_equals( + self, result.to_pb2(), expected_result.to_pb2() + ) self.assertTrue( - np.array_equal(output_image.numpy_view(), - self.test_image.numpy_view())) + np.array_equal( + output_image.numpy_view(), self.test_image.numpy_view() + ) + ) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms @@ -468,7 +615,8 @@ class ImageClassifierTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, score_threshold=threshold, - result_callback=check_result) + result_callback=check_result, + ) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(self.test_image, 0) @@ -476,14 +624,17 @@ class ImageClassifierTest(parameterized.TestCase): # Load the test image. test_image = _Image.create_from_file( test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) + os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg') + ) + ) # Region-of-interest around the soccer ball. roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) image_processing_options = _ImageProcessingOptions(roi) observed_timestamp_ms = -1 - def check_result(result: ImageClassifierResult, output_image: _Image, - timestamp_ms: int): + def check_result( + result: ImageClassifierResult, output_image: _Image, timestamp_ms: int + ): test_utils.assert_proto_equals( self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2() ) @@ -496,7 +647,8 @@ class ImageClassifierTest(parameterized.TestCase): base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=1, - result_callback=check_result) + result_callback=check_result, + ) with _ImageClassifier.create_from_options(options) as classifier: classifier.classify_async(test_image, 100, image_processing_options) diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 2bb9b0214..5a11246a4 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -28,6 +28,7 @@ from mediapipe.tasks.python.components.containers import detections as detection from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import object_detector +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions @@ -36,8 +37,10 @@ _BoundingBox = bounding_box_module.BoundingBox _Detection = detections_module.Detection _DetectionResult = detections_module.DetectionResult _Image = image_module.Image +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ObjectDetector = object_detector.ObjectDetector _ObjectDetectorOptions = object_detector.ObjectDetectorOptions + _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' @@ -115,10 +118,11 @@ class ObjectDetectorTest(parameterized.TestCase): def setUp(self): super().setUp() self.test_image = _Image.create_from_file( - test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)) + ) self.model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) + os.path.join(_TEST_DATA_DIR, _MODEL_FILE) + ) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -134,9 +138,11 @@ class ObjectDetectorTest(parameterized.TestCase): def test_create_from_options_fails_with_invalid_model_path(self): with self.assertRaisesRegex( - RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') + model_asset_path='/path/to/invalid/model.tflite' + ) options = _ObjectDetectorOptions(base_options=base_options) _ObjectDetector.create_from_options(options) @@ -150,9 +156,11 @@ class ObjectDetectorTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT)) - def test_detect(self, model_file_type, max_results, - expected_detection_result): + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT), + ) + def test_detect( + self, model_file_type, max_results, expected_detection_result + ): # Creates detector. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(model_asset_path=self.model_path) @@ -165,7 +173,8 @@ class ObjectDetectorTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ObjectDetectorOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, max_results=max_results + ) detector = _ObjectDetector.create_from_options(options) # Performs object detection on the input. @@ -178,9 +187,11 @@ class ObjectDetectorTest(parameterized.TestCase): @parameterized.parameters( (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT), - (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT)) - def test_detect_in_context(self, model_file_type, max_results, - expected_detection_result): + (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT), + ) + def test_detect_in_context( + self, model_file_type, max_results, expected_detection_result + ): if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(model_asset_path=self.model_path) elif model_file_type is ModelFileType.FILE_CONTENT: @@ -192,7 +203,8 @@ class ObjectDetectorTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ObjectDetectorOptions( - base_options=base_options, max_results=max_results) + base_options=base_options, max_results=max_results + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) @@ -202,7 +214,8 @@ class ObjectDetectorTest(parameterized.TestCase): def test_score_threshold_option(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - score_threshold=_SCORE_THRESHOLD) + score_threshold=_SCORE_THRESHOLD, + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) @@ -211,25 +224,30 @@ class ObjectDetectorTest(parameterized.TestCase): for detection in detections: score = detection.categories[0].score self.assertGreaterEqual( - score, _SCORE_THRESHOLD, - f'Detection with score lower than threshold found. {detection}') + score, + _SCORE_THRESHOLD, + f'Detection with score lower than threshold found. {detection}', + ) def test_max_results_option(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - max_results=_MAX_RESULTS) + max_results=_MAX_RESULTS, + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) detections = detection_result.detections self.assertLessEqual( - len(detections), _MAX_RESULTS, 'Too many results returned.') + len(detections), _MAX_RESULTS, 'Too many results returned.' + ) def test_allow_list_option(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - category_allowlist=_ALLOW_LIST) + category_allowlist=_ALLOW_LIST, + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) @@ -237,13 +255,17 @@ class ObjectDetectorTest(parameterized.TestCase): for detection in detections: label = detection.categories[0].category_name - self.assertIn(label, _ALLOW_LIST, - f'Label {label} found but not in label allow list') + self.assertIn( + label, + _ALLOW_LIST, + f'Label {label} found but not in label allow list', + ) def test_deny_list_option(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - category_denylist=_DENY_LIST) + category_denylist=_DENY_LIST, + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) @@ -251,26 +273,30 @@ class ObjectDetectorTest(parameterized.TestCase): for detection in detections: label = detection.categories[0].category_name - self.assertNotIn(label, _DENY_LIST, - f'Label {label} found but in deny list.') + self.assertNotIn( + label, _DENY_LIST, f'Label {label} found but in deny list.' + ) def test_combined_allowlist_and_denylist(self): # Fails with combined allowlist and denylist with self.assertRaisesRegex( ValueError, r'`category_allowlist` and `category_denylist` are mutually ' - r'exclusive options.'): + r'exclusive options.', + ): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), category_allowlist=['foo'], - category_denylist=['bar']) + category_denylist=['bar'], + ) with _ObjectDetector.create_from_options(options) as unused_detector: pass def test_empty_detection_outputs(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - score_threshold=1) + score_threshold=1, + ) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. detection_result = detector.detect(self.test_image) @@ -279,9 +305,11 @@ class ObjectDetectorTest(parameterized.TestCase): def test_missing_result_callback(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): with _ObjectDetector.create_from_options(options) as unused_detector: pass @@ -290,56 +318,68 @@ class ObjectDetectorTest(parameterized.TestCase): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=running_mode, - result_callback=mock.MagicMock()) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): with _ObjectDetector.create_from_options(options) as unused_detector: pass def test_calling_detect_for_video_in_image_mode(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): detector.detect_for_video(self.test_image, 0) def test_calling_detect_async_in_image_mode(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): detector.detect_async(self.test_image, 0) def test_calling_detect_in_video_mode(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): detector.detect(self.test_image) def test_calling_detect_async_in_video_mode(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): detector.detect_async(self.test_image, 0) def test_detect_for_video_with_out_of_order_timestamp(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ObjectDetector.create_from_options(options) as detector: unused_result = detector.detect_for_video(self.test_image, 1) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): detector.detect_for_video(self.test_image, 0) # TODO: Tests how `detect_for_video` handles the temporal data @@ -348,7 +388,8 @@ class ObjectDetectorTest(parameterized.TestCase): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.VIDEO, - max_results=4) + max_results=4, + ) with _ObjectDetector.create_from_options(options) as detector: for timestamp in range(0, 300, 30): detection_result = detector.detect_for_video(self.test_image, timestamp) @@ -358,20 +399,24 @@ class ObjectDetectorTest(parameterized.TestCase): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): detector.detect(self.test_image) def test_calling_detect_for_video_in_live_stream_mode(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ObjectDetector.create_from_options(options) as detector: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): detector.detect_for_video(self.test_image, 0) def test_detect_async_calls_with_illegal_timestamp(self): @@ -379,24 +424,30 @@ class ObjectDetectorTest(parameterized.TestCase): base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ObjectDetector.create_from_options(options) as detector: detector.detect_async(self.test_image, 100) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): detector.detect_async(self.test_image, 0) - @parameterized.parameters((0, _EXPECTED_DETECTION_RESULT), - (1, _DetectionResult(detections=[]))) + @parameterized.parameters( + (0, _EXPECTED_DETECTION_RESULT), (1, _DetectionResult(detections=[])) + ) def test_detect_async_calls(self, threshold, expected_result): observed_timestamp_ms = -1 - def check_result(result: _DetectionResult, output_image: _Image, - timestamp_ms: int): + def check_result( + result: _DetectionResult, output_image: _Image, timestamp_ms: int + ): self.assertEqual(result, expected_result) self.assertTrue( - np.array_equal(output_image.numpy_view(), - self.test_image.numpy_view())) + np.array_equal( + output_image.numpy_view(), self.test_image.numpy_view() + ) + ) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms @@ -405,7 +456,8 @@ class ObjectDetectorTest(parameterized.TestCase): running_mode=_RUNNING_MODE.LIVE_STREAM, max_results=4, score_threshold=threshold, - result_callback=check_result) + result_callback=check_result, + ) detector = _ObjectDetector.create_from_options(options) for timestamp in range(0, 300, 30): detector.detect_async(self.test_image, timestamp) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 89a988be9..2c0053b11 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -29,10 +29,12 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_py_pb2", "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) @@ -71,10 +73,12 @@ py_library( "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", + "//mediapipe/tasks/python/components/containers:rect", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index 768d392f1..eb976153e 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -17,6 +17,7 @@ import math from typing import Callable, Mapping, Optional from mediapipe.framework import calculator_pb2 +from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.tasks.python.components.containers import rect as rect_module @@ -39,8 +40,9 @@ class BaseVisionTaskApi(object): self, graph_config: calculator_pb2.CalculatorGraphConfig, running_mode: _RunningMode, - packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]], - None]] = None + packet_callback: Optional[ + Callable[[Mapping[str, packet_module.Packet]], None] + ] = None, ) -> None: """Initializes the `BaseVisionTaskApi` object. @@ -48,7 +50,7 @@ class BaseVisionTaskApi(object): graph_config: The mediapipe vision task graph config proto. running_mode: The running mode of the mediapipe vision task. packet_callback: The optional packet callback for getting results - asynchronously in the live stream mode. + asynchronously in the live stream mode. Raises: ValueError: The packet callback is not properly set based on the task's @@ -58,16 +60,19 @@ class BaseVisionTaskApi(object): if packet_callback is None: raise ValueError( 'The vision task is in live stream mode, a user-defined result ' - 'callback must be provided.') + 'callback must be provided.' + ) elif packet_callback: raise ValueError( 'The vision task is in image or video mode, a user-defined result ' - 'callback should not be provided.') + 'callback should not be provided.' + ) self._runner = _TaskRunner.create(graph_config, packet_callback) self._running_mode = running_mode def _process_image_data( - self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: + self, inputs: Mapping[str, _Packet] + ) -> Mapping[str, _Packet]: """A synchronous method to process single image inputs. The call blocks the current thread until a failure status or a successful @@ -84,12 +89,14 @@ class BaseVisionTaskApi(object): """ if self._running_mode != _RunningMode.IMAGE: raise ValueError( - 'Task is not initialized with the image mode. Current running mode:' + - self._running_mode.name) + 'Task is not initialized with the image mode. Current running mode:' + + self._running_mode.name + ) return self._runner.process(inputs) def _process_video_data( - self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: + self, inputs: Mapping[str, _Packet] + ) -> Mapping[str, _Packet]: """A synchronous method to process continuous video frames. The call blocks the current thread until a failure status or a successful @@ -106,8 +113,9 @@ class BaseVisionTaskApi(object): """ if self._running_mode != _RunningMode.VIDEO: raise ValueError( - 'Task is not initialized with the video mode. Current running mode:' + - self._running_mode.name) + 'Task is not initialized with the video mode. Current running mode:' + + self._running_mode.name + ) return self._runner.process(inputs) def _send_live_stream_data(self, inputs: Mapping[str, _Packet]) -> None: @@ -124,13 +132,18 @@ class BaseVisionTaskApi(object): """ if self._running_mode != _RunningMode.LIVE_STREAM: raise ValueError( - 'Task is not initialized with the live stream mode. Current running mode:' - + self._running_mode.name) + 'Task is not initialized with the live stream mode. Current running' + ' mode:' + + self._running_mode.name + ) self._runner.send(inputs) - def convert_to_normalized_rect(self, - options: _ImageProcessingOptions, - roi_allowed: bool = True) -> _NormalizedRect: + def convert_to_normalized_rect( + self, + options: _ImageProcessingOptions, + image: image_module.Image, + roi_allowed: bool = True, + ) -> _NormalizedRect: """Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly. If the input ImageProcessingOptions is not present, returns a default @@ -140,6 +153,7 @@ class BaseVisionTaskApi(object): Args: options: Options for image processing. + image: The image to process. roi_allowed: Indicates if the `region_of_interest` field is allowed to be set. By default, it's set to True. @@ -147,7 +161,8 @@ class BaseVisionTaskApi(object): A normalized rect proto that represents the image processing options. """ normalized_rect = _NormalizedRect( - rotation=0, x_center=0.5, y_center=0.5, width=1, height=1) + rotation=0, x_center=0.5, y_center=0.5, width=1, height=1 + ) if options is None: return normalized_rect @@ -169,6 +184,20 @@ class BaseVisionTaskApi(object): normalized_rect.y_center = (roi.top + roi.bottom) / 2.0 normalized_rect.width = roi.right - roi.left normalized_rect.height = roi.bottom - roi.top + + # For 90° and 270° rotations, we need to swap width and height. + # This is due to the internal behavior of ImageToTensorCalculator, which: + # - first denormalizes the provided rect by multiplying the rect width or + # height by the image width or height, repectively. + # - then rotates this by denormalized rect by the provided rotation, and + # uses this for cropping, + # - then finally rotates this back. + if abs(options.rotation_degrees % 180) != 0: + w = normalized_rect.height * image.height / image.width + h = normalized_rect.width * image.width / image.height + normalized_rect.width = w + normalized_rect.height = h + return normalized_rect def close(self) -> None: diff --git a/mediapipe/tasks/python/vision/face_detector.py b/mediapipe/tasks/python/vision/face_detector.py index f0ce4d1f1..cf09a378d 100644 --- a/mediapipe/tasks/python/vision/face_detector.py +++ b/mediapipe/tasks/python/vision/face_detector.py @@ -212,7 +212,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If face detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) output_packets = self._process_image_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), @@ -261,7 +261,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If face detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) output_packets = self._process_video_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( @@ -320,7 +320,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi): detector has already processed. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) self._send_live_stream_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 41faf6d91..3e43e8a7f 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -401,7 +401,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If face landmarker detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) output_packets = self._process_image_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), @@ -444,7 +444,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If face landmarker detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) output_packets = self._process_video_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( @@ -497,7 +497,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): face landmarker has already processed. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False + image_processing_options, image, roi_allowed=False ) self._send_live_stream_data({ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 227203a0d..7d480c95f 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions +_GestureRecognizerGraphOptionsProto = ( + gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions +) _ClassifierOptions = classifier_options.ClassifierOptions _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -53,7 +55,9 @@ _HAND_LANDMARKS_STREAM_NAME = 'landmarks' _HAND_LANDMARKS_TAG = 'LANDMARKS' _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks' _HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' +_TASK_GRAPH_NAME = ( + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' +) _MICRO_SECONDS_PER_MILLISECOND = 1000 _GESTURE_DEFAULT_INDEX = -1 @@ -78,17 +82,21 @@ class GestureRecognizerResult: def _build_recognition_result( - output_packets: Mapping[str, - packet_module.Packet]) -> GestureRecognizerResult: + output_packets: Mapping[str, packet_module.Packet] +) -> GestureRecognizerResult: """Constructs a `GestureRecognizerResult` from output packets.""" gestures_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_GESTURE_STREAM_NAME]) + output_packets[_HAND_GESTURE_STREAM_NAME] + ) handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) + output_packets[_HANDEDNESS_STREAM_NAME] + ) hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) + output_packets[_HAND_LANDMARKS_STREAM_NAME] + ) hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME] + ) gesture_results = [] for proto in gestures_proto_list: @@ -101,7 +109,9 @@ def _build_recognition_result( index=_GESTURE_DEFAULT_INDEX, score=gesture.score, display_name=gesture.display_name, - category_name=gesture.label)) + category_name=gesture.label, + ) + ) gesture_results.append(gesture_categories) handedness_results = [] @@ -115,7 +125,9 @@ def _build_recognition_result( index=handedness.index, score=handedness.score, display_name=handedness.display_name, - category_name=handedness.label)) + category_name=handedness.label, + ) + ) handedness_results.append(handedness_categories) hand_landmarks_results = [] @@ -125,7 +137,8 @@ def _build_recognition_result( hand_landmarks_list = [] for hand_landmark in hand_landmarks.landmark: hand_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)) + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + ) hand_landmarks_results.append(hand_landmarks_list) hand_world_landmarks_results = [] @@ -135,12 +148,16 @@ def _build_recognition_result( hand_world_landmarks_list = [] for hand_world_landmark in hand_world_landmarks.landmark: hand_world_landmarks_list.append( - landmark_module.Landmark.create_from_pb2(hand_world_landmark)) + landmark_module.Landmark.create_from_pb2(hand_world_landmark) + ) hand_world_landmarks_results.append(hand_world_landmarks_list) - return GestureRecognizerResult(gesture_results, handedness_results, - hand_landmarks_results, - hand_world_landmarks_results) + return GestureRecognizerResult( + gesture_results, + handedness_results, + hand_landmarks_results, + hand_world_landmarks_results, + ) @dataclasses.dataclass @@ -174,43 +191,62 @@ class GestureRecognizerOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE num_hands: Optional[int] = 1 min_hand_detection_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 - canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) - custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = dataclasses.field( - default_factory=_ClassifierOptions) - result_callback: Optional[Callable[ - [GestureRecognizerResult, image_module.Image, int], None]] = None + canned_gesture_classifier_options: Optional[_ClassifierOptions] = ( + dataclasses.field(default_factory=_ClassifierOptions) + ) + custom_gesture_classifier_options: Optional[_ClassifierOptions] = ( + dataclasses.field(default_factory=_ClassifierOptions) + ) + result_callback: Optional[ + Callable[[GestureRecognizerResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _GestureRecognizerGraphOptionsProto: """Generates an GestureRecognizerOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) # Initialize gesture recognizer options from base options. gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto( - base_options=base_options_proto) + base_options=base_options_proto + ) # Configure hand detector and hand landmarker options. - hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options - hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence - hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands - hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence - hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence + hand_landmarker_options_proto = ( + gesture_recognizer_options_proto.hand_landmarker_graph_options + ) + hand_landmarker_options_proto.min_tracking_confidence = ( + self.min_tracking_confidence + ) + hand_landmarker_options_proto.hand_detector_graph_options.num_hands = ( + self.num_hands + ) + hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = ( + self.min_hand_detection_confidence + ) + hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = ( + self.min_hand_presence_confidence + ) # Configure hand gesture recognizer options. - hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options + hand_gesture_recognizer_options_proto = ( + gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options + ) hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom( - self.canned_gesture_classifier_options.to_pb2()) + self.canned_gesture_classifier_options.to_pb2() + ) hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom( - self.custom_gesture_classifier_options.to_pb2()) + self.custom_gesture_classifier_options.to_pb2() + ) return gesture_recognizer_options_proto @@ -239,12 +275,14 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = GestureRecognizerOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod def create_from_options( - cls, options: GestureRecognizerOptions) -> 'GestureRecognizer': + cls, options: GestureRecognizerOptions + ) -> 'GestureRecognizer': """Creates the `GestureRecognizer` object from gesture recognizer options. Args: @@ -268,14 +306,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME] options.result_callback( - GestureRecognizerResult([], [], [], []), image, - empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + GestureRecognizerResult([], [], [], []), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) return gesture_recognizer_result = _build_recognition_result(output_packets) timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp - options.result_callback(gesture_recognizer_result, image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + options.result_callback( + gesture_recognizer_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -286,23 +329,27 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): output_streams=[ ':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]), ':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]), - ':'.join([_HAND_LANDMARKS_TAG, - _HAND_LANDMARKS_STREAM_NAME]), ':'.join([ - _HAND_WORLD_LANDMARKS_TAG, - _HAND_WORLD_LANDMARKS_STREAM_NAME - ]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]), + ':'.join( + [_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME] + ), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) def recognize( self, image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> GestureRecognizerResult: """Performs hand gesture recognition on the given image. @@ -325,12 +372,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If gesture recognition failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_image_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), }) if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): @@ -342,7 +390,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> GestureRecognizerResult: """Performs gesture recognition on the provided video frame. @@ -367,14 +415,15 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If gesture recognition failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): @@ -386,7 +435,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> None: """Sends live image data to perform gesture recognition. @@ -419,12 +468,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi): gesture recognizer has already processed. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index a0cd99a83..e6fcca2e2 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module _BaseOptions = base_options_module.BaseOptions -_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions +_HandLandmarkerGraphOptionsProto = ( + hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions +) _RunningMode = running_mode_module.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo @@ -56,6 +58,7 @@ _MICRO_SECONDS_PER_MILLISECOND = 1000 class HandLandmark(enum.IntEnum): """The 21 hand landmarks.""" + WRIST = 0 THUMB_CMC = 1 THUMB_MCP = 2 @@ -95,14 +98,18 @@ class HandLandmarkerResult: def _build_landmarker_result( - output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult: + output_packets: Mapping[str, packet_module.Packet] +) -> HandLandmarkerResult: """Constructs a `HandLandmarksDetectionResult` from output packets.""" handedness_proto_list = packet_getter.get_proto_list( - output_packets[_HANDEDNESS_STREAM_NAME]) + output_packets[_HANDEDNESS_STREAM_NAME] + ) hand_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_LANDMARKS_STREAM_NAME]) + output_packets[_HAND_LANDMARKS_STREAM_NAME] + ) hand_world_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) + output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME] + ) handedness_results = [] for proto in handedness_proto_list: @@ -115,7 +122,9 @@ def _build_landmarker_result( index=handedness.index, score=handedness.score, display_name=handedness.display_name, - category_name=handedness.label)) + category_name=handedness.label, + ) + ) handedness_results.append(handedness_categories) hand_landmarks_results = [] @@ -125,7 +134,8 @@ def _build_landmarker_result( hand_landmarks_list = [] for hand_landmark in hand_landmarks.landmark: hand_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)) + landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) + ) hand_landmarks_results.append(hand_landmarks_list) hand_world_landmarks_results = [] @@ -135,11 +145,13 @@ def _build_landmarker_result( hand_world_landmarks_list = [] for hand_world_landmark in hand_world_landmarks.landmark: hand_world_landmarks_list.append( - landmark_module.Landmark.create_from_pb2(hand_world_landmark)) + landmark_module.Landmark.create_from_pb2(hand_world_landmark) + ) hand_world_landmarks_results.append(hand_world_landmarks_list) - return HandLandmarkerResult(handedness_results, hand_landmarks_results, - hand_world_landmarks_results) + return HandLandmarkerResult( + handedness_results, hand_landmarks_results, hand_world_landmarks_results + ) @dataclasses.dataclass @@ -167,28 +179,41 @@ class HandLandmarkerOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE num_hands: Optional[int] = 1 min_hand_detection_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 - result_callback: Optional[Callable[ - [HandLandmarkerResult, image_module.Image, int], None]] = None + result_callback: Optional[ + Callable[[HandLandmarkerResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _HandLandmarkerGraphOptionsProto: """Generates an HandLandmarkerGraphOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) # Initialize the hand landmarker options from base options. hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto( - base_options=base_options_proto) - hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence - hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands - hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence - hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence + base_options=base_options_proto + ) + hand_landmarker_options_proto.min_tracking_confidence = ( + self.min_tracking_confidence + ) + hand_landmarker_options_proto.hand_detector_graph_options.num_hands = ( + self.num_hands + ) + hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = ( + self.min_hand_detection_confidence + ) + hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = ( + self.min_hand_presence_confidence + ) return hand_landmarker_options_proto @@ -216,12 +241,14 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = HandLandmarkerOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod - def create_from_options(cls, - options: HandLandmarkerOptions) -> 'HandLandmarker': + def create_from_options( + cls, options: HandLandmarkerOptions + ) -> 'HandLandmarker': """Creates the `HandLandmarker` object from hand landmarker options. Args: @@ -245,14 +272,19 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME] options.result_callback( - HandLandmarkerResult([], [], []), image, - empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + HandLandmarkerResult([], [], []), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) return hand_landmarks_detection_result = _build_landmarker_result(output_packets) timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp - options.result_callback(hand_landmarks_detection_result, image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + options.result_callback( + hand_landmarks_detection_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -263,21 +295,26 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): output_streams=[ ':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]), ':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]), - ':'.join([ - _HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME - ]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join( + [_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME] + ), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) def detect( self, image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> HandLandmarkerResult: """Performs hand landmarks detection on the given image. @@ -300,12 +337,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If hand landmarker detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_image_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), }) if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): @@ -317,7 +355,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> HandLandmarkerResult: """Performs hand landmarks detection on the provided video frame. @@ -342,14 +380,15 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If hand landmarker detection failed to run. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): @@ -361,7 +400,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> None: """Sends live image data to perform hand landmarks detection. @@ -394,12 +433,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi): hand landmarker has already processed. """ normalized_rect = self.convert_to_normalized_rect( - image_processing_options, roi_allowed=False) + image_processing_options, image, roi_allowed=False + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index b60d18e31..eda348fc7 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -35,7 +35,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode ImageClassifierResult = classification_result_module.ClassificationResult _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions -_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions +_ImageClassifierGraphOptionsProto = ( + image_classifier_graph_options_pb2.ImageClassifierGraphOptions +) _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions @@ -48,7 +50,9 @@ _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _NORM_RECT_STREAM_NAME = 'norm_rect_in' _NORM_RECT_TAG = 'NORM_RECT' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' +_TASK_GRAPH_NAME = ( + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' +) _MICRO_SECONDS_PER_MILLISECOND = 1000 @@ -81,6 +85,7 @@ class ImageClassifierOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE display_names_locale: Optional[str] = None @@ -88,24 +93,29 @@ class ImageClassifierOptions: score_threshold: Optional[float] = None category_allowlist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None - result_callback: Optional[Callable[ - [ImageClassifierResult, image_module.Image, int], None]] = None + result_callback: Optional[ + Callable[[ImageClassifierResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageClassifierGraphOptionsProto: """Generates an ImageClassifierOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) classifier_options_proto = _ClassifierOptionsProto( score_threshold=self.score_threshold, category_allowlist=self.category_allowlist, category_denylist=self.category_denylist, display_names_locale=self.display_names_locale, - max_results=self.max_results) + max_results=self.max_results, + ) return _ImageClassifierGraphOptionsProto( base_options=base_options_proto, - classifier_options=classifier_options_proto) + classifier_options=classifier_options_proto, + ) class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): @@ -165,12 +175,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = ImageClassifierOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod - def create_from_options(cls, - options: ImageClassifierOptions) -> 'ImageClassifier': + def create_from_options( + cls, options: ImageClassifierOptions + ) -> 'ImageClassifier': """Creates the `ImageClassifier` object from image classifier options. Args: @@ -191,12 +203,15 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]) + ) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback( ImageClassifierResult.create_from_pb2(classification_result_proto), - image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -206,19 +221,23 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ], output_streams=[ ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) def classify( self, image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> ImageClassifierResult: """Performs image classification on the provided MediaPipe Image. @@ -233,17 +252,20 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) output_packets = self._process_image_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), }) classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]) + ) return ImageClassifierResult.create_from_pb2(classification_result_proto) @@ -251,7 +273,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> ImageClassifierResult: """Performs image classification on the provided video frames. @@ -272,19 +294,22 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image classification failed to run. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]) + ) return ImageClassifierResult.create_from_pb2(classification_result_proto) @@ -292,7 +317,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> None: """Sends live image data (an Image with a unique timestamp) to perform image classification. @@ -320,12 +345,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): ValueError: If the current input timestamp is smaller than what the image classifier has already processed. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index 0bae21bda..511fc3c56 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni ImageEmbedderResult = embedding_result_module.EmbeddingResult _BaseOptions = base_options_module.BaseOptions -_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions +_ImageEmbedderGraphOptionsProto = ( + image_embedder_graph_options_pb2.ImageEmbedderGraphOptions +) _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _RunningMode = running_mode_module.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -74,24 +76,29 @@ class ImageEmbedderOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE l2_normalize: Optional[bool] = None quantize: Optional[bool] = None - result_callback: Optional[Callable[ - [ImageEmbedderResult, image_module.Image, int], None]] = None + result_callback: Optional[ + Callable[[ImageEmbedderResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageEmbedderGraphOptionsProto: """Generates an ImageEmbedderOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) embedder_options_proto = _EmbedderOptionsProto( - l2_normalize=self.l2_normalize, quantize=self.quantize) + l2_normalize=self.l2_normalize, quantize=self.quantize + ) return _ImageEmbedderGraphOptionsProto( - base_options=base_options_proto, - embedder_options=embedder_options_proto) + base_options=base_options_proto, embedder_options=embedder_options_proto + ) class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): @@ -135,12 +142,14 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = ImageEmbedderOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod - def create_from_options(cls, - options: ImageEmbedderOptions) -> 'ImageEmbedder': + def create_from_options( + cls, options: ImageEmbedderOptions + ) -> 'ImageEmbedder': """Creates the `ImageEmbedder` object from image embedder options. Args: @@ -161,13 +170,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) + ) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback( - ImageEmbedderResult.create_from_pb2(embedding_result_proto), image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + ImageEmbedderResult.create_from_pb2(embedding_result_proto), + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, @@ -177,19 +189,23 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): ], output_streams=[ ':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) def embed( self, image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> ImageEmbedderResult: """Performs image embedding extraction on the provided MediaPipe Image. @@ -207,17 +223,20 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image embedder failed to run. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) output_packets = self._process_image_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), }) embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) + ) return ImageEmbedderResult.create_from_pb2(embedding_result_proto) @@ -225,7 +244,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> ImageEmbedderResult: """Performs image embedding extraction on the provided video frames. @@ -249,18 +268,21 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image embedder failed to run. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto.CopyFrom( - packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) + packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]) + ) return ImageEmbedderResult.create_from_pb2(embedding_result_proto) @@ -268,7 +290,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): self, image: image_module.Image, timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None + image_processing_options: Optional[_ImageProcessingOptions] = None, ) -> None: """Sends live image data to embedder. @@ -301,19 +323,24 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): ValueError: If the current input timestamp is smaller than what the image embedder has already processed. """ - normalized_rect = self.convert_to_normalized_rect(image_processing_options) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - _NORM_RECT_STREAM_NAME: - packet_creator.create_proto(normalized_rect.to_pb2()).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) @classmethod - def cosine_similarity(cls, u: embedding_result_module.Embedding, - v: embedding_result_module.Embedding) -> float: + def cosine_similarity( + cls, + u: embedding_result_module.Embedding, + v: embedding_result_module.Embedding, + ) -> float: """Utility function to compute cosine similarity between two embedding entries. May return an InvalidArgumentError if e.g. the feature vectors are of diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 22a37cb3e..e50ffbf79 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -23,16 +23,23 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 +from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageSegmenterResult = List[image_module.Image] +_NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions -_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions +_ImageSegmenterGraphOptionsProto = ( + image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions +) _RunningMode = vision_task_running_mode.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' @@ -40,6 +47,8 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 @@ -77,19 +86,24 @@ class ImageSegmenterOptions: running_mode: _RunningMode = _RunningMode.IMAGE output_type: Optional[OutputType] = OutputType.CATEGORY_MASK activation: Optional[Activation] = Activation.NONE - result_callback: Optional[Callable[ - [List[image_module.Image], image_module.Image, int], None]] = None + result_callback: Optional[ + Callable[[ImageSegmenterResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: """Generates an ImageSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value, activation=self.activation.value) + output_type=self.output_type.value, activation=self.activation.value + ) return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, - segmenter_options=segmenter_options_proto) + segmenter_options=segmenter_options_proto, + ) class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): @@ -138,12 +152,14 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = ImageSegmenterOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod - def create_from_options(cls, - options: ImageSegmenterOptions) -> 'ImageSegmenter': + def create_from_options( + cls, options: ImageSegmenterOptions + ) -> 'ImageSegmenter': """Creates the `ImageSegmenter` object from image segmenter options. Args: @@ -162,31 +178,47 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME]) + output_packets[_SEGMENTATION_OUT_STREAM_NAME] + ) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp - options.result_callback(segmentation_result, image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + options.result_callback( + segmentation_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, - input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], output_streams=[ ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) - def segment(self, image: image_module.Image) -> List[image_module.Image]: + def segment( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> ImageSegmenterResult: """Performs the actual segmentation task on the provided MediaPipe Image. Args: image: MediaPipe Image. + image_processing_options: Options for image processing. Returns: If the output_type is CATEGORY_MASK, the returned vector of images is @@ -199,14 +231,26 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image segmentation failed to run. """ - output_packets = self._process_image_data( - {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME]) + output_packets[_SEGMENTATION_OUT_STREAM_NAME] + ) return segmentation_result - def segment_for_video(self, image: image_module.Image, - timestamp_ms: int) -> List[image_module.Image]: + def segment_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> ImageSegmenterResult: """Performs segmentation on the provided video frames. Only use this method when the ImageSegmenter is created with the video @@ -217,6 +261,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. Returns: If the output_type is CATEGORY_MASK, the returned vector of images is @@ -229,16 +274,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If image segmentation failed to run. """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME]) + output_packets[_SEGMENTATION_OUT_STREAM_NAME] + ) return segmentation_result - def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None: + def segment_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: """Sends live image data (an Image with a unique timestamp) to perform image segmentation. Only use this method when the ImageSegmenter is created with the live stream @@ -260,13 +317,20 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. Raises: ValueError: If the current input timestamp is smaller than what the image segmenter has already processed. """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) diff --git a/mediapipe/tasks/python/vision/object_detector.py b/mediapipe/tasks/python/vision/object_detector.py index 7c9993d62..023467234 100644 --- a/mediapipe/tasks/python/vision/object_detector.py +++ b/mediapipe/tasks/python/vision/object_detector.py @@ -22,15 +22,20 @@ from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.tasks.cc.vision.object_detector.proto import object_detector_options_pb2 from mediapipe.tasks.python.components.containers import detections as detections_module +from mediapipe.tasks.python.components.containers import rect from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module +ObjectDetectorResult = detections_module.DetectionResult +_NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _ObjectDetectorOptionsProto = object_detector_options_pb2.ObjectDetectorOptions _RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo _DETECTIONS_OUT_STREAM_NAME = 'detections_out' @@ -38,7 +43,10 @@ _DETECTIONS_TAG = 'DETECTIONS' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ObjectDetectorGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 @dataclasses.dataclass @@ -48,11 +56,10 @@ class ObjectDetectorOptions: Attributes: base_options: Base options for the object detector task. running_mode: The running mode of the task. Default to the image mode. - Object detector task has three running modes: - 1) The image mode for detecting objects on single image inputs. - 2) The video mode for detecting objects on the decoded frames of a video. - 3) The live stream mode for detecting objects on a live stream of input - data, such as from camera. + Object detector task has three running modes: 1) The image mode for + detecting objects on single image inputs. 2) The video mode for detecting + objects on the decoded frames of a video. 3) The live stream mode for + detecting objects on a live stream of input data, such as from camera. display_names_locale: The locale to use for display names specified through the TFLite Model Metadata. max_results: The maximum number of top-scored classification results to @@ -71,6 +78,7 @@ class ObjectDetectorOptions: data. The result callback should only be specified when the running mode is set to the live stream mode. """ + base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE display_names_locale: Optional[str] = None @@ -79,14 +87,16 @@ class ObjectDetectorOptions: category_allowlist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None result_callback: Optional[ - Callable[[detections_module.DetectionResult, image_module.Image, int], - None]] = None + Callable[[ObjectDetectorResult, image_module.Image, int], None] + ] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _ObjectDetectorOptionsProto: """Generates an ObjectDetectorOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) return _ObjectDetectorOptionsProto( base_options=base_options_proto, display_names_locale=self.display_names_locale, @@ -163,12 +173,14 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): """ base_options = _BaseOptions(model_asset_path=model_path) options = ObjectDetectorOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE) + base_options=base_options, running_mode=_RunningMode.IMAGE + ) return cls.create_from_options(options) @classmethod - def create_from_options(cls, - options: ObjectDetectorOptions) -> 'ObjectDetector': + def create_from_options( + cls, options: ObjectDetectorOptions + ) -> 'ObjectDetector': """Creates the `ObjectDetector` object from object detector options. Args: @@ -187,32 +199,45 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return detection_proto_list = packet_getter.get_proto_list( - output_packets[_DETECTIONS_OUT_STREAM_NAME]) - detection_result = detections_module.DetectionResult([ - detections_module.Detection.create_from_pb2(result) - for result in detection_proto_list - ]) + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + detection_result = ObjectDetectorResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback(detection_result, image, timestamp) task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, - input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], output_streams=[ ':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), ], - task_options=options) + task_options=options, + ) return cls( task_info.generate_graph_config( - enable_flow_limiting=options.running_mode == - _RunningMode.LIVE_STREAM), options.running_mode, - packets_callback if options.result_callback else None) + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) # TODO: Create an Image class for MediaPipe Tasks. - def detect(self, - image: image_module.Image) -> detections_module.DetectionResult: + def detect( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> ObjectDetectorResult: """Performs object detection on the provided MediaPipe Image. Only use this method when the ObjectDetector is created with the image @@ -220,6 +245,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. + image_processing_options: Options for image processing. Returns: A detection result object that contains a list of detections, each @@ -231,17 +257,31 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If object detection failed to run. """ - output_packets = self._process_image_data( - {_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) detection_proto_list = packet_getter.get_proto_list( - output_packets[_DETECTIONS_OUT_STREAM_NAME]) - return detections_module.DetectionResult([ - detections_module.Detection.create_from_pb2(result) - for result in detection_proto_list - ]) + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + return ObjectDetectorResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) - def detect_for_video(self, image: image_module.Image, - timestamp_ms: int) -> detections_module.DetectionResult: + def detect_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> ObjectDetectorResult: """Performs object detection on the provided video frames. Only use this method when the ObjectDetector is created with the video @@ -252,6 +292,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. Returns: A detection result object that contains a list of detections, each @@ -263,18 +304,33 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): ValueError: If any of the input arguments is invalid. RuntimeError: If object detection failed to run. """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) detection_proto_list = packet_getter.get_proto_list( - output_packets[_DETECTIONS_OUT_STREAM_NAME]) - return detections_module.DetectionResult([ - detections_module.Detection.create_from_pb2(result) - for result in detection_proto_list - ]) + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + return ObjectDetectorResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) - def detect_async(self, image: image_module.Image, timestamp_ms: int) -> None: + def detect_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: """Sends live image data (an Image with a unique timestamp) to perform object detection. Only use this method when the ObjectDetector is created with the live stream @@ -298,12 +354,20 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. Raises: ValueError: If the current input timestamp is smaller than what the object detector has already processed. """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image, roi_allowed=False + ) self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: - packet_creator.create_image(image).at(timestamp_ms) + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), })