Internal change

PiperOrigin-RevId: 522307800
This commit is contained in:
MediaPipe Team 2023-04-06 04:59:26 -07:00 committed by Copybara-Service
parent 2e256bebb5
commit 97bd9c2157
13 changed files with 964 additions and 454 deletions

View File

@ -33,6 +33,7 @@ py_test(
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision:object_detector",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )

View File

@ -55,10 +55,10 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _generate_empty_results() -> ImageClassifierResult: def _generate_empty_results() -> ImageClassifierResult:
return ImageClassifierResult( return ImageClassifierResult(
classifications=[ classifications=[
_Classifications( _Classifications(categories=[], head_index=0, head_name='probability')
categories=[], head_index=0, head_name='probability')
], ],
timestamp_ms=0) timestamp_ms=0,
)
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult: def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
@ -129,10 +129,11 @@ class ImageClassifierTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) )
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -148,9 +149,11 @@ class ImageClassifierTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions( base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite') model_asset_path='/path/to/invalid/model.tflite'
)
options = _ImageClassifierOptions(base_options=base_options) options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options) _ImageClassifier.create_from_options(options)
@ -164,9 +167,11 @@ class ImageClassifierTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results()), (ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
def test_classify(self, model_file_type, max_results, )
expected_classification_result): def test_classify(
self, model_file_type, max_results, expected_classification_result
):
# Creates classifier. # Creates classifier.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -179,23 +184,27 @@ class ImageClassifierTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results) base_options=base_options, max_results=max_results
)
classifier = _ImageClassifier.create_from_options(options) classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(), test_utils.assert_proto_equals(
expected_classification_result.to_pb2()) self, image_result.to_pb2(), expected_classification_result.to_pb2()
)
# Closes the classifier explicitly when the classifier is not used in # Closes the classifier explicitly when the classifier is not used in
# a context. # a context.
classifier.close() classifier.close()
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results()), (ModelFileType.FILE_NAME, 4, _generate_burger_results()),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results())) (ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
def test_classify_in_context(self, model_file_type, max_results, )
expected_classification_result): def test_classify_in_context(
self, model_file_type, max_results, expected_classification_result
):
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT: elif model_file_type is ModelFileType.FILE_CONTENT:
@ -207,13 +216,15 @@ class ImageClassifierTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=base_options, max_results=max_results) base_options=base_options, max_results=max_results
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
# Comparing results. # Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(), test_utils.assert_proto_equals(
expected_classification_result.to_pb2()) self, image_result.to_pb2(), expected_classification_result.to_pb2()
)
def test_classify_succeeds_with_region_of_interest(self): def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -222,20 +233,110 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball. # Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi) image_processing_options = _ImageProcessingOptions(roi)
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options) image_result = classifier.classify(test_image, image_processing_options)
# Comparing results. # Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(), test_utils.assert_proto_equals(
_generate_soccer_ball_results().to_pb2()) self, image_result.to_pb2(), _generate_soccer_ball_results().to_pb2()
)
def test_classify_succeeds_with_rotation(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options, max_results=3)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'burger_rotated.jpg')
)
)
# Specify a 90° anti-clockwise rotation.
image_processing_options = _ImageProcessingOptions(None, -90)
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
expected = ImageClassifierResult(
classifications=[
_Classifications(
categories=[
_Category(
index=934,
score=0.754467,
display_name='',
category_name='cheeseburger',
),
_Category(
index=925,
score=0.0288028,
display_name='',
category_name='guacamole',
),
_Category(
index=932,
score=0.0286119,
display_name='',
category_name='bagel',
),
],
head_index=0,
head_name='probability',
)
],
timestamp_ms=0,
)
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected.to_pb2()
)
def test_classify_succeeds_with_region_of_interest_and_rotation(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects_rotated.jpg')
)
)
# Region-of-interest around the soccer ball, with 90° anti-clockwise
# rotation.
roi = _Rect(left=0.2655, top=0.45, right=0.6925, bottom=0.614)
image_processing_options = _ImageProcessingOptions(roi, -90)
# Performs image classification on the input.
image_result = classifier.classify(test_image, image_processing_options)
# Comparing results.
expected = ImageClassifierResult(
classifications=[
_Classifications(
categories=[
_Category(
index=806,
score=0.997684,
display_name='',
category_name='soccer ball',
),
],
head_index=0,
head_name='probability',
)
],
timestamp_ms=0,
)
test_utils.assert_proto_equals(
self, image_result.to_pb2(), expected.to_pb2()
)
def test_score_threshold_option(self): def test_score_threshold_option(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD) score_threshold=_SCORE_THRESHOLD,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -245,26 +346,33 @@ class ImageClassifierTest(parameterized.TestCase):
for category in classification.categories: for category in classification.categories:
score = category.score score = category.score
self.assertGreaterEqual( self.assertGreaterEqual(
score, _SCORE_THRESHOLD, score,
f'Classification with score lower than threshold found. ' _SCORE_THRESHOLD,
f'{classification}') (
'Classification with score lower than threshold found. '
f'{classification}'
),
)
def test_max_results_option(self): def test_max_results_option(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD) score_threshold=_SCORE_THRESHOLD,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].categories categories = image_result.classifications[0].categories
self.assertLessEqual( self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.') len(categories), _MAX_RESULTS, 'Too many results returned.'
)
def test_allow_list_option(self): def test_allow_list_option(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=_ALLOW_LIST) category_allowlist=_ALLOW_LIST,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -273,13 +381,17 @@ class ImageClassifierTest(parameterized.TestCase):
for classification in classifications: for classification in classifications:
for category in classification.categories: for category in classification.categories:
label = category.category_name label = category.category_name
self.assertIn(label, _ALLOW_LIST, self.assertIn(
f'Label {label} found but not in label allow list') label,
_ALLOW_LIST,
f'Label {label} found but not in label allow list',
)
def test_deny_list_option(self): def test_deny_list_option(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_denylist=_DENY_LIST) category_denylist=_DENY_LIST,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -288,26 +400,30 @@ class ImageClassifierTest(parameterized.TestCase):
for classification in classifications: for classification in classifications:
for category in classification.categories: for category in classification.categories:
label = category.category_name label = category.category_name
self.assertNotIn(label, _DENY_LIST, self.assertNotIn(
f'Label {label} found but in deny list.') label, _DENY_LIST, f'Label {label} found but in deny list.'
)
def test_combined_allowlist_and_denylist(self): def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist # Fails with combined allowlist and denylist
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r'`category_allowlist` and `category_denylist` are mutually ' r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'): r'exclusive options.',
):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=['foo'], category_allowlist=['foo'],
category_denylist=['bar']) category_denylist=['bar'],
)
with _ImageClassifier.create_from_options(options) as unused_classifier: with _ImageClassifier.create_from_options(options) as unused_classifier:
pass pass
def test_empty_classification_outputs(self): def test_empty_classification_outputs(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=1) score_threshold=1,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(self.test_image) image_result = classifier.classify(self.test_image)
@ -316,9 +432,11 @@ class ImageClassifierTest(parameterized.TestCase):
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM) running_mode=_RUNNING_MODE.LIVE_STREAM,
with self.assertRaisesRegex(ValueError, )
r'result callback must be provided'): with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ImageClassifier.create_from_options(options) as unused_classifier: with _ImageClassifier.create_from_options(options) as unused_classifier:
pass pass
@ -327,67 +445,81 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode, running_mode=running_mode,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
with self.assertRaisesRegex(ValueError, )
r'result callback should not be provided'): with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ImageClassifier.create_from_options(options) as unused_classifier: with _ImageClassifier.create_from_options(options) as unused_classifier:
pass pass
def test_calling_classify_for_video_in_image_mode(self): def test_calling_classify_for_video_in_image_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
classifier.classify_for_video(self.test_image, 0) classifier.classify_for_video(self.test_image, 0)
def test_calling_classify_async_in_image_mode(self): def test_calling_classify_async_in_image_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self): def test_calling_classify_in_video_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
classifier.classify(self.test_image) classifier.classify(self.test_image)
def test_calling_classify_async_in_video_mode(self): def test_calling_classify_async_in_video_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self): def test_classify_for_video_with_out_of_order_timestamp(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
unused_result = classifier.classify_for_video(self.test_image, 1) unused_result = classifier.classify_for_video(self.test_image, 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
classifier.classify_for_video(self.test_image, 0) classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self): def test_classify_for_video(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, running_mode=_RUNNING_MODE.VIDEO,
max_results=4) max_results=4,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
self.test_image, timestamp) self.test_image, timestamp
)
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, self,
classification_result.to_pb2(), classification_result.to_pb2(),
@ -398,18 +530,22 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, running_mode=_RUNNING_MODE.VIDEO,
max_results=1) max_results=1,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball. # Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi) image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, image_processing_options) test_image, timestamp, image_processing_options
)
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, self,
classification_result.to_pb2(), classification_result.to_pb2(),
@ -420,20 +556,24 @@ class ImageClassifierTest(parameterized.TestCase):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
classifier.classify(self.test_image) classifier.classify(self.test_image)
def test_calling_classify_for_video_in_live_stream_mode(self): def test_calling_classify_for_video_in_live_stream_mode(self):
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
classifier.classify_for_video(self.test_image, 0) classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self): def test_classify_async_calls_with_illegal_timestamp(self):
@ -441,25 +581,32 @@ class ImageClassifierTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4, max_results=4,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100) classifier.classify_async(self.test_image, 100)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
@parameterized.parameters((0, _generate_burger_results()), @parameterized.parameters(
(1, _generate_empty_results())) (0, _generate_burger_results()), (1, _generate_empty_results())
)
def test_classify_async_calls(self, threshold, expected_result): def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: ImageClassifierResult, output_image: _Image, def check_result(
timestamp_ms: int): result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
test_utils.assert_proto_equals(self, result.to_pb2(), ):
expected_result.to_pb2()) test_utils.assert_proto_equals(
self, result.to_pb2(), expected_result.to_pb2()
)
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(
self.test_image.numpy_view())) output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
@ -468,7 +615,8 @@ class ImageClassifierTest(parameterized.TestCase):
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4, max_results=4,
score_threshold=threshold, score_threshold=threshold,
result_callback=check_result) result_callback=check_result,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 0) classifier.classify_async(self.test_image, 0)
@ -476,14 +624,17 @@ class ImageClassifierTest(parameterized.TestCase):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
)
)
# Region-of-interest around the soccer ball. # Region-of-interest around the soccer ball.
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345) roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
image_processing_options = _ImageProcessingOptions(roi) image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: ImageClassifierResult, output_image: _Image, def check_result(
timestamp_ms: int): result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
):
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2() self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
) )
@ -496,7 +647,8 @@ class ImageClassifierTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=1, max_results=1,
result_callback=check_result) result_callback=check_result,
)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(test_image, 100, image_processing_options) classifier.classify_async(test_image, 100, image_processing_options)

