Internal change

PiperOrigin-RevId: 522307800
This commit is contained in:
MediaPipe Team 2023-04-06 04:59:26 -07:00 committed by Copybara-Service
parent 2e256bebb5
commit 97bd9c2157
13 changed files with 964 additions and 454 deletions

View File

@ -33,6 +33,7 @@ py_test(
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:object_detector",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)

View File

@ -55,10 +55,10 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _generate_empty_results() -> ImageClassifierResult:
return ImageClassifierResult(
classifications=[
_Classifications(
categories=[], head_index=0, head_name='probability')
_Classifications(categories=[], head_index=0, head_name='probability')
],
timestamp_ms=0)
timestamp_ms=0,
)
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
@ -129,10 +129,11 @@ class ImageClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
@ -148,9 +149,11 @@ class ImageClassifierTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
model_asset_path='/path/to/invalid/model.tflite'
)
options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options)
@ -164,9 +167,11 @@ class ImageClassifierTest(parameterized.TestCase):
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify(self, model_file_type, max_results,
expected_classification_result):
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
)
def test_classify(
self, model_file_type, max_results, expected_classification_result
):
# Creates classifier.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
@ -179,23 +184,27 @@ class ImageClassifierTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, max_results=max_results
)
classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected_classification_result.to_pb2()
)
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result):
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
)
def test_classify_in_context(
self, model_file_type, max_results, expected_classification_result
):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
@ -207,13 +216,15 @@ class ImageClassifierTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, max_results=max_results
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(),
expected_classification_result.to_pb2())
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected_classification_result.to_pb2()
)
def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
@ -222,20 +233,110 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results().to_pb2())
test_utils.assert_proto_equals(
self, image_result.to_pb2(), _generate_soccer_ball_results().to_pb2()
)
def test_classify_succeeds_with_rotation(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options, max_results=3)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'burger_rotated.jpg')
)
)
# Specify a 90° anti-clockwise rotation.
image_processing_options = _ImageProcessingOptions(None, -90)
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
expected = ImageClassifierResult(
classifications=[
_Classifications(
categories=[
_Category(
index=934,
score=0.754467,
display_name='',
category_name='cheeseburger',
),
_Category(
index=925,
score=0.0288028,
display_name='',
category_name='guacamole',
),
_Category(
index=932,
score=0.0286119,
display_name='',
category_name='bagel',
),
],
head_index=0,
head_name='probability',
)
],
timestamp_ms=0,
)
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected.to_pb2()
)
def test_classify_succeeds_with_region_of_interest_and_rotation(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects_rotated.jpg')
)
)
# Region-of-interest around the soccer ball, with 90° anti-clockwise
# rotation.
roi = _Rect(left=0.2655, top=0.45, right=0.6925, bottom=0.614)
image_processing_options = _ImageProcessingOptions(roi, -90)
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
expected = ImageClassifierResult(
classifications=[
_Classifications(
categories=[
_Category(
index=806,
score=0.997684,
display_name='',
category_name='soccer ball',
),
],
head_index=0,
head_name='probability',
)
],
timestamp_ms=0,
)
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected.to_pb2()
)
def test_score_threshold_option(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD)
score_threshold=_SCORE_THRESHOLD,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -245,26 +346,33 @@ class ImageClassifierTest(parameterized.TestCase):
for category in classification.categories:
score = category.score
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
f'Classification with score lower than threshold found. '
f'{classification}')
score,
_SCORE_THRESHOLD,
(
'Classification with score lower than threshold found. '
f'{classification}'
),
)
def test_max_results_option(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD)
score_threshold=_SCORE_THRESHOLD,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].categories
self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.')
len(categories), _MAX_RESULTS, 'Too many results returned.'
)
def test_allow_list_option(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=_ALLOW_LIST)
category_allowlist=_ALLOW_LIST,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -273,13 +381,17 @@ class ImageClassifierTest(parameterized.TestCase):
for classification in classifications:
for category in classification.categories:
label = category.category_name
self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list')
self.assertIn(
label,
_ALLOW_LIST,
f'Label {label} found but not in label allow list',
)
def test_deny_list_option(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_denylist=_DENY_LIST)
category_denylist=_DENY_LIST,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -288,26 +400,30 @@ class ImageClassifierTest(parameterized.TestCase):
for classification in classifications:
for category in classification.categories:
label = category.category_name
self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.')
self.assertNotIn(
label, _DENY_LIST, f'Label {label} found but in deny list.'
)
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
r'exclusive options.',
):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=['foo'],
category_denylist=['bar'])
category_denylist=['bar'],
)
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_empty_classification_outputs(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=1)
score_threshold=1,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
@ -316,9 +432,11 @@ class ImageClassifierTest(parameterized.TestCase):
def test_missing_result_callback(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
@ -327,67 +445,81 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_calling_classify_for_video_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
classifier.classify_for_video(self.test_image, 0)
def test_calling_classify_async_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
classifier.classify_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
classifier.classify(self.test_image)
def test_calling_classify_async_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
classifier.classify_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier:
unused_result = classifier.classify_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
max_results=4)
max_results=4,
)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
self.test_image, timestamp
)
test_utils.assert_proto_equals(
self,
classification_result.to_pb2(),
@ -398,18 +530,22 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
max_results=1)
max_results=1,
)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, image_processing_options)
test_image, timestamp, image_processing_options
)
test_utils.assert_proto_equals(
self,
classification_result.to_pb2(),
@ -420,20 +556,24 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
classifier.classify(self.test_image)
def test_calling_classify_for_video_in_live_stream_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self):
@ -441,25 +581,32 @@ class ImageClassifierTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
classifier.classify_async(self.test_image, 0)
@parameterized.parameters((0, _generate_burger_results()),
(1, _generate_empty_results()))
@parameterized.parameters(
(0, _generate_burger_results()), (1, _generate_empty_results())
)
def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1
def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int):
test_utils.assert_proto_equals(self, result.to_pb2(),
expected_result.to_pb2())
def check_result(
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
):
test_utils.assert_proto_equals(
self, result.to_pb2(), expected_result.to_pb2()
)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
np.array_equal(
output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
@ -468,7 +615,8 @@ class ImageClassifierTest(parameterized.TestCase):
running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4,
score_threshold=threshold,
result_callback=check_result)
result_callback=check_result,
)
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 0)
@ -476,14 +624,17 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1
def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int):
def check_result(
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
):
test_utils.assert_proto_equals(
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
)
@ -496,7 +647,8 @@ class ImageClassifierTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=1,
result_callback=check_result)
result_callback=check_result,
)
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(test_image, 100, image_processing_options)

View File

@ -28,6 +28,7 @@ from mediapipe.tasks.python.components.containers import detections as detection
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import object_detector
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
@ -36,8 +37,10 @@ _BoundingBox = bounding_box_module.BoundingBox
_Detection = detections_module.Detection
_DetectionResult = detections_module.DetectionResult
_Image = image_module.Image
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_ObjectDetector = object_detector.ObjectDetector
_ObjectDetectorOptions = object_detector.ObjectDetectorOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
@ -115,10 +118,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
@ -134,9 +138,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
model_asset_path='/path/to/invalid/model.tflite'
)
options = _ObjectDetectorOptions(base_options=base_options)
_ObjectDetector.create_from_options(options)
@ -150,9 +156,11 @@ class ObjectDetectorTest(parameterized.TestCase):
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT))
def test_detect(self, model_file_type, max_results,
expected_detection_result):
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
)
def test_detect(
self, model_file_type, max_results, expected_detection_result
):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
@ -165,7 +173,8 @@ class ObjectDetectorTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ObjectDetectorOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, max_results=max_results
)
detector = _ObjectDetector.create_from_options(options)
# Performs object detection on the input.
@ -178,9 +187,11 @@ class ObjectDetectorTest(parameterized.TestCase):
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT))
def test_detect_in_context(self, model_file_type, max_results,
expected_detection_result):
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
)
def test_detect_in_context(
self, model_file_type, max_results, expected_detection_result
):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
@ -192,7 +203,8 @@ class ObjectDetectorTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.')
options = _ObjectDetectorOptions(
base_options=base_options, max_results=max_results)
base_options=base_options, max_results=max_results
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
@ -202,7 +214,8 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_score_threshold_option(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD)
score_threshold=_SCORE_THRESHOLD,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
@ -211,25 +224,30 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections:
score = detection.categories[0].score
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
f'Detection with score lower than threshold found. {detection}')
score,
_SCORE_THRESHOLD,
f'Detection with score lower than threshold found. {detection}',
)
def test_max_results_option(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
max_results=_MAX_RESULTS)
max_results=_MAX_RESULTS,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
detections = detection_result.detections
self.assertLessEqual(
len(detections), _MAX_RESULTS, 'Too many results returned.')
len(detections), _MAX_RESULTS, 'Too many results returned.'
)
def test_allow_list_option(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=_ALLOW_LIST)
category_allowlist=_ALLOW_LIST,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
@ -237,13 +255,17 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections:
label = detection.categories[0].category_name
self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list')
self.assertIn(
label,
_ALLOW_LIST,
f'Label {label} found but not in label allow list',
)
def test_deny_list_option(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_denylist=_DENY_LIST)
category_denylist=_DENY_LIST,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
@ -251,26 +273,30 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections:
label = detection.categories[0].category_name
self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.')
self.assertNotIn(
label, _DENY_LIST, f'Label {label} found but in deny list.'
)
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
r'exclusive options.',
):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=['foo'],
category_denylist=['bar'])
category_denylist=['bar'],
)
with _ObjectDetector.create_from_options(options) as unused_detector:
pass
def test_empty_detection_outputs(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=1)
score_threshold=1,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
@ -279,9 +305,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_missing_result_callback(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ObjectDetector.create_from_options(options) as unused_detector:
pass
@ -290,56 +318,68 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ObjectDetector.create_from_options(options) as unused_detector:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
running_mode=_RUNNING_MODE.IMAGE,
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector:
unused_result = detector.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_for_video(self.test_image, 0)
# TODO: Tests how `detect_for_video` handles the temporal data
@ -348,7 +388,8 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
max_results=4)
max_results=4,
)
with _ObjectDetector.create_from_options(options) as detector:
for timestamp in range(0, 300, 30):
detection_result = detector.detect_for_video(self.test_image, timestamp)
@ -358,20 +399,24 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
@ -379,24 +424,30 @@ class ObjectDetectorTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4,
result_callback=mock.MagicMock())
result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector:
detector.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_async(self.test_image, 0)
@parameterized.parameters((0, _EXPECTED_DETECTION_RESULT),
(1, _DetectionResult(detections=[])))
@parameterized.parameters(
(0, _EXPECTED_DETECTION_RESULT), (1, _DetectionResult(detections=[]))
)
def test_detect_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1
def check_result(result: _DetectionResult, output_image: _Image,
timestamp_ms: int):
def check_result(
result: _DetectionResult, output_image: _Image, timestamp_ms: int
):
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
np.array_equal(
output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
@ -405,7 +456,8 @@ class ObjectDetectorTest(parameterized.TestCase):
running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4,
score_threshold=threshold,
result_callback=check_result)
result_callback=check_result,
)
detector = _ObjectDetector.create_from_options(options)
for timestamp in range(0, 300, 30):
detector.detect_async(self.test_image, timestamp)

View File

@ -29,10 +29,12 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_py_pb2",
"//mediapipe/tasks/python/components/containers:detections",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
@ -71,10 +73,12 @@ py_library(
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)

View File

@ -17,6 +17,7 @@ import math
from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.python.components.containers import rect as rect_module
@ -39,8 +40,9 @@ class BaseVisionTaskApi(object):
self,
graph_config: calculator_pb2.CalculatorGraphConfig,
running_mode: _RunningMode,
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]],
None]] = None
packet_callback: Optional[
Callable[[Mapping[str, packet_module.Packet]], None]
] = None,
) -> None:
"""Initializes the `BaseVisionTaskApi` object.
@ -48,7 +50,7 @@ class BaseVisionTaskApi(object):
graph_config: The mediapipe vision task graph config proto.
running_mode: The running mode of the mediapipe vision task.
packet_callback: The optional packet callback for getting results
asynchronously in the live stream mode.
asynchronously in the live stream mode.
Raises:
ValueError: The packet callback is not properly set based on the task's
@ -58,16 +60,19 @@ class BaseVisionTaskApi(object):
if packet_callback is None:
raise ValueError(
'The vision task is in live stream mode, a user-defined result '
'callback must be provided.')
'callback must be provided.'
)
elif packet_callback:
raise ValueError(
'The vision task is in image or video mode, a user-defined result '
'callback should not be provided.')
'callback should not be provided.'
)
self._runner = _TaskRunner.create(graph_config, packet_callback)
self._running_mode = running_mode
def _process_image_data(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
self, inputs: Mapping[str, _Packet]
) -> Mapping[str, _Packet]:
"""A synchronous method to process single image inputs.
The call blocks the current thread until a failure status or a successful
@ -84,12 +89,14 @@ class BaseVisionTaskApi(object):
"""
if self._running_mode != _RunningMode.IMAGE:
raise ValueError(
'Task is not initialized with the image mode. Current running mode:' +
self._running_mode.name)
'Task is not initialized with the image mode. Current running mode:'
+ self._running_mode.name
)
return self._runner.process(inputs)
def _process_video_data(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
self, inputs: Mapping[str, _Packet]
) -> Mapping[str, _Packet]:
"""A synchronous method to process continuous video frames.
The call blocks the current thread until a failure status or a successful
@ -106,8 +113,9 @@ class BaseVisionTaskApi(object):
"""
if self._running_mode != _RunningMode.VIDEO:
raise ValueError(
'Task is not initialized with the video mode. Current running mode:' +
self._running_mode.name)
'Task is not initialized with the video mode. Current running mode:'
+ self._running_mode.name
)
return self._runner.process(inputs)
def _send_live_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
@ -124,13 +132,18 @@ class BaseVisionTaskApi(object):
"""
if self._running_mode != _RunningMode.LIVE_STREAM:
raise ValueError(
'Task is not initialized with the live stream mode. Current running mode:'
+ self._running_mode.name)
'Task is not initialized with the live stream mode. Current running'
' mode:'
+ self._running_mode.name
)
self._runner.send(inputs)
def convert_to_normalized_rect(self,
options: _ImageProcessingOptions,
roi_allowed: bool = True) -> _NormalizedRect:
def convert_to_normalized_rect(
self,
options: _ImageProcessingOptions,
image: image_module.Image,
roi_allowed: bool = True,
) -> _NormalizedRect:
"""Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly.
If the input ImageProcessingOptions is not present, returns a default
@ -140,6 +153,7 @@ class BaseVisionTaskApi(object):
Args:
options: Options for image processing.
image: The image to process.
roi_allowed: Indicates if the `region_of_interest` field is allowed to be
set. By default, it's set to True.
@ -147,7 +161,8 @@ class BaseVisionTaskApi(object):
A normalized rect proto that represents the image processing options.
"""
normalized_rect = _NormalizedRect(
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1)
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1
)
if options is None:
return normalized_rect
@ -169,6 +184,20 @@ class BaseVisionTaskApi(object):
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
normalized_rect.width = roi.right - roi.left
normalized_rect.height = roi.bottom - roi.top
# For 90° and 270° rotations, we need to swap width and height.
# This is due to the internal behavior of ImageToTensorCalculator, which:
# - first denormalizes the provided rect by multiplying the rect width or
# height by the image width or height, repectively.
# - then rotates this by denormalized rect by the provided rotation, and
# uses this for cropping,
# - then finally rotates this back.
if abs(options.rotation_degrees % 180) != 0:
w = normalized_rect.height * image.height / image.width
h = normalized_rect.width * image.width / image.height
normalized_rect.width = w
normalized_rect.height = h
return normalized_rect
def close(self) -> None:

View File

@ -212,7 +212,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
@ -261,7 +261,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
@ -320,7 +320,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
detector has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(

View File

@ -401,7 +401,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
@ -444,7 +444,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
@ -497,7 +497,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
face landmarker has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_GestureRecognizerGraphOptionsProto = (
gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
)
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -53,7 +55,9 @@ _HAND_LANDMARKS_STREAM_NAME = 'landmarks'
_HAND_LANDMARKS_TAG = 'LANDMARKS'
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
_TASK_GRAPH_NAME = (
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
)
_MICRO_SECONDS_PER_MILLISECOND = 1000
_GESTURE_DEFAULT_INDEX = -1
@ -78,17 +82,21 @@ class GestureRecognizerResult:
def _build_recognition_result(
output_packets: Mapping[str,
packet_module.Packet]) -> GestureRecognizerResult:
output_packets: Mapping[str, packet_module.Packet]
) -> GestureRecognizerResult:
"""Constructs a `GestureRecognizerResult` from output packets."""
gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME])
output_packets[_HAND_GESTURE_STREAM_NAME]
)
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
output_packets[_HANDEDNESS_STREAM_NAME]
)
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
output_packets[_HAND_LANDMARKS_STREAM_NAME]
)
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
gesture_results = []
for proto in gestures_proto_list:
@ -101,7 +109,9 @@ def _build_recognition_result(
index=_GESTURE_DEFAULT_INDEX,
score=gesture.score,
display_name=gesture.display_name,
category_name=gesture.label))
category_name=gesture.label,
)
)
gesture_results.append(gesture_categories)
handedness_results = []
@ -115,7 +125,9 @@ def _build_recognition_result(
index=handedness.index,
score=handedness.score,
display_name=handedness.display_name,
category_name=handedness.label))
category_name=handedness.label,
)
)
handedness_results.append(handedness_categories)
hand_landmarks_results = []
@ -125,7 +137,8 @@ def _build_recognition_result(
hand_landmarks_list = []
for hand_landmark in hand_landmarks.landmark:
hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = []
@ -135,12 +148,16 @@ def _build_recognition_result(
hand_world_landmarks_list = []
for hand_world_landmark in hand_world_landmarks.landmark:
hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
)
hand_world_landmarks_results.append(hand_world_landmarks_list)
return GestureRecognizerResult(gesture_results, handedness_results,
hand_landmarks_results,
hand_world_landmarks_results)
return GestureRecognizerResult(
gesture_results,
handedness_results,
hand_landmarks_results,
hand_world_landmarks_results,
)
@dataclasses.dataclass
@ -174,43 +191,62 @@ class GestureRecognizerOptions:
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5
canned_gesture_classifier_options: Optional[
_ClassifierOptions] = dataclasses.field(
default_factory=_ClassifierOptions)
custom_gesture_classifier_options: Optional[
_ClassifierOptions] = dataclasses.field(
default_factory=_ClassifierOptions)
result_callback: Optional[Callable[
[GestureRecognizerResult, image_module.Image, int], None]] = None
canned_gesture_classifier_options: Optional[_ClassifierOptions] = (
dataclasses.field(default_factory=_ClassifierOptions)
)
custom_gesture_classifier_options: Optional[_ClassifierOptions] = (
dataclasses.field(default_factory=_ClassifierOptions)
)
result_callback: Optional[
Callable[[GestureRecognizerResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
"""Generates an GestureRecognizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
# Initialize gesture recognizer options from base options.
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto)
base_options=base_options_proto
)
# Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
hand_landmarker_options_proto = (
gesture_recognizer_options_proto.hand_landmarker_graph_options
)
hand_landmarker_options_proto.min_tracking_confidence = (
self.min_tracking_confidence
)
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
self.num_hands
)
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
self.min_hand_detection_confidence
)
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_presence_confidence
)
# Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
hand_gesture_recognizer_options_proto = (
gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
)
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.canned_gesture_classifier_options.to_pb2())
self.canned_gesture_classifier_options.to_pb2()
)
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.custom_gesture_classifier_options.to_pb2())
self.custom_gesture_classifier_options.to_pb2()
)
return gesture_recognizer_options_proto
@ -239,12 +275,14 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = GestureRecognizerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(
cls, options: GestureRecognizerOptions) -> 'GestureRecognizer':
cls, options: GestureRecognizerOptions
) -> 'GestureRecognizer':
"""Creates the `GestureRecognizer` object from gesture recognizer options.
Args:
@ -268,14 +306,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
options.result_callback(
GestureRecognizerResult([], [], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
GestureRecognizerResult([], [], [], []),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return
gesture_recognizer_result = _build_recognition_result(output_packets)
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
options.result_callback(gesture_recognizer_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
options.result_callback(
gesture_recognizer_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -286,23 +329,27 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
output_streams=[
':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]),
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG,
_HAND_LANDMARKS_STREAM_NAME]), ':'.join([
_HAND_WORLD_LANDMARKS_TAG,
_HAND_WORLD_LANDMARKS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
':'.join(
[_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def recognize(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> GestureRecognizerResult:
"""Performs hand gesture recognition on the given image.
@ -325,12 +372,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If gesture recognition failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
@ -342,7 +390,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> GestureRecognizerResult:
"""Performs gesture recognition on the provided video frame.
@ -367,14 +415,15 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If gesture recognition failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
@ -386,7 +435,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data to perform gesture recognition.
@ -419,12 +468,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
gesture recognizer has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_HandLandmarkerGraphOptionsProto = (
hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
@ -56,6 +58,7 @@ _MICRO_SECONDS_PER_MILLISECOND = 1000
class HandLandmark(enum.IntEnum):
"""The 21 hand landmarks."""
WRIST = 0
THUMB_CMC = 1
THUMB_MCP = 2
@ -95,14 +98,18 @@ class HandLandmarkerResult:
def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult:
output_packets: Mapping[str, packet_module.Packet]
) -> HandLandmarkerResult:
"""Constructs a `HandLandmarksDetectionResult` from output packets."""
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
output_packets[_HANDEDNESS_STREAM_NAME]
)
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
output_packets[_HAND_LANDMARKS_STREAM_NAME]
)
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
handedness_results = []
for proto in handedness_proto_list:
@ -115,7 +122,9 @@ def _build_landmarker_result(
index=handedness.index,
score=handedness.score,
display_name=handedness.display_name,
category_name=handedness.label))
category_name=handedness.label,
)
)
handedness_results.append(handedness_categories)
hand_landmarks_results = []
@ -125,7 +134,8 @@ def _build_landmarker_result(
hand_landmarks_list = []
for hand_landmark in hand_landmarks.landmark:
hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = []
@ -135,11 +145,13 @@ def _build_landmarker_result(
hand_world_landmarks_list = []
for hand_world_landmark in hand_world_landmarks.landmark:
hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
)
hand_world_landmarks_results.append(hand_world_landmarks_list)
return HandLandmarkerResult(handedness_results, hand_landmarks_results,
hand_world_landmarks_results)
return HandLandmarkerResult(
handedness_results, hand_landmarks_results, hand_world_landmarks_results
)
@dataclasses.dataclass
@ -167,28 +179,41 @@ class HandLandmarkerOptions:
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5
result_callback: Optional[Callable[
[HandLandmarkerResult, image_module.Image, int], None]] = None
result_callback: Optional[
Callable[[HandLandmarkerResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
"""Generates an HandLandmarkerGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
# Initialize the hand landmarker options from base options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
base_options=base_options_proto)
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
base_options=base_options_proto
)
hand_landmarker_options_proto.min_tracking_confidence = (
self.min_tracking_confidence
)
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
self.num_hands
)
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
self.min_hand_detection_confidence
)
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_presence_confidence
)
return hand_landmarker_options_proto
@ -216,12 +241,14 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = HandLandmarkerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: HandLandmarkerOptions) -> 'HandLandmarker':
def create_from_options(
cls, options: HandLandmarkerOptions
) -> 'HandLandmarker':
"""Creates the `HandLandmarker` object from hand landmarker options.
Args:
@ -245,14 +272,19 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
options.result_callback(
HandLandmarkerResult([], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
HandLandmarkerResult([], [], []),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return
hand_landmarks_detection_result = _build_landmarker_result(output_packets)
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
options.result_callback(hand_landmarks_detection_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
options.result_callback(
hand_landmarks_detection_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -263,21 +295,26 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
output_streams=[
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
':'.join([
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join(
[_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def detect(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the given image.
@ -300,12 +337,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If hand landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
@ -317,7 +355,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the provided video frame.
@ -342,14 +380,15 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If hand landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
@ -361,7 +400,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data to perform hand landmarks detection.
@ -394,12 +433,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
hand landmarker has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})

View File

@ -35,7 +35,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ImageClassifierGraphOptionsProto = (
image_classifier_graph_options_pb2.ImageClassifierGraphOptions
)
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -48,7 +50,9 @@ _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_TASK_GRAPH_NAME = (
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
)
_MICRO_SECONDS_PER_MILLISECOND = 1000
@ -81,6 +85,7 @@ class ImageClassifierOptions:
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
display_names_locale: Optional[str] = None
@ -88,24 +93,29 @@ class ImageClassifierOptions:
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[
[ImageClassifierResult, image_module.Image, int], None]] = None
result_callback: Optional[
Callable[[ImageClassifierResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
"""Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results)
max_results=self.max_results,
)
return _ImageClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
classifier_options=classifier_options_proto,
)
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
@ -165,12 +175,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ImageClassifierOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ImageClassifierOptions) -> 'ImageClassifier':
def create_from_options(
cls, options: ImageClassifierOptions
) -> 'ImageClassifier':
"""Creates the `ImageClassifier` object from image classifier options.
Args:
@ -191,12 +203,15 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(
ImageClassifierResult.create_from_pb2(classification_result_proto),
image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -206,19 +221,23 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
],
output_streams=[
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def classify(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageClassifierResult:
"""Performs image classification on the provided MediaPipe Image.
@ -233,17 +252,20 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
return ImageClassifierResult.create_from_pb2(classification_result_proto)
@ -251,7 +273,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageClassifierResult:
"""Performs image classification on the provided video frames.
@ -272,19 +294,22 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
return ImageClassifierResult.create_from_pb2(classification_result_proto)
@ -292,7 +317,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
@ -320,12 +345,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If the current input timestamp is smaller than what the image
classifier has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
_ImageEmbedderGraphOptionsProto = (
image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
)
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
@ -74,24 +76,29 @@ class ImageEmbedderOptions:
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None
result_callback: Optional[Callable[
[ImageEmbedderResult, image_module.Image, int], None]] = None
result_callback: Optional[
Callable[[ImageEmbedderResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageEmbedderGraphOptionsProto:
"""Generates an ImageEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize)
l2_normalize=self.l2_normalize, quantize=self.quantize
)
return _ImageEmbedderGraphOptionsProto(
base_options=base_options_proto,
embedder_options=embedder_options_proto)
base_options=base_options_proto, embedder_options=embedder_options_proto
)
class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
@ -135,12 +142,14 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ImageEmbedderOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ImageEmbedderOptions) -> 'ImageEmbedder':
def create_from_options(
cls, options: ImageEmbedderOptions
) -> 'ImageEmbedder':
"""Creates the `ImageEmbedder` object from image embedder options.
Args:
@ -161,13 +170,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(
ImageEmbedderResult.create_from_pb2(embedding_result_proto), image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
@ -177,19 +189,23 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
],
output_streams=[
':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def embed(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image.
@ -207,17 +223,20 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
@ -225,7 +244,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames.
@ -249,18 +268,21 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
@ -268,7 +290,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data to embedder.
@ -301,19 +323,24 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If the current input timestamp is smaller than what the image
embedder has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
@classmethod
def cosine_similarity(cls, u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding) -> float:
def cosine_similarity(
cls,
u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding,
) -> float:
"""Utility function to compute cosine similarity between two embedding entries.
May return an InvalidArgumentError if e.g. the feature vectors are of

View File

@ -23,16 +23,23 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
_ImageSegmenterGraphOptionsProto = (
image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
)
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
@ -40,6 +47,8 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@ -77,19 +86,24 @@ class ImageSegmenterOptions:
running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
activation: Optional[Activation] = Activation.NONE
result_callback: Optional[Callable[
[List[image_module.Image], image_module.Image, int], None]] = None
result_callback: Optional[
Callable[[ImageSegmenterResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value)
output_type=self.output_type.value, activation=self.activation.value
)
return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto,
segmenter_options=segmenter_options_proto)
segmenter_options=segmenter_options_proto,
)
class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
@ -138,12 +152,14 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ImageSegmenterOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ImageSegmenterOptions) -> 'ImageSegmenter':
def create_from_options(
cls, options: ImageSegmenterOptions
) -> 'ImageSegmenter':
"""Creates the `ImageSegmenter` object from image segmenter options.
Args:
@ -162,31 +178,47 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
options.result_callback(segmentation_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
options.result_callback(
segmentation_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def segment(self, image: image_module.Image) -> List[image_module.Image]:
def segment(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is
@ -199,14 +231,26 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
return segmentation_result
def segment_for_video(self, image: image_module.Image,
timestamp_ms: int) -> List[image_module.Image]:
def segment_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageSegmenterResult:
"""Performs segmentation on the provided video frames.
Only use this method when the ImageSegmenter is created with the video
@ -217,6 +261,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is
@ -229,16 +274,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
return segmentation_result
def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None:
def segment_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image segmentation.
Only use this method when the ImageSegmenter is created with the live stream
@ -260,13 +317,20 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the image
segmenter has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})

View File

@ -22,15 +22,20 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.object_detector.proto import object_detector_options_pb2
from mediapipe.tasks.python.components.containers import detections as detections_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
ObjectDetectorResult = detections_module.DetectionResult
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ObjectDetectorOptionsProto = object_detector_options_pb2.ObjectDetectorOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_DETECTIONS_OUT_STREAM_NAME = 'detections_out'
@ -38,7 +43,10 @@ _DETECTIONS_TAG = 'DETECTIONS'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ObjectDetectorGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
@ -48,11 +56,10 @@ class ObjectDetectorOptions:
Attributes:
base_options: Base options for the object detector task.
running_mode: The running mode of the task. Default to the image mode.
Object detector task has three running modes:
1) The image mode for detecting objects on single image inputs.
2) The video mode for detecting objects on the decoded frames of a video.
3) The live stream mode for detecting objects on a live stream of input
data, such as from camera.
Object detector task has three running modes: 1) The image mode for
detecting objects on single image inputs. 2) The video mode for detecting
objects on the decoded frames of a video. 3) The live stream mode for
detecting objects on a live stream of input data, such as from camera.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
@ -71,6 +78,7 @@ class ObjectDetectorOptions:
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
display_names_locale: Optional[str] = None
@ -79,14 +87,16 @@ class ObjectDetectorOptions:
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
result_callback: Optional[
Callable[[detections_module.DetectionResult, image_module.Image, int],
None]] = None
Callable[[ObjectDetectorResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ObjectDetectorOptionsProto:
"""Generates an ObjectDetectorOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
return _ObjectDetectorOptionsProto(
base_options=base_options_proto,
display_names_locale=self.display_names_locale,
@ -163,12 +173,14 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ObjectDetectorOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ObjectDetectorOptions) -> 'ObjectDetector':
def create_from_options(
cls, options: ObjectDetectorOptions
) -> 'ObjectDetector':
"""Creates the `ObjectDetector` object from object detector options.
Args:
@ -187,32 +199,45 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME])
detection_result = detections_module.DetectionResult([
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
])
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
detection_result = ObjectDetectorResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(detection_result, image, timestamp)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options)
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
# TODO: Create an Image class for MediaPipe Tasks.
def detect(self,
image: image_module.Image) -> detections_module.DetectionResult:
def detect(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ObjectDetectorResult:
"""Performs object detection on the provided MediaPipe Image.
Only use this method when the ObjectDetector is created with the image
@ -220,6 +245,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
A detection result object that contains a list of detections, each
@ -231,17 +257,31 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If object detection failed to run.
"""
output_packets = self._process_image_data(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME])
return detections_module.DetectionResult([
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
])
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
return ObjectDetectorResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_for_video(self, image: image_module.Image,
timestamp_ms: int) -> detections_module.DetectionResult:
def detect_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ObjectDetectorResult:
"""Performs object detection on the provided video frames.
Only use this method when the ObjectDetector is created with the video
@ -252,6 +292,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
A detection result object that contains a list of detections, each
@ -263,18 +304,33 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid.
RuntimeError: If object detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME])
return detections_module.DetectionResult([
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
])
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
return ObjectDetectorResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_async(self, image: image_module.Image, timestamp_ms: int) -> None:
def detect_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform object detection.
Only use this method when the ObjectDetector is created with the live stream
@ -298,12 +354,20 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the object
detector has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms)
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})