View File

@ -28,6 +28,7 @@ from mediapipe.tasks.python.components.containers import detections as detection
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import object_detector from mediapipe.tasks.python.vision import object_detector
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
@ -36,8 +37,10 @@ _BoundingBox = bounding_box_module.BoundingBox
_Detection = detections_module.Detection _Detection = detections_module.Detection
_DetectionResult = detections_module.DetectionResult _DetectionResult = detections_module.DetectionResult
_Image = image_module.Image _Image = image_module.Image
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_ObjectDetector = object_detector.ObjectDetector _ObjectDetector = object_detector.ObjectDetector
_ObjectDetectorOptions = object_detector.ObjectDetectorOptions _ObjectDetectorOptions = object_detector.ObjectDetectorOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
@ -115,10 +118,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) )
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -134,9 +138,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions( base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite') model_asset_path='/path/to/invalid/model.tflite'
)
options = _ObjectDetectorOptions(base_options=base_options) options = _ObjectDetectorOptions(base_options=base_options)
_ObjectDetector.create_from_options(options) _ObjectDetector.create_from_options(options)
@ -150,9 +156,11 @@ class ObjectDetectorTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT), (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT)) (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
def test_detect(self, model_file_type, max_results, )
expected_detection_result): def test_detect(
self, model_file_type, max_results, expected_detection_result
):
# Creates detector. # Creates detector.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
@ -165,7 +173,8 @@ class ObjectDetectorTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=base_options, max_results=max_results) base_options=base_options, max_results=max_results
)
detector = _ObjectDetector.create_from_options(options) detector = _ObjectDetector.create_from_options(options)
# Performs object detection on the input. # Performs object detection on the input.
@ -178,9 +187,11 @@ class ObjectDetectorTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT), (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT)) (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
def test_detect_in_context(self, model_file_type, max_results, )
expected_detection_result): def test_detect_in_context(
self, model_file_type, max_results, expected_detection_result
):
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT: elif model_file_type is ModelFileType.FILE_CONTENT:
@ -192,7 +203,8 @@ class ObjectDetectorTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=base_options, max_results=max_results) base_options=base_options, max_results=max_results
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
@ -202,7 +214,8 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_score_threshold_option(self): def test_score_threshold_option(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=_SCORE_THRESHOLD) score_threshold=_SCORE_THRESHOLD,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
@ -211,25 +224,30 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections: for detection in detections:
score = detection.categories[0].score score = detection.categories[0].score
self.assertGreaterEqual( self.assertGreaterEqual(
score, _SCORE_THRESHOLD, score,
f'Detection with score lower than threshold found. {detection}') _SCORE_THRESHOLD,
f'Detection with score lower than threshold found. {detection}',
)
def test_max_results_option(self): def test_max_results_option(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
max_results=_MAX_RESULTS) max_results=_MAX_RESULTS,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
detections = detection_result.detections detections = detection_result.detections
self.assertLessEqual( self.assertLessEqual(
len(detections), _MAX_RESULTS, 'Too many results returned.') len(detections), _MAX_RESULTS, 'Too many results returned.'
)
def test_allow_list_option(self): def test_allow_list_option(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=_ALLOW_LIST) category_allowlist=_ALLOW_LIST,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
@ -237,13 +255,17 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections: for detection in detections:
label = detection.categories[0].category_name label = detection.categories[0].category_name
self.assertIn(label, _ALLOW_LIST, self.assertIn(
f'Label {label} found but not in label allow list') label,
_ALLOW_LIST,
f'Label {label} found but not in label allow list',
)
def test_deny_list_option(self): def test_deny_list_option(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_denylist=_DENY_LIST) category_denylist=_DENY_LIST,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
@ -251,26 +273,30 @@ class ObjectDetectorTest(parameterized.TestCase):
for detection in detections: for detection in detections:
label = detection.categories[0].category_name label = detection.categories[0].category_name
self.assertNotIn(label, _DENY_LIST, self.assertNotIn(
f'Label {label} found but in deny list.') label, _DENY_LIST, f'Label {label} found but in deny list.'
)
def test_combined_allowlist_and_denylist(self): def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist # Fails with combined allowlist and denylist
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r'`category_allowlist` and `category_denylist` are mutually ' r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'): r'exclusive options.',
):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
category_allowlist=['foo'], category_allowlist=['foo'],
category_denylist=['bar']) category_denylist=['bar'],
)
with _ObjectDetector.create_from_options(options) as unused_detector: with _ObjectDetector.create_from_options(options) as unused_detector:
pass pass
def test_empty_detection_outputs(self): def test_empty_detection_outputs(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=1) score_threshold=1,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
@ -279,9 +305,11 @@ class ObjectDetectorTest(parameterized.TestCase):
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM) running_mode=_RUNNING_MODE.LIVE_STREAM,
with self.assertRaisesRegex(ValueError, )
r'result callback must be provided'): with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ObjectDetector.create_from_options(options) as unused_detector: with _ObjectDetector.create_from_options(options) as unused_detector:
pass pass
@ -290,56 +318,68 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode, running_mode=running_mode,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
with self.assertRaisesRegex(ValueError, )
r'result callback should not be provided'): with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ObjectDetector.create_from_options(options) as unused_detector: with _ObjectDetector.create_from_options(options) as unused_detector:
pass pass
def test_calling_detect_for_video_in_image_mode(self): def test_calling_detect_for_video_in_image_mode(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0) detector.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self): def test_calling_detect_async_in_image_mode(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0) detector.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self): def test_calling_detect_in_video_mode(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image) detector.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self): def test_calling_detect_async_in_video_mode(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0) detector.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self): def test_detect_for_video_with_out_of_order_timestamp(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
unused_result = detector.detect_for_video(self.test_image, 1) unused_result = detector.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_for_video(self.test_image, 0) detector.detect_for_video(self.test_image, 0)
# TODO: Tests how `detect_for_video` handles the temporal data # TODO: Tests how `detect_for_video` handles the temporal data
@ -348,7 +388,8 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO, running_mode=_RUNNING_MODE.VIDEO,
max_results=4) max_results=4,
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
detection_result = detector.detect_for_video(self.test_image, timestamp) detection_result = detector.detect_for_video(self.test_image, timestamp)
@ -358,20 +399,24 @@ class ObjectDetectorTest(parameterized.TestCase):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image) detector.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self): def test_calling_detect_for_video_in_live_stream_mode(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0) detector.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self): def test_detect_async_calls_with_illegal_timestamp(self):
@ -379,24 +424,30 @@ class ObjectDetectorTest(parameterized.TestCase):
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4, max_results=4,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
detector.detect_async(self.test_image, 100) detector.detect_async(self.test_image, 100)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_async(self.test_image, 0) detector.detect_async(self.test_image, 0)
@parameterized.parameters((0, _EXPECTED_DETECTION_RESULT), @parameterized.parameters(
(1, _DetectionResult(detections=[]))) (0, _EXPECTED_DETECTION_RESULT), (1, _DetectionResult(detections=[]))
)
def test_detect_async_calls(self, threshold, expected_result): def test_detect_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _DetectionResult, output_image: _Image, def check_result(
timestamp_ms: int): result: _DetectionResult, output_image: _Image, timestamp_ms: int
):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
self.assertTrue( self.assertTrue(
np.array_equal(output_image.numpy_view(), np.array_equal(
self.test_image.numpy_view())) output_image.numpy_view(), self.test_image.numpy_view()
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
@ -405,7 +456,8 @@ class ObjectDetectorTest(parameterized.TestCase):
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
max_results=4, max_results=4,
score_threshold=threshold, score_threshold=threshold,
result_callback=check_result) result_callback=check_result,
)
detector = _ObjectDetector.create_from_options(options) detector = _ObjectDetector.create_from_options(options)
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
detector.detect_async(self.test_image, timestamp) detector.detect_async(self.test_image, timestamp)

View File

@ -29,10 +29,12 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_py_pb2", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_py_pb2",
"//mediapipe/tasks/python/components/containers:detections", "//mediapipe/tasks/python/components/containers:detections",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )
@ -71,10 +73,12 @@ py_library(
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )

View File

@ -17,6 +17,7 @@ import math
from typing import Callable, Mapping, Optional from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2 from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.python.components.containers import rect as rect_module from mediapipe.tasks.python.components.containers import rect as rect_module
@ -39,8 +40,9 @@ class BaseVisionTaskApi(object):
self, self,
graph_config: calculator_pb2.CalculatorGraphConfig, graph_config: calculator_pb2.CalculatorGraphConfig,
running_mode: _RunningMode, running_mode: _RunningMode,
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]], packet_callback: Optional[
None]] = None Callable[[Mapping[str, packet_module.Packet]], None]
] = None,
) -> None: ) -> None:
"""Initializes the `BaseVisionTaskApi` object. """Initializes the `BaseVisionTaskApi` object.
@ -48,7 +50,7 @@ class BaseVisionTaskApi(object):
graph_config: The mediapipe vision task graph config proto. graph_config: The mediapipe vision task graph config proto.
running_mode: The running mode of the mediapipe vision task. running_mode: The running mode of the mediapipe vision task.
packet_callback: The optional packet callback for getting results packet_callback: The optional packet callback for getting results
asynchronously in the live stream mode. asynchronously in the live stream mode.
Raises: Raises:
ValueError: The packet callback is not properly set based on the task's ValueError: The packet callback is not properly set based on the task's
@ -58,16 +60,19 @@ class BaseVisionTaskApi(object):
if packet_callback is None: if packet_callback is None:
raise ValueError( raise ValueError(
'The vision task is in live stream mode, a user-defined result ' 'The vision task is in live stream mode, a user-defined result '
'callback must be provided.') 'callback must be provided.'
)
elif packet_callback: elif packet_callback:
raise ValueError( raise ValueError(
'The vision task is in image or video mode, a user-defined result ' 'The vision task is in image or video mode, a user-defined result '
'callback should not be provided.') 'callback should not be provided.'
)
self._runner = _TaskRunner.create(graph_config, packet_callback) self._runner = _TaskRunner.create(graph_config, packet_callback)
self._running_mode = running_mode self._running_mode = running_mode
def _process_image_data( def _process_image_data(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: self, inputs: Mapping[str, _Packet]
) -> Mapping[str, _Packet]:
"""A synchronous method to process single image inputs. """A synchronous method to process single image inputs.
The call blocks the current thread until a failure status or a successful The call blocks the current thread until a failure status or a successful
@ -84,12 +89,14 @@ class BaseVisionTaskApi(object):
""" """
if self._running_mode != _RunningMode.IMAGE: if self._running_mode != _RunningMode.IMAGE:
raise ValueError( raise ValueError(
'Task is not initialized with the image mode. Current running mode:' + 'Task is not initialized with the image mode. Current running mode:'
self._running_mode.name) + self._running_mode.name
)
return self._runner.process(inputs) return self._runner.process(inputs)
def _process_video_data( def _process_video_data(
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]: self, inputs: Mapping[str, _Packet]
) -> Mapping[str, _Packet]:
"""A synchronous method to process continuous video frames. """A synchronous method to process continuous video frames.
The call blocks the current thread until a failure status or a successful The call blocks the current thread until a failure status or a successful
@ -106,8 +113,9 @@ class BaseVisionTaskApi(object):
""" """
if self._running_mode != _RunningMode.VIDEO: if self._running_mode != _RunningMode.VIDEO:
raise ValueError( raise ValueError(
'Task is not initialized with the video mode. Current running mode:' + 'Task is not initialized with the video mode. Current running mode:'
self._running_mode.name) + self._running_mode.name
)
return self._runner.process(inputs) return self._runner.process(inputs)
def _send_live_stream_data(self, inputs: Mapping[str, _Packet]) -> None: def _send_live_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
@ -124,13 +132,18 @@ class BaseVisionTaskApi(object):
""" """
if self._running_mode != _RunningMode.LIVE_STREAM: if self._running_mode != _RunningMode.LIVE_STREAM:
raise ValueError( raise ValueError(
'Task is not initialized with the live stream mode. Current running mode:' 'Task is not initialized with the live stream mode. Current running'
+ self._running_mode.name) ' mode:'
+ self._running_mode.name
)
self._runner.send(inputs) self._runner.send(inputs)
def convert_to_normalized_rect(self, def convert_to_normalized_rect(
options: _ImageProcessingOptions, self,
roi_allowed: bool = True) -> _NormalizedRect: options: _ImageProcessingOptions,
image: image_module.Image,
roi_allowed: bool = True,
) -> _NormalizedRect:
"""Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly. """Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly.
If the input ImageProcessingOptions is not present, returns a default If the input ImageProcessingOptions is not present, returns a default
@ -140,6 +153,7 @@ class BaseVisionTaskApi(object):
Args: Args:
options: Options for image processing. options: Options for image processing.
image: The image to process.
roi_allowed: Indicates if the `region_of_interest` field is allowed to be roi_allowed: Indicates if the `region_of_interest` field is allowed to be
set. By default, it's set to True. set. By default, it's set to True.
@ -147,7 +161,8 @@ class BaseVisionTaskApi(object):
A normalized rect proto that represents the image processing options. A normalized rect proto that represents the image processing options.
""" """
normalized_rect = _NormalizedRect( normalized_rect = _NormalizedRect(
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1) rotation=0, x_center=0.5, y_center=0.5, width=1, height=1
)
if options is None: if options is None:
return normalized_rect return normalized_rect
@ -169,6 +184,20 @@ class BaseVisionTaskApi(object):
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0 normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
normalized_rect.width = roi.right - roi.left normalized_rect.width = roi.right - roi.left
normalized_rect.height = roi.bottom - roi.top normalized_rect.height = roi.bottom - roi.top
# For 90° and 270° rotations, we need to swap width and height.
# This is due to the internal behavior of ImageToTensorCalculator, which:
# - first denormalizes the provided rect by multiplying the rect width or
# height by the image width or height, repectively.
# - then rotates this by denormalized rect by the provided rotation, and
# uses this for cropping,
# - then finally rotates this back.
if abs(options.rotation_degrees % 180) != 0:
w = normalized_rect.height * image.height / image.width
h = normalized_rect.width * image.width / image.height
normalized_rect.width = w
normalized_rect.height = h
return normalized_rect return normalized_rect
def close(self) -> None: def close(self) -> None:

View File

@ -212,7 +212,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face detection failed to run. RuntimeError: If face detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
@ -261,7 +261,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face detection failed to run. RuntimeError: If face detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
@ -320,7 +320,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
detector has already processed. detector has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(

View File

@ -401,7 +401,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face landmarker detection failed to run. RuntimeError: If face landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
@ -444,7 +444,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If face landmarker detection failed to run. RuntimeError: If face landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
@ -497,7 +497,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
face landmarker has already processed. face landmarker has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False image_processing_options, image, roi_allowed=False
) )
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions _GestureRecognizerGraphOptionsProto = (
gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
)
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -53,7 +55,9 @@ _HAND_LANDMARKS_STREAM_NAME = 'landmarks'
_HAND_LANDMARKS_TAG = 'LANDMARKS' _HAND_LANDMARKS_TAG = 'LANDMARKS'
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks' _HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS' _HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph' _TASK_GRAPH_NAME = (
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
)
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
_GESTURE_DEFAULT_INDEX = -1 _GESTURE_DEFAULT_INDEX = -1
@ -78,17 +82,21 @@ class GestureRecognizerResult:
def _build_recognition_result( def _build_recognition_result(
output_packets: Mapping[str, output_packets: Mapping[str, packet_module.Packet]
packet_module.Packet]) -> GestureRecognizerResult: ) -> GestureRecognizerResult:
"""Constructs a `GestureRecognizerResult` from output packets.""" """Constructs a `GestureRecognizerResult` from output packets."""
gestures_proto_list = packet_getter.get_proto_list( gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME]) output_packets[_HAND_GESTURE_STREAM_NAME]
)
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME]) output_packets[_HANDEDNESS_STREAM_NAME]
)
hand_landmarks_proto_list = packet_getter.get_proto_list( hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME]) output_packets[_HAND_LANDMARKS_STREAM_NAME]
)
hand_world_landmarks_proto_list = packet_getter.get_proto_list( hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
gesture_results = [] gesture_results = []
for proto in gestures_proto_list: for proto in gestures_proto_list:
@ -101,7 +109,9 @@ def _build_recognition_result(
index=_GESTURE_DEFAULT_INDEX, index=_GESTURE_DEFAULT_INDEX,
score=gesture.score, score=gesture.score,
display_name=gesture.display_name, display_name=gesture.display_name,
category_name=gesture.label)) category_name=gesture.label,
)
)
gesture_results.append(gesture_categories) gesture_results.append(gesture_categories)
handedness_results = [] handedness_results = []
@ -115,7 +125,9 @@ def _build_recognition_result(
index=handedness.index, index=handedness.index,
score=handedness.score, score=handedness.score,
display_name=handedness.display_name, display_name=handedness.display_name,
category_name=handedness.label)) category_name=handedness.label,
)
)
handedness_results.append(handedness_categories) handedness_results.append(handedness_categories)
hand_landmarks_results = [] hand_landmarks_results = []
@ -125,7 +137,8 @@ def _build_recognition_result(
hand_landmarks_list = [] hand_landmarks_list = []
for hand_landmark in hand_landmarks.landmark: for hand_landmark in hand_landmarks.landmark:
hand_landmarks_list.append( hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)) landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
hand_landmarks_results.append(hand_landmarks_list) hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = [] hand_world_landmarks_results = []
@ -135,12 +148,16 @@ def _build_recognition_result(
hand_world_landmarks_list = [] hand_world_landmarks_list = []
for hand_world_landmark in hand_world_landmarks.landmark: for hand_world_landmark in hand_world_landmarks.landmark:
hand_world_landmarks_list.append( hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(hand_world_landmark)) landmark_module.Landmark.create_from_pb2(hand_world_landmark)
)
hand_world_landmarks_results.append(hand_world_landmarks_list) hand_world_landmarks_results.append(hand_world_landmarks_list)
return GestureRecognizerResult(gesture_results, handedness_results, return GestureRecognizerResult(
hand_landmarks_results, gesture_results,
hand_world_landmarks_results) handedness_results,
hand_landmarks_results,
hand_world_landmarks_results,
)
@dataclasses.dataclass @dataclasses.dataclass
@ -174,43 +191,62 @@ class GestureRecognizerOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1 num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5 min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5
canned_gesture_classifier_options: Optional[ canned_gesture_classifier_options: Optional[_ClassifierOptions] = (
_ClassifierOptions] = dataclasses.field( dataclasses.field(default_factory=_ClassifierOptions)
default_factory=_ClassifierOptions) )
custom_gesture_classifier_options: Optional[ custom_gesture_classifier_options: Optional[_ClassifierOptions] = (
_ClassifierOptions] = dataclasses.field( dataclasses.field(default_factory=_ClassifierOptions)
default_factory=_ClassifierOptions) )
result_callback: Optional[Callable[ result_callback: Optional[
[GestureRecognizerResult, image_module.Image, int], None]] = None Callable[[GestureRecognizerResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto: def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
"""Generates an GestureRecognizerOptions protobuf object.""" """Generates an GestureRecognizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
# Initialize gesture recognizer options from base options. # Initialize gesture recognizer options from base options.
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto( gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto) base_options=base_options_proto
)
# Configure hand detector and hand landmarker options. # Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options hand_landmarker_options_proto = (
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence gesture_recognizer_options_proto.hand_landmarker_graph_options
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands )
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence hand_landmarker_options_proto.min_tracking_confidence = (
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence self.min_tracking_confidence
)
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
self.num_hands
)
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
self.min_hand_detection_confidence
)
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_presence_confidence
)
# Configure hand gesture recognizer options. # Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options hand_gesture_recognizer_options_proto = (
gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
)
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom( hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.canned_gesture_classifier_options.to_pb2()) self.canned_gesture_classifier_options.to_pb2()
)
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom( hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.custom_gesture_classifier_options.to_pb2()) self.custom_gesture_classifier_options.to_pb2()
)
return gesture_recognizer_options_proto return gesture_recognizer_options_proto
@ -239,12 +275,14 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = GestureRecognizerOptions( options = GestureRecognizerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options( def create_from_options(
cls, options: GestureRecognizerOptions) -> 'GestureRecognizer': cls, options: GestureRecognizerOptions
) -> 'GestureRecognizer':
"""Creates the `GestureRecognizer` object from gesture recognizer options. """Creates the `GestureRecognizer` object from gesture recognizer options.
Args: Args:
@ -268,14 +306,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME] empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
options.result_callback( options.result_callback(
GestureRecognizerResult([], [], [], []), image, GestureRecognizerResult([], [], [], []),
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return return
gesture_recognizer_result = _build_recognition_result(output_packets) gesture_recognizer_result = _build_recognition_result(output_packets)
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
options.result_callback(gesture_recognizer_result, image, options.result_callback(
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) gesture_recognizer_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -286,23 +329,27 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
output_streams=[ output_streams=[
':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]), ':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]),
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]), ':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG, ':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
_HAND_LANDMARKS_STREAM_NAME]), ':'.join([ ':'.join(
_HAND_WORLD_LANDMARKS_TAG, [_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
_HAND_WORLD_LANDMARKS_STREAM_NAME ),
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
def recognize( def recognize(
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> GestureRecognizerResult: ) -> GestureRecognizerResult:
"""Performs hand gesture recognition on the given image. """Performs hand gesture recognition on the given image.
@ -325,12 +372,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If gesture recognition failed to run. RuntimeError: If gesture recognition failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
_NORM_RECT_STREAM_NAME: normalized_rect.to_pb2()
packet_creator.create_proto(normalized_rect.to_pb2()) ),
}) })
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
@ -342,7 +390,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> GestureRecognizerResult: ) -> GestureRecognizerResult:
"""Performs gesture recognition on the provided video frame. """Performs gesture recognition on the provided video frame.
@ -367,14 +415,15 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If gesture recognition failed to run. RuntimeError: If gesture recognition failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
@ -386,7 +435,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None: ) -> None:
"""Sends live image data to perform gesture recognition. """Sends live image data to perform gesture recognition.
@ -419,12 +468,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
gesture recognizer has already processed. gesture recognizer has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions _HandLandmarkerGraphOptionsProto = (
hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -56,6 +58,7 @@ _MICRO_SECONDS_PER_MILLISECOND = 1000
class HandLandmark(enum.IntEnum): class HandLandmark(enum.IntEnum):
"""The 21 hand landmarks.""" """The 21 hand landmarks."""
WRIST = 0 WRIST = 0
THUMB_CMC = 1 THUMB_CMC = 1
THUMB_MCP = 2 THUMB_MCP = 2
@ -95,14 +98,18 @@ class HandLandmarkerResult:
def _build_landmarker_result( def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult: output_packets: Mapping[str, packet_module.Packet]
) -> HandLandmarkerResult:
"""Constructs a `HandLandmarksDetectionResult` from output packets.""" """Constructs a `HandLandmarksDetectionResult` from output packets."""
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME]) output_packets[_HANDEDNESS_STREAM_NAME]
)
hand_landmarks_proto_list = packet_getter.get_proto_list( hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME]) output_packets[_HAND_LANDMARKS_STREAM_NAME]
)
hand_world_landmarks_proto_list = packet_getter.get_proto_list( hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]) output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
handedness_results = [] handedness_results = []
for proto in handedness_proto_list: for proto in handedness_proto_list:
@ -115,7 +122,9 @@ def _build_landmarker_result(
index=handedness.index, index=handedness.index,
score=handedness.score, score=handedness.score,
display_name=handedness.display_name, display_name=handedness.display_name,
category_name=handedness.label)) category_name=handedness.label,
)
)
handedness_results.append(handedness_categories) handedness_results.append(handedness_categories)
hand_landmarks_results = [] hand_landmarks_results = []
@ -125,7 +134,8 @@ def _build_landmarker_result(
hand_landmarks_list = [] hand_landmarks_list = []
for hand_landmark in hand_landmarks.landmark: for hand_landmark in hand_landmarks.landmark:
hand_landmarks_list.append( hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)) landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
hand_landmarks_results.append(hand_landmarks_list) hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = [] hand_world_landmarks_results = []
@ -135,11 +145,13 @@ def _build_landmarker_result(
hand_world_landmarks_list = [] hand_world_landmarks_list = []
for hand_world_landmark in hand_world_landmarks.landmark: for hand_world_landmark in hand_world_landmarks.landmark:
hand_world_landmarks_list.append( hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(hand_world_landmark)) landmark_module.Landmark.create_from_pb2(hand_world_landmark)
)
hand_world_landmarks_results.append(hand_world_landmarks_list) hand_world_landmarks_results.append(hand_world_landmarks_list)
return HandLandmarkerResult(handedness_results, hand_landmarks_results, return HandLandmarkerResult(
hand_world_landmarks_results) handedness_results, hand_landmarks_results, hand_world_landmarks_results
)
@dataclasses.dataclass @dataclasses.dataclass
@ -167,28 +179,41 @@ class HandLandmarkerOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1 num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5 min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5 min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5
result_callback: Optional[Callable[ result_callback: Optional[
[HandLandmarkerResult, image_module.Image, int], None]] = None Callable[[HandLandmarkerResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto: def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
"""Generates an HandLandmarkerGraphOptions protobuf object.""" """Generates an HandLandmarkerGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
# Initialize the hand landmarker options from base options. # Initialize the hand landmarker options from base options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto( hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
base_options=base_options_proto) base_options=base_options_proto
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence )
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands hand_landmarker_options_proto.min_tracking_confidence = (
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence self.min_tracking_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence )
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
self.num_hands
)
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
self.min_hand_detection_confidence
)
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_presence_confidence
)
return hand_landmarker_options_proto return hand_landmarker_options_proto
@ -216,12 +241,14 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = HandLandmarkerOptions( options = HandLandmarkerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options(cls, def create_from_options(
options: HandLandmarkerOptions) -> 'HandLandmarker': cls, options: HandLandmarkerOptions
) -> 'HandLandmarker':
"""Creates the `HandLandmarker` object from hand landmarker options. """Creates the `HandLandmarker` object from hand landmarker options.
Args: Args:
@ -245,14 +272,19 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME] empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
options.result_callback( options.result_callback(
HandLandmarkerResult([], [], []), image, HandLandmarkerResult([], [], []),
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return return
hand_landmarks_detection_result = _build_landmarker_result(output_packets) hand_landmarks_detection_result = _build_landmarker_result(output_packets)
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
options.result_callback(hand_landmarks_detection_result, image, options.result_callback(
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) hand_landmarks_detection_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -263,21 +295,26 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
output_streams=[ output_streams=[
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]), ':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]), ':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
':'.join([ ':'.join(
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME [_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
def detect( def detect(
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HandLandmarkerResult: ) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the given image. """Performs hand landmarks detection on the given image.
@ -300,12 +337,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If hand landmarker detection failed to run. RuntimeError: If hand landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
_NORM_RECT_STREAM_NAME: normalized_rect.to_pb2()
packet_creator.create_proto(normalized_rect.to_pb2()) ),
}) })
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
@ -317,7 +355,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> HandLandmarkerResult: ) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the provided video frame. """Performs hand landmarks detection on the provided video frame.
@ -342,14 +380,15 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If hand landmarker detection failed to run. RuntimeError: If hand landmarker detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty(): if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
@ -361,7 +400,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None: ) -> None:
"""Sends live image data to perform hand landmarks detection. """Sends live image data to perform hand landmarks detection.
@ -394,12 +433,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
hand landmarker has already processed. hand landmarker has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect( normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False) image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })

View File

@ -35,7 +35,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageClassifierResult = classification_result_module.ClassificationResult ImageClassifierResult = classification_result_module.ClassificationResult
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ImageClassifierGraphOptionsProto = (
image_classifier_graph_options_pb2.ImageClassifierGraphOptions
)
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions _ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
@ -48,7 +50,9 @@ _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in' _NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT' _NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _TASK_GRAPH_NAME = (
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
)
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@ -81,6 +85,7 @@ class ImageClassifierOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
display_names_locale: Optional[str] = None display_names_locale: Optional[str] = None
@ -88,24 +93,29 @@ class ImageClassifierOptions:
score_threshold: Optional[float] = None score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None
result_callback: Optional[Callable[ result_callback: Optional[
[ImageClassifierResult, image_module.Image, int], None]] = None Callable[[ImageClassifierResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierGraphOptionsProto: def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
"""Generates an ImageClassifierOptions protobuf object.""" """Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
classifier_options_proto = _ClassifierOptionsProto( classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold, score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist, category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist, category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale, display_names_locale=self.display_names_locale,
max_results=self.max_results) max_results=self.max_results,
)
return _ImageClassifierGraphOptionsProto( return _ImageClassifierGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
classifier_options=classifier_options_proto) classifier_options=classifier_options_proto,
)
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
@ -165,12 +175,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = ImageClassifierOptions( options = ImageClassifierOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options(cls, def create_from_options(
options: ImageClassifierOptions) -> 'ImageClassifier': cls, options: ImageClassifierOptions
) -> 'ImageClassifier':
"""Creates the `ImageClassifier` object from image classifier options. """Creates the `ImageClassifier` object from image classifier options.
Args: Args:
@ -191,12 +203,15 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
ImageClassifierResult.create_from_pb2(classification_result_proto), ImageClassifierResult.create_from_pb2(classification_result_proto),
image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -206,19 +221,23 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
], ],
output_streams=[ output_streams=[
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]), ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
def classify( def classify(
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageClassifierResult: ) -> ImageClassifierResult:
"""Performs image classification on the provided MediaPipe Image. """Performs image classification on the provided MediaPipe Image.
@ -233,17 +252,20 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
_NORM_RECT_STREAM_NAME: normalized_rect.to_pb2()
packet_creator.create_proto(normalized_rect.to_pb2()) ),
}) })
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
return ImageClassifierResult.create_from_pb2(classification_result_proto) return ImageClassifierResult.create_from_pb2(classification_result_proto)
@ -251,7 +273,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageClassifierResult: ) -> ImageClassifierResult:
"""Performs image classification on the provided video frames. """Performs image classification on the provided video frames.
@ -272,19 +294,22 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom( classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])) packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
return ImageClassifierResult.create_from_pb2(classification_result_proto) return ImageClassifierResult.create_from_pb2(classification_result_proto)
@ -292,7 +317,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None: ) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification. """Sends live image data (an Image with a unique timestamp) to perform image classification.
@ -320,12 +345,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
classifier has already processed. classifier has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })

View File

@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
ImageEmbedderResult = embedding_result_module.EmbeddingResult ImageEmbedderResult = embedding_result_module.EmbeddingResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions _ImageEmbedderGraphOptionsProto = (
image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
)
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions _EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
@ -74,24 +76,29 @@ class ImageEmbedderOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
l2_normalize: Optional[bool] = None l2_normalize: Optional[bool] = None
quantize: Optional[bool] = None quantize: Optional[bool] = None
result_callback: Optional[Callable[ result_callback: Optional[
[ImageEmbedderResult, image_module.Image, int], None]] = None Callable[[ImageEmbedderResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageEmbedderGraphOptionsProto: def to_pb2(self) -> _ImageEmbedderGraphOptionsProto:
"""Generates an ImageEmbedderOptions protobuf object.""" """Generates an ImageEmbedderOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
embedder_options_proto = _EmbedderOptionsProto( embedder_options_proto = _EmbedderOptionsProto(
l2_normalize=self.l2_normalize, quantize=self.quantize) l2_normalize=self.l2_normalize, quantize=self.quantize
)
return _ImageEmbedderGraphOptionsProto( return _ImageEmbedderGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto, embedder_options=embedder_options_proto
embedder_options=embedder_options_proto) )
class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi): class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
@ -135,12 +142,14 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = ImageEmbedderOptions( options = ImageEmbedderOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options(cls, def create_from_options(
options: ImageEmbedderOptions) -> 'ImageEmbedder': cls, options: ImageEmbedderOptions
) -> 'ImageEmbedder':
"""Creates the `ImageEmbedder` object from image embedder options. """Creates the `ImageEmbedder` object from image embedder options.
Args: Args:
@ -161,13 +170,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom( embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
ImageEmbedderResult.create_from_pb2(embedding_result_proto), image, ImageEmbedderResult.create_from_pb2(embedding_result_proto),
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -177,19 +189,23 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
], ],
output_streams=[ output_streams=[
':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_OUT_STREAM_NAME]), ':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
def embed( def embed(
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageEmbedderResult: ) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided MediaPipe Image. """Performs image embedding extraction on the provided MediaPipe Image.
@ -207,17 +223,20 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run. RuntimeError: If image embedder failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
packet_creator.create_image(image), _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
_NORM_RECT_STREAM_NAME: normalized_rect.to_pb2()
packet_creator.create_proto(normalized_rect.to_pb2()) ),
}) })
embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom( embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
return ImageEmbedderResult.create_from_pb2(embedding_result_proto) return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
@ -225,7 +244,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageEmbedderResult: ) -> ImageEmbedderResult:
"""Performs image embedding extraction on the provided video frames. """Performs image embedding extraction on the provided video frames.
@ -249,18 +268,21 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image embedder failed to run. RuntimeError: If image embedder failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
embedding_result_proto = embeddings_pb2.EmbeddingResult() embedding_result_proto = embeddings_pb2.EmbeddingResult()
embedding_result_proto.CopyFrom( embedding_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])) packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
)
return ImageEmbedderResult.create_from_pb2(embedding_result_proto) return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
@ -268,7 +290,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None: ) -> None:
"""Sends live image data to embedder. """Sends live image data to embedder.
@ -301,19 +323,24 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
embedder has already processed. embedder has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect(image_processing_options) normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ),
_NORM_RECT_STREAM_NAME: _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
packet_creator.create_proto(normalized_rect.to_pb2()).at( normalized_rect.to_pb2()
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
@classmethod @classmethod
def cosine_similarity(cls, u: embedding_result_module.Embedding, def cosine_similarity(
v: embedding_result_module.Embedding) -> float: cls,
u: embedding_result_module.Embedding,
v: embedding_result_module.Embedding,
) -> float:
"""Utility function to compute cosine similarity between two embedding entries. """Utility function to compute cosine similarity between two embedding entries.
May return an InvalidArgumentError if e.g. the feature vectors are of May return an InvalidArgumentError if e.g. the feature vectors are of

View File

@ -23,16 +23,23 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions _ImageSegmenterGraphOptionsProto = (
image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
)
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
@ -40,6 +47,8 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@ -77,19 +86,24 @@ class ImageSegmenterOptions:
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
activation: Optional[Activation] = Activation.NONE activation: Optional[Activation] = Activation.NONE
result_callback: Optional[Callable[ result_callback: Optional[
[List[image_module.Image], image_module.Image, int], None]] = None Callable[[ImageSegmenterResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an ImageSegmenterOptions protobuf object.""" """Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value) output_type=self.output_type.value, activation=self.activation.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto) segmenter_options=segmenter_options_proto,
)
class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
@ -138,12 +152,14 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = ImageSegmenterOptions( options = ImageSegmenterOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options(cls, def create_from_options(
options: ImageSegmenterOptions) -> 'ImageSegmenter': cls, options: ImageSegmenterOptions
) -> 'ImageSegmenter':
"""Creates the `ImageSegmenter` object from image segmenter options. """Creates the `ImageSegmenter` object from image segmenter options.
Args: Args:
@ -162,31 +178,47 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
options.result_callback(segmentation_result, image, options.result_callback(
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) segmentation_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[ output_streams=[
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
def segment(self, image: image_module.Image) -> List[image_module.Image]: def segment(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image. """Performs the actual segmentation task on the provided MediaPipe Image.
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns: Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is If the output_type is CATEGORY_MASK, the returned vector of images is
@ -199,14 +231,26 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run. RuntimeError: If image segmentation failed to run.
""" """
output_packets = self._process_image_data( normalized_rect = self.convert_to_normalized_rect(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_for_video(self, image: image_module.Image, def segment_for_video(
timestamp_ms: int) -> List[image_module.Image]: self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ImageSegmenterResult:
"""Performs segmentation on the provided video frames. """Performs segmentation on the provided video frames.
Only use this method when the ImageSegmenter is created with the video Only use this method when the ImageSegmenter is created with the video
@ -217,6 +261,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns: Returns:
If the output_type is CATEGORY_MASK, the returned vector of images is If the output_type is CATEGORY_MASK, the returned vector of images is
@ -229,16 +274,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image segmentation failed to run. RuntimeError: If image segmentation failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME]) output_packets[_SEGMENTATION_OUT_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None: def segment_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image segmentation. """Sends live image data (an Image with a unique timestamp) to perform image segmentation.
Only use this method when the ImageSegmenter is created with the live stream Only use this method when the ImageSegmenter is created with the live stream
@ -260,13 +317,20 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
segmenter has already processed. segmenter has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at( timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) ),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })

View File

@ -22,15 +22,20 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.object_detector.proto import object_detector_options_pb2 from mediapipe.tasks.cc.vision.object_detector.proto import object_detector_options_pb2
from mediapipe.tasks.python.components.containers import detections as detections_module from mediapipe.tasks.python.components.containers import detections as detections_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
ObjectDetectorResult = detections_module.DetectionResult
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ObjectDetectorOptionsProto = object_detector_options_pb2.ObjectDetectorOptions _ObjectDetectorOptionsProto = object_detector_options_pb2.ObjectDetectorOptions
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_DETECTIONS_OUT_STREAM_NAME = 'detections_out' _DETECTIONS_OUT_STREAM_NAME = 'detections_out'
@ -38,7 +43,10 @@ _DETECTIONS_TAG = 'DETECTIONS'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ObjectDetectorGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ObjectDetectorGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass @dataclasses.dataclass
@ -48,11 +56,10 @@ class ObjectDetectorOptions:
Attributes: Attributes:
base_options: Base options for the object detector task. base_options: Base options for the object detector task.
running_mode: The running mode of the task. Default to the image mode. running_mode: The running mode of the task. Default to the image mode.
Object detector task has three running modes: Object detector task has three running modes: 1) The image mode for
1) The image mode for detecting objects on single image inputs. detecting objects on single image inputs. 2) The video mode for detecting
2) The video mode for detecting objects on the decoded frames of a video. objects on the decoded frames of a video. 3) The live stream mode for
3) The live stream mode for detecting objects on a live stream of input detecting objects on a live stream of input data, such as from camera.
data, such as from camera.
display_names_locale: The locale to use for display names specified through display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata. the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to max_results: The maximum number of top-scored classification results to
@ -71,6 +78,7 @@ class ObjectDetectorOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
display_names_locale: Optional[str] = None display_names_locale: Optional[str] = None
@ -79,14 +87,16 @@ class ObjectDetectorOptions:
category_allowlist: Optional[List[str]] = None category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None category_denylist: Optional[List[str]] = None
result_callback: Optional[ result_callback: Optional[
Callable[[detections_module.DetectionResult, image_module.Image, int], Callable[[ObjectDetectorResult, image_module.Image, int], None]
None]] = None ] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ObjectDetectorOptionsProto: def to_pb2(self) -> _ObjectDetectorOptionsProto:
"""Generates an ObjectDetectorOptions protobuf object.""" """Generates an ObjectDetectorOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
return _ObjectDetectorOptionsProto( return _ObjectDetectorOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
display_names_locale=self.display_names_locale, display_names_locale=self.display_names_locale,
@ -163,12 +173,14 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
""" """
base_options = _BaseOptions(model_asset_path=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = ObjectDetectorOptions( options = ObjectDetectorOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options) return cls.create_from_options(options)
@classmethod @classmethod
def create_from_options(cls, def create_from_options(
options: ObjectDetectorOptions) -> 'ObjectDetector': cls, options: ObjectDetectorOptions
) -> 'ObjectDetector':
"""Creates the `ObjectDetector` object from object detector options. """Creates the `ObjectDetector` object from object detector options.
Args: Args:
@ -187,32 +199,45 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]) output_packets[_DETECTIONS_OUT_STREAM_NAME]
detection_result = detections_module.DetectionResult([ )
detections_module.Detection.create_from_pb2(result) detection_result = ObjectDetectorResult(
for result in detection_proto_list [
]) detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(detection_result, image, timestamp) options.result_callback(detection_result, image, timestamp)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])], input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[ output_streams=[
':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]), ':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
], ],
task_options=options) task_options=options,
)
return cls( return cls(
task_info.generate_graph_config( task_info.generate_graph_config(
enable_flow_limiting=options.running_mode == enable_flow_limiting=options.running_mode
_RunningMode.LIVE_STREAM), options.running_mode, == _RunningMode.LIVE_STREAM
packets_callback if options.result_callback else None) ),
options.running_mode,
packets_callback if options.result_callback else None,
)
# TODO: Create an Image class for MediaPipe Tasks. # TODO: Create an Image class for MediaPipe Tasks.
def detect(self, def detect(
image: image_module.Image) -> detections_module.DetectionResult: self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ObjectDetectorResult:
"""Performs object detection on the provided MediaPipe Image. """Performs object detection on the provided MediaPipe Image.
Only use this method when the ObjectDetector is created with the image Only use this method when the ObjectDetector is created with the image
@ -220,6 +245,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns: Returns:
A detection result object that contains a list of detections, each A detection result object that contains a list of detections, each
@ -231,17 +257,31 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If object detection failed to run. RuntimeError: If object detection failed to run.
""" """
output_packets = self._process_image_data( normalized_rect = self.convert_to_normalized_rect(
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)}) image_processing_options, image, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]) output_packets[_DETECTIONS_OUT_STREAM_NAME]
return detections_module.DetectionResult([ )
detections_module.Detection.create_from_pb2(result) return ObjectDetectorResult(
for result in detection_proto_list [
]) detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_for_video(self, image: image_module.Image, def detect_for_video(
timestamp_ms: int) -> detections_module.DetectionResult: self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> ObjectDetectorResult:
"""Performs object detection on the provided video frames. """Performs object detection on the provided video frames.
Only use this method when the ObjectDetector is created with the video Only use this method when the ObjectDetector is created with the video
@ -252,6 +292,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns: Returns:
A detection result object that contains a list of detections, each A detection result object that contains a list of detections, each
@ -263,18 +304,33 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If object detection failed to run. RuntimeError: If object detection failed to run.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at(timestamp_ms) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]) output_packets[_DETECTIONS_OUT_STREAM_NAME]
return detections_module.DetectionResult([ )
detections_module.Detection.create_from_pb2(result) return ObjectDetectorResult(
for result in detection_proto_list [
]) detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_async(self, image: image_module.Image, timestamp_ms: int) -> None: def detect_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform object detection. """Sends live image data (an Image with a unique timestamp) to perform object detection.
Only use this method when the ObjectDetector is created with the live stream Only use this method when the ObjectDetector is created with the live stream
@ -298,12 +354,20 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the object ValueError: If the current input timestamp is smaller than what the object
detector has already processed. detector has already processed.
""" """
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image, roi_allowed=False
)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
packet_creator.create_image(image).at(timestamp_ms) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })