Internal change
PiperOrigin-RevId: 522307800
This commit is contained in:
parent
2e256bebb5
commit
97bd9c2157
|
@ -33,6 +33,7 @@ py_test(
|
|||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:object_detector",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -55,10 +55,10 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
|||
def _generate_empty_results() -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[], head_index=0, head_name='probability')
|
||||
_Classifications(categories=[], head_index=0, head_name='probability')
|
||||
],
|
||||
timestamp_ms=0)
|
||||
timestamp_ms=0,
|
||||
)
|
||||
|
||||
|
||||
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
|
||||
|
@ -129,10 +129,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -148,9 +149,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
_ImageClassifier.create_from_options(options)
|
||||
|
||||
|
@ -164,9 +167,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||
def test_classify(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
|
||||
)
|
||||
def test_classify(
|
||||
self, model_file_type, max_results, expected_classification_result
|
||||
):
|
||||
# Creates classifier.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -179,23 +184,27 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, max_results=max_results)
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected_classification_result.to_pb2()
|
||||
)
|
||||
# Closes the classifier explicitly when the classifier is not used in
|
||||
# a context.
|
||||
classifier.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()))
|
||||
def test_classify_in_context(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
|
||||
)
|
||||
def test_classify_in_context(
|
||||
self, model_file_type, max_results, expected_classification_result
|
||||
):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
|
@ -207,13 +216,15 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, max_results=max_results)
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected_classification_result.to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -222,20 +233,110 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), _generate_soccer_ball_results().to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_rotation(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=3)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'burger_rotated.jpg')
|
||||
)
|
||||
)
|
||||
# Specify a 90° anti-clockwise rotation.
|
||||
image_processing_options = _ImageProcessingOptions(None, -90)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected = ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.754467,
|
||||
display_name='',
|
||||
category_name='cheeseburger',
|
||||
),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.0288028,
|
||||
display_name='',
|
||||
category_name='guacamole',
|
||||
),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.0286119,
|
||||
display_name='',
|
||||
category_name='bagel',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0,
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected.to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest_and_rotation(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects_rotated.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball, with 90° anti-clockwise
|
||||
# rotation.
|
||||
roi = _Rect(left=0.2655, top=0.45, right=0.6925, bottom=0.614)
|
||||
image_processing_options = _ImageProcessingOptions(roi, -90)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected = ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.997684,
|
||||
display_name='',
|
||||
category_name='soccer ball',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0,
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected.to_pb2()
|
||||
)
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -245,26 +346,33 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for category in classification.categories:
|
||||
score = category.score
|
||||
self.assertGreaterEqual(
|
||||
score, _SCORE_THRESHOLD,
|
||||
f'Classification with score lower than threshold found. '
|
||||
f'{classification}')
|
||||
score,
|
||||
_SCORE_THRESHOLD,
|
||||
(
|
||||
'Classification with score lower than threshold found. '
|
||||
f'{classification}'
|
||||
),
|
||||
)
|
||||
|
||||
def test_max_results_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
categories = image_result.classifications[0].categories
|
||||
|
||||
self.assertLessEqual(
|
||||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||
len(categories), _MAX_RESULTS, 'Too many results returned.'
|
||||
)
|
||||
|
||||
def test_allow_list_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=_ALLOW_LIST)
|
||||
category_allowlist=_ALLOW_LIST,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -273,13 +381,17 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for classification in classifications:
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertIn(label, _ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list')
|
||||
self.assertIn(
|
||||
label,
|
||||
_ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list',
|
||||
)
|
||||
|
||||
def test_deny_list_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_denylist=_DENY_LIST)
|
||||
category_denylist=_DENY_LIST,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -288,26 +400,30 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for classification in classifications:
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertNotIn(label, _DENY_LIST,
|
||||
f'Label {label} found but in deny list.')
|
||||
self.assertNotIn(
|
||||
label, _DENY_LIST, f'Label {label} found but in deny list.'
|
||||
)
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.'):
|
||||
r'exclusive options.',
|
||||
):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'])
|
||||
category_denylist=['bar'],
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_empty_classification_outputs(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=1)
|
||||
score_threshold=1,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
|
@ -316,9 +432,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
def test_missing_result_callback(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
|
@ -327,67 +445,81 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_calling_classify_for_video_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_async_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_async_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
unused_result = classifier.classify_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=4)
|
||||
max_results=4,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
self.test_image, timestamp
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
|
@ -398,18 +530,22 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=1)
|
||||
max_results=1,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, image_processing_options)
|
||||
test_image, timestamp, image_processing_options
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
|
@ -420,20 +556,24 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
|
@ -441,25 +581,32 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters((0, _generate_burger_results()),
|
||||
(1, _generate_empty_results()))
|
||||
@parameterized.parameters(
|
||||
(0, _generate_burger_results()), (1, _generate_empty_results())
|
||||
)
|
||||
def test_classify_async_calls(self, threshold, expected_result):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||
expected_result.to_pb2())
|
||||
def check_result(
|
||||
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(), expected_result.to_pb2()
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
|
@ -468,7 +615,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
score_threshold=threshold,
|
||||
result_callback=check_result)
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
|
@ -476,14 +624,17 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
def check_result(
|
||||
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
|
||||
)
|
||||
|
@ -496,7 +647,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=1,
|
||||
result_callback=check_result)
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(test_image, 100, image_processing_options)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from mediapipe.tasks.python.components.containers import detections as detection
|
|||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import object_detector
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
|
@ -36,8 +37,10 @@ _BoundingBox = bounding_box_module.BoundingBox
|
|||
_Detection = detections_module.Detection
|
||||
_DetectionResult = detections_module.DetectionResult
|
||||
_Image = image_module.Image
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_ObjectDetector = object_detector.ObjectDetector
|
||||
_ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
|
||||
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
|
||||
|
@ -115,10 +118,11 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -134,9 +138,11 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
_ObjectDetector.create_from_options(options)
|
||||
|
||||
|
@ -150,9 +156,11 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT))
|
||||
def test_detect(self, model_file_type, max_results,
|
||||
expected_detection_result):
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
|
||||
)
|
||||
def test_detect(
|
||||
self, model_file_type, max_results, expected_detection_result
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
@ -165,7 +173,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=base_options, max_results=max_results)
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
detector = _ObjectDetector.create_from_options(options)
|
||||
|
||||
# Performs object detection on the input.
|
||||
|
@ -178,9 +187,11 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT))
|
||||
def test_detect_in_context(self, model_file_type, max_results,
|
||||
expected_detection_result):
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self, model_file_type, max_results, expected_detection_result
|
||||
):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
|
@ -192,7 +203,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=base_options, max_results=max_results)
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
|
@ -202,7 +214,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def test_score_threshold_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
|
@ -211,25 +224,30 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
for detection in detections:
|
||||
score = detection.categories[0].score
|
||||
self.assertGreaterEqual(
|
||||
score, _SCORE_THRESHOLD,
|
||||
f'Detection with score lower than threshold found. {detection}')
|
||||
score,
|
||||
_SCORE_THRESHOLD,
|
||||
f'Detection with score lower than threshold found. {detection}',
|
||||
)
|
||||
|
||||
def test_max_results_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
max_results=_MAX_RESULTS)
|
||||
max_results=_MAX_RESULTS,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
self.assertLessEqual(
|
||||
len(detections), _MAX_RESULTS, 'Too many results returned.')
|
||||
len(detections), _MAX_RESULTS, 'Too many results returned.'
|
||||
)
|
||||
|
||||
def test_allow_list_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=_ALLOW_LIST)
|
||||
category_allowlist=_ALLOW_LIST,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
|
@ -237,13 +255,17 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
self.assertIn(label, _ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list')
|
||||
self.assertIn(
|
||||
label,
|
||||
_ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list',
|
||||
)
|
||||
|
||||
def test_deny_list_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_denylist=_DENY_LIST)
|
||||
category_denylist=_DENY_LIST,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
|
@ -251,26 +273,30 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
self.assertNotIn(label, _DENY_LIST,
|
||||
f'Label {label} found but in deny list.')
|
||||
self.assertNotIn(
|
||||
label, _DENY_LIST, f'Label {label} found but in deny list.'
|
||||
)
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.'):
|
||||
r'exclusive options.',
|
||||
):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'])
|
||||
category_denylist=['bar'],
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=1)
|
||||
score_threshold=1,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
|
@ -279,9 +305,11 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def test_missing_result_callback(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
|
@ -290,56 +318,68 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
unused_result = detector.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
# TODO: Tests how `detect_for_video` handles the temporal data
|
||||
|
@ -348,7 +388,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=4)
|
||||
max_results=4,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
detection_result = detector.detect_for_video(self.test_image, timestamp)
|
||||
|
@ -358,20 +399,24 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
|
@ -379,24 +424,30 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
detector.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters((0, _EXPECTED_DETECTION_RESULT),
|
||||
(1, _DetectionResult(detections=[])))
|
||||
@parameterized.parameters(
|
||||
(0, _EXPECTED_DETECTION_RESULT), (1, _DetectionResult(detections=[]))
|
||||
)
|
||||
def test_detect_async_calls(self, threshold, expected_result):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _DetectionResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
def check_result(
|
||||
result: _DetectionResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
|
@ -405,7 +456,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
score_threshold=threshold,
|
||||
result_callback=check_result)
|
||||
result_callback=check_result,
|
||||
)
|
||||
detector = _ObjectDetector.create_from_options(options)
|
||||
for timestamp in range(0, 300, 30):
|
||||
detector.detect_async(self.test_image, timestamp)
|
||||
|
|
|
@ -29,10 +29,12 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:detections",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
@ -71,10 +73,12 @@ py_library(
|
|||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ import math
|
|||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
|
@ -39,8 +40,9 @@ class BaseVisionTaskApi(object):
|
|||
self,
|
||||
graph_config: calculator_pb2.CalculatorGraphConfig,
|
||||
running_mode: _RunningMode,
|
||||
packet_callback: Optional[Callable[[Mapping[str, packet_module.Packet]],
|
||||
None]] = None
|
||||
packet_callback: Optional[
|
||||
Callable[[Mapping[str, packet_module.Packet]], None]
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Initializes the `BaseVisionTaskApi` object.
|
||||
|
||||
|
@ -58,16 +60,19 @@ class BaseVisionTaskApi(object):
|
|||
if packet_callback is None:
|
||||
raise ValueError(
|
||||
'The vision task is in live stream mode, a user-defined result '
|
||||
'callback must be provided.')
|
||||
'callback must be provided.'
|
||||
)
|
||||
elif packet_callback:
|
||||
raise ValueError(
|
||||
'The vision task is in image or video mode, a user-defined result '
|
||||
'callback should not be provided.')
|
||||
'callback should not be provided.'
|
||||
)
|
||||
self._runner = _TaskRunner.create(graph_config, packet_callback)
|
||||
self._running_mode = running_mode
|
||||
|
||||
def _process_image_data(
|
||||
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||
self, inputs: Mapping[str, _Packet]
|
||||
) -> Mapping[str, _Packet]:
|
||||
"""A synchronous method to process single image inputs.
|
||||
|
||||
The call blocks the current thread until a failure status or a successful
|
||||
|
@ -84,12 +89,14 @@ class BaseVisionTaskApi(object):
|
|||
"""
|
||||
if self._running_mode != _RunningMode.IMAGE:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the image mode. Current running mode:' +
|
||||
self._running_mode.name)
|
||||
'Task is not initialized with the image mode. Current running mode:'
|
||||
+ self._running_mode.name
|
||||
)
|
||||
return self._runner.process(inputs)
|
||||
|
||||
def _process_video_data(
|
||||
self, inputs: Mapping[str, _Packet]) -> Mapping[str, _Packet]:
|
||||
self, inputs: Mapping[str, _Packet]
|
||||
) -> Mapping[str, _Packet]:
|
||||
"""A synchronous method to process continuous video frames.
|
||||
|
||||
The call blocks the current thread until a failure status or a successful
|
||||
|
@ -106,8 +113,9 @@ class BaseVisionTaskApi(object):
|
|||
"""
|
||||
if self._running_mode != _RunningMode.VIDEO:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the video mode. Current running mode:' +
|
||||
self._running_mode.name)
|
||||
'Task is not initialized with the video mode. Current running mode:'
|
||||
+ self._running_mode.name
|
||||
)
|
||||
return self._runner.process(inputs)
|
||||
|
||||
def _send_live_stream_data(self, inputs: Mapping[str, _Packet]) -> None:
|
||||
|
@ -124,13 +132,18 @@ class BaseVisionTaskApi(object):
|
|||
"""
|
||||
if self._running_mode != _RunningMode.LIVE_STREAM:
|
||||
raise ValueError(
|
||||
'Task is not initialized with the live stream mode. Current running mode:'
|
||||
+ self._running_mode.name)
|
||||
'Task is not initialized with the live stream mode. Current running'
|
||||
' mode:'
|
||||
+ self._running_mode.name
|
||||
)
|
||||
self._runner.send(inputs)
|
||||
|
||||
def convert_to_normalized_rect(self,
|
||||
def convert_to_normalized_rect(
|
||||
self,
|
||||
options: _ImageProcessingOptions,
|
||||
roi_allowed: bool = True) -> _NormalizedRect:
|
||||
image: image_module.Image,
|
||||
roi_allowed: bool = True,
|
||||
) -> _NormalizedRect:
|
||||
"""Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly.
|
||||
|
||||
If the input ImageProcessingOptions is not present, returns a default
|
||||
|
@ -140,6 +153,7 @@ class BaseVisionTaskApi(object):
|
|||
|
||||
Args:
|
||||
options: Options for image processing.
|
||||
image: The image to process.
|
||||
roi_allowed: Indicates if the `region_of_interest` field is allowed to be
|
||||
set. By default, it's set to True.
|
||||
|
||||
|
@ -147,7 +161,8 @@ class BaseVisionTaskApi(object):
|
|||
A normalized rect proto that represents the image processing options.
|
||||
"""
|
||||
normalized_rect = _NormalizedRect(
|
||||
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1)
|
||||
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1
|
||||
)
|
||||
if options is None:
|
||||
return normalized_rect
|
||||
|
||||
|
@ -169,6 +184,20 @@ class BaseVisionTaskApi(object):
|
|||
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
|
||||
normalized_rect.width = roi.right - roi.left
|
||||
normalized_rect.height = roi.bottom - roi.top
|
||||
|
||||
# For 90° and 270° rotations, we need to swap width and height.
|
||||
# This is due to the internal behavior of ImageToTensorCalculator, which:
|
||||
# - first denormalizes the provided rect by multiplying the rect width or
|
||||
# height by the image width or height, repectively.
|
||||
# - then rotates this by denormalized rect by the provided rotation, and
|
||||
# uses this for cropping,
|
||||
# - then finally rotates this back.
|
||||
if abs(options.rotation_degrees % 180) != 0:
|
||||
w = normalized_rect.height * image.height / image.width
|
||||
h = normalized_rect.width * image.width / image.height
|
||||
normalized_rect.width = w
|
||||
normalized_rect.height = h
|
||||
|
||||
return normalized_rect
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
@ -212,7 +212,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If face detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
|
@ -261,7 +261,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If face detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
|
@ -320,7 +320,7 @@ class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
detector has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
|
|
|
@ -401,7 +401,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If face landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
|
@ -444,7 +444,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If face landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
|
@ -497,7 +497,7 @@ class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
face landmarker has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
|
|
|
@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
|
|||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
|
||||
_GestureRecognizerGraphOptionsProto = (
|
||||
gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
|
||||
)
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
@ -53,7 +55,9 @@ _HAND_LANDMARKS_STREAM_NAME = 'landmarks'
|
|||
_HAND_LANDMARKS_TAG = 'LANDMARKS'
|
||||
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
||||
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
|
||||
_TASK_GRAPH_NAME = (
|
||||
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
|
||||
)
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
_GESTURE_DEFAULT_INDEX = -1
|
||||
|
||||
|
@ -78,17 +82,21 @@ class GestureRecognizerResult:
|
|||
|
||||
|
||||
def _build_recognition_result(
|
||||
output_packets: Mapping[str,
|
||||
packet_module.Packet]) -> GestureRecognizerResult:
|
||||
output_packets: Mapping[str, packet_module.Packet]
|
||||
) -> GestureRecognizerResult:
|
||||
"""Constructs a `GestureRecognizerResult` from output packets."""
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME]
|
||||
)
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
output_packets[_HANDEDNESS_STREAM_NAME]
|
||||
)
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
gesture_results = []
|
||||
for proto in gestures_proto_list:
|
||||
|
@ -101,7 +109,9 @@ def _build_recognition_result(
|
|||
index=_GESTURE_DEFAULT_INDEX,
|
||||
score=gesture.score,
|
||||
display_name=gesture.display_name,
|
||||
category_name=gesture.label))
|
||||
category_name=gesture.label,
|
||||
)
|
||||
)
|
||||
gesture_results.append(gesture_categories)
|
||||
|
||||
handedness_results = []
|
||||
|
@ -115,7 +125,9 @@ def _build_recognition_result(
|
|||
index=handedness.index,
|
||||
score=handedness.score,
|
||||
display_name=handedness.display_name,
|
||||
category_name=handedness.label))
|
||||
category_name=handedness.label,
|
||||
)
|
||||
)
|
||||
handedness_results.append(handedness_categories)
|
||||
|
||||
hand_landmarks_results = []
|
||||
|
@ -125,7 +137,8 @@ def _build_recognition_result(
|
|||
hand_landmarks_list = []
|
||||
for hand_landmark in hand_landmarks.landmark:
|
||||
hand_landmarks_list.append(
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
)
|
||||
hand_landmarks_results.append(hand_landmarks_list)
|
||||
|
||||
hand_world_landmarks_results = []
|
||||
|
@ -135,12 +148,16 @@ def _build_recognition_result(
|
|||
hand_world_landmarks_list = []
|
||||
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||
hand_world_landmarks_list.append(
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
)
|
||||
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||
|
||||
return GestureRecognizerResult(gesture_results, handedness_results,
|
||||
return GestureRecognizerResult(
|
||||
gesture_results,
|
||||
handedness_results,
|
||||
hand_landmarks_results,
|
||||
hand_world_landmarks_results)
|
||||
hand_world_landmarks_results,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -174,43 +191,62 @@ class GestureRecognizerOptions:
|
|||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
num_hands: Optional[int] = 1
|
||||
min_hand_detection_confidence: Optional[float] = 0.5
|
||||
min_hand_presence_confidence: Optional[float] = 0.5
|
||||
min_tracking_confidence: Optional[float] = 0.5
|
||||
canned_gesture_classifier_options: Optional[
|
||||
_ClassifierOptions] = dataclasses.field(
|
||||
default_factory=_ClassifierOptions)
|
||||
custom_gesture_classifier_options: Optional[
|
||||
_ClassifierOptions] = dataclasses.field(
|
||||
default_factory=_ClassifierOptions)
|
||||
result_callback: Optional[Callable[
|
||||
[GestureRecognizerResult, image_module.Image, int], None]] = None
|
||||
canned_gesture_classifier_options: Optional[_ClassifierOptions] = (
|
||||
dataclasses.field(default_factory=_ClassifierOptions)
|
||||
)
|
||||
custom_gesture_classifier_options: Optional[_ClassifierOptions] = (
|
||||
dataclasses.field(default_factory=_ClassifierOptions)
|
||||
)
|
||||
result_callback: Optional[
|
||||
Callable[[GestureRecognizerResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
|
||||
"""Generates an GestureRecognizerOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
|
||||
# Initialize gesture recognizer options from base options.
|
||||
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
|
||||
base_options=base_options_proto)
|
||||
base_options=base_options_proto
|
||||
)
|
||||
# Configure hand detector and hand landmarker options.
|
||||
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
|
||||
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||
hand_landmarker_options_proto = (
|
||||
gesture_recognizer_options_proto.hand_landmarker_graph_options
|
||||
)
|
||||
hand_landmarker_options_proto.min_tracking_confidence = (
|
||||
self.min_tracking_confidence
|
||||
)
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
|
||||
self.num_hands
|
||||
)
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
|
||||
self.min_hand_detection_confidence
|
||||
)
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
|
||||
self.min_hand_presence_confidence
|
||||
)
|
||||
|
||||
# Configure hand gesture recognizer options.
|
||||
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
|
||||
hand_gesture_recognizer_options_proto = (
|
||||
gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
|
||||
)
|
||||
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(
|
||||
self.canned_gesture_classifier_options.to_pb2())
|
||||
self.canned_gesture_classifier_options.to_pb2()
|
||||
)
|
||||
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(
|
||||
self.custom_gesture_classifier_options.to_pb2())
|
||||
self.custom_gesture_classifier_options.to_pb2()
|
||||
)
|
||||
|
||||
return gesture_recognizer_options_proto
|
||||
|
||||
|
@ -239,12 +275,14 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = GestureRecognizerOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(
|
||||
cls, options: GestureRecognizerOptions) -> 'GestureRecognizer':
|
||||
cls, options: GestureRecognizerOptions
|
||||
) -> 'GestureRecognizer':
|
||||
"""Creates the `GestureRecognizer` object from gesture recognizer options.
|
||||
|
||||
Args:
|
||||
|
@ -268,14 +306,19 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
|
||||
options.result_callback(
|
||||
GestureRecognizerResult([], [], [], []), image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
GestureRecognizerResult([], [], [], []),
|
||||
image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
return
|
||||
|
||||
gesture_recognizer_result = _build_recognition_result(output_packets)
|
||||
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
|
||||
options.result_callback(gesture_recognizer_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
options.result_callback(
|
||||
gesture_recognizer_result,
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -286,23 +329,27 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
output_streams=[
|
||||
':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]),
|
||||
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
|
||||
':'.join([_HAND_LANDMARKS_TAG,
|
||||
_HAND_LANDMARKS_STREAM_NAME]), ':'.join([
|
||||
_HAND_WORLD_LANDMARKS_TAG,
|
||||
_HAND_WORLD_LANDMARKS_STREAM_NAME
|
||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
|
||||
':'.join(
|
||||
[_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
|
||||
),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def recognize(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> GestureRecognizerResult:
|
||||
"""Performs hand gesture recognition on the given image.
|
||||
|
||||
|
@ -325,12 +372,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If gesture recognition failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
|
||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
|
@ -342,7 +390,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> GestureRecognizerResult:
|
||||
"""Performs gesture recognition on the provided video frame.
|
||||
|
||||
|
@ -367,14 +415,15 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If gesture recognition failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
|
@ -386,7 +435,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data to perform gesture recognition.
|
||||
|
||||
|
@ -419,12 +468,13 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
gesture recognizer has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
|
|
@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import image_processing_options as image
|
|||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
|
||||
_HandLandmarkerGraphOptionsProto = (
|
||||
hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
|
||||
)
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
@ -56,6 +58,7 @@ _MICRO_SECONDS_PER_MILLISECOND = 1000
|
|||
|
||||
class HandLandmark(enum.IntEnum):
|
||||
"""The 21 hand landmarks."""
|
||||
|
||||
WRIST = 0
|
||||
THUMB_CMC = 1
|
||||
THUMB_MCP = 2
|
||||
|
@ -95,14 +98,18 @@ class HandLandmarkerResult:
|
|||
|
||||
|
||||
def _build_landmarker_result(
|
||||
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult:
|
||||
output_packets: Mapping[str, packet_module.Packet]
|
||||
) -> HandLandmarkerResult:
|
||||
"""Constructs a `HandLandmarksDetectionResult` from output packets."""
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
output_packets[_HANDEDNESS_STREAM_NAME]
|
||||
)
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
handedness_results = []
|
||||
for proto in handedness_proto_list:
|
||||
|
@ -115,7 +122,9 @@ def _build_landmarker_result(
|
|||
index=handedness.index,
|
||||
score=handedness.score,
|
||||
display_name=handedness.display_name,
|
||||
category_name=handedness.label))
|
||||
category_name=handedness.label,
|
||||
)
|
||||
)
|
||||
handedness_results.append(handedness_categories)
|
||||
|
||||
hand_landmarks_results = []
|
||||
|
@ -125,7 +134,8 @@ def _build_landmarker_result(
|
|||
hand_landmarks_list = []
|
||||
for hand_landmark in hand_landmarks.landmark:
|
||||
hand_landmarks_list.append(
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
)
|
||||
hand_landmarks_results.append(hand_landmarks_list)
|
||||
|
||||
hand_world_landmarks_results = []
|
||||
|
@ -135,11 +145,13 @@ def _build_landmarker_result(
|
|||
hand_world_landmarks_list = []
|
||||
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||
hand_world_landmarks_list.append(
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
)
|
||||
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||
|
||||
return HandLandmarkerResult(handedness_results, hand_landmarks_results,
|
||||
hand_world_landmarks_results)
|
||||
return HandLandmarkerResult(
|
||||
handedness_results, hand_landmarks_results, hand_world_landmarks_results
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -167,28 +179,41 @@ class HandLandmarkerOptions:
|
|||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
num_hands: Optional[int] = 1
|
||||
min_hand_detection_confidence: Optional[float] = 0.5
|
||||
min_hand_presence_confidence: Optional[float] = 0.5
|
||||
min_tracking_confidence: Optional[float] = 0.5
|
||||
result_callback: Optional[Callable[
|
||||
[HandLandmarkerResult, image_module.Image, int], None]] = None
|
||||
result_callback: Optional[
|
||||
Callable[[HandLandmarkerResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
|
||||
"""Generates an HandLandmarkerGraphOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
|
||||
# Initialize the hand landmarker options from base options.
|
||||
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
|
||||
base_options=base_options_proto)
|
||||
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||
base_options=base_options_proto
|
||||
)
|
||||
hand_landmarker_options_proto.min_tracking_confidence = (
|
||||
self.min_tracking_confidence
|
||||
)
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = (
|
||||
self.num_hands
|
||||
)
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = (
|
||||
self.min_hand_detection_confidence
|
||||
)
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
|
||||
self.min_hand_presence_confidence
|
||||
)
|
||||
return hand_landmarker_options_proto
|
||||
|
||||
|
||||
|
@ -216,12 +241,14 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = HandLandmarkerOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: HandLandmarkerOptions) -> 'HandLandmarker':
|
||||
def create_from_options(
|
||||
cls, options: HandLandmarkerOptions
|
||||
) -> 'HandLandmarker':
|
||||
"""Creates the `HandLandmarker` object from hand landmarker options.
|
||||
|
||||
Args:
|
||||
|
@ -245,14 +272,19 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
|
||||
options.result_callback(
|
||||
HandLandmarkerResult([], [], []), image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
HandLandmarkerResult([], [], []),
|
||||
image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
return
|
||||
|
||||
hand_landmarks_detection_result = _build_landmarker_result(output_packets)
|
||||
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
|
||||
options.result_callback(hand_landmarks_detection_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
options.result_callback(
|
||||
hand_landmarks_detection_result,
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -263,21 +295,26 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
output_streams=[
|
||||
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
|
||||
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
|
||||
':'.join([
|
||||
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME
|
||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join(
|
||||
[_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME]
|
||||
),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> HandLandmarkerResult:
|
||||
"""Performs hand landmarks detection on the given image.
|
||||
|
||||
|
@ -300,12 +337,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If hand landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
|
||||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
|
@ -317,7 +355,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> HandLandmarkerResult:
|
||||
"""Performs hand landmarks detection on the provided video frame.
|
||||
|
||||
|
@ -342,14 +380,15 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If hand landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
||||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
|
@ -361,7 +400,7 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data to perform hand landmarks detection.
|
||||
|
||||
|
@ -394,12 +433,13 @@ class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
hand landmarker has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
|
|
@ -35,7 +35,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
|||
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
_ImageClassifierGraphOptionsProto = (
|
||||
image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
)
|
||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
@ -48,7 +50,9 @@ _IMAGE_OUT_STREAM_NAME = 'image_out'
|
|||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||
_TASK_GRAPH_NAME = (
|
||||
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||
)
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
|
@ -81,6 +85,7 @@ class ImageClassifierOptions:
|
|||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
display_names_locale: Optional[str] = None
|
||||
|
@ -88,24 +93,29 @@ class ImageClassifierOptions:
|
|||
score_threshold: Optional[float] = None
|
||||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
result_callback: Optional[Callable[
|
||||
[ImageClassifierResult, image_module.Image, int], None]] = None
|
||||
result_callback: Optional[
|
||||
Callable[[ImageClassifierResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||
"""Generates an ImageClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
classifier_options_proto = _ClassifierOptionsProto(
|
||||
score_threshold=self.score_threshold,
|
||||
category_allowlist=self.category_allowlist,
|
||||
category_denylist=self.category_denylist,
|
||||
display_names_locale=self.display_names_locale,
|
||||
max_results=self.max_results)
|
||||
max_results=self.max_results,
|
||||
)
|
||||
|
||||
return _ImageClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
classifier_options=classifier_options_proto)
|
||||
classifier_options=classifier_options_proto,
|
||||
)
|
||||
|
||||
|
||||
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||
|
@ -165,12 +175,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = ImageClassifierOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: ImageClassifierOptions) -> 'ImageClassifier':
|
||||
def create_from_options(
|
||||
cls, options: ImageClassifierOptions
|
||||
) -> 'ImageClassifier':
|
||||
"""Creates the `ImageClassifier` object from image classifier options.
|
||||
|
||||
Args:
|
||||
|
@ -191,12 +203,15 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
|
||||
)
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(
|
||||
ImageClassifierResult.create_from_pb2(classification_result_proto),
|
||||
image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -206,19 +221,23 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
],
|
||||
output_streams=[
|
||||
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def classify(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageClassifierResult:
|
||||
"""Performs image classification on the provided MediaPipe Image.
|
||||
|
||||
|
@ -233,17 +252,20 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
|
||||
)
|
||||
|
||||
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||
|
||||
|
@ -251,7 +273,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageClassifierResult:
|
||||
"""Performs image classification on the provided video frames.
|
||||
|
||||
|
@ -272,19 +294,22 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
|
||||
)
|
||||
|
||||
return ImageClassifierResult.create_from_pb2(classification_result_proto)
|
||||
|
||||
|
@ -292,7 +317,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||
|
||||
|
@ -320,12 +345,14 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If the current input timestamp is smaller than what the image
|
||||
classifier has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
|
|
@ -34,7 +34,9 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode as runni
|
|||
|
||||
ImageEmbedderResult = embedding_result_module.EmbeddingResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageEmbedderGraphOptionsProto = image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
_ImageEmbedderGraphOptionsProto = (
|
||||
image_embedder_graph_options_pb2.ImageEmbedderGraphOptions
|
||||
)
|
||||
_EmbedderOptionsProto = embedder_options_pb2.EmbedderOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
@ -74,24 +76,29 @@ class ImageEmbedderOptions:
|
|||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
l2_normalize: Optional[bool] = None
|
||||
quantize: Optional[bool] = None
|
||||
result_callback: Optional[Callable[
|
||||
[ImageEmbedderResult, image_module.Image, int], None]] = None
|
||||
result_callback: Optional[
|
||||
Callable[[ImageEmbedderResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageEmbedderGraphOptionsProto:
|
||||
"""Generates an ImageEmbedderOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
embedder_options_proto = _EmbedderOptionsProto(
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize)
|
||||
l2_normalize=self.l2_normalize, quantize=self.quantize
|
||||
)
|
||||
|
||||
return _ImageEmbedderGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
embedder_options=embedder_options_proto)
|
||||
base_options=base_options_proto, embedder_options=embedder_options_proto
|
||||
)
|
||||
|
||||
|
||||
class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
||||
|
@ -135,12 +142,14 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = ImageEmbedderOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: ImageEmbedderOptions) -> 'ImageEmbedder':
|
||||
def create_from_options(
|
||||
cls, options: ImageEmbedderOptions
|
||||
) -> 'ImageEmbedder':
|
||||
"""Creates the `ImageEmbedder` object from image embedder options.
|
||||
|
||||
Args:
|
||||
|
@ -161,13 +170,16 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
)
|
||||
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(
|
||||
ImageEmbedderResult.create_from_pb2(embedding_result_proto), image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
ImageEmbedderResult.create_from_pb2(embedding_result_proto),
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
|
@ -177,19 +189,23 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
],
|
||||
output_streams=[
|
||||
':'.join([_EMBEDDINGS_TAG, _EMBEDDINGS_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided MediaPipe Image.
|
||||
|
||||
|
@ -207,17 +223,20 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image embedder failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
)
|
||||
|
||||
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||
|
||||
|
@ -225,7 +244,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageEmbedderResult:
|
||||
"""Performs image embedding extraction on the provided video frames.
|
||||
|
||||
|
@ -249,18 +268,21 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image embedder failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
embedding_result_proto = embeddings_pb2.EmbeddingResult()
|
||||
embedding_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME]))
|
||||
packet_getter.get_proto(output_packets[_EMBEDDINGS_OUT_STREAM_NAME])
|
||||
)
|
||||
|
||||
return ImageEmbedderResult.create_from_pb2(embedding_result_proto)
|
||||
|
||||
|
@ -268,7 +290,7 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data to embedder.
|
||||
|
||||
|
@ -301,19 +323,24 @@ class ImageEmbedder(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If the current input timestamp is smaller than what the image
|
||||
embedder has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def cosine_similarity(cls, u: embedding_result_module.Embedding,
|
||||
v: embedding_result_module.Embedding) -> float:
|
||||
def cosine_similarity(
|
||||
cls,
|
||||
u: embedding_result_module.Embedding,
|
||||
v: embedding_result_module.Embedding,
|
||||
) -> float:
|
||||
"""Utility function to compute cosine similarity between two embedding entries.
|
||||
|
||||
May return an InvalidArgumentError if e.g. the feature vectors are of
|
||||
|
|
|
@ -23,16 +23,23 @@ from mediapipe.python._framework_bindings import image as image_module
|
|||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
|
||||
from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageSegmenterResult = List[image_module.Image]
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
||||
_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
|
||||
_ImageSegmenterGraphOptionsProto = (
|
||||
image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
|
||||
)
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
||||
|
@ -40,6 +47,8 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
|||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
@ -77,19 +86,24 @@ class ImageSegmenterOptions:
|
|||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
||||
activation: Optional[Activation] = Activation.NONE
|
||||
result_callback: Optional[Callable[
|
||||
[List[image_module.Image], image_module.Image, int], None]] = None
|
||||
result_callback: Optional[
|
||||
Callable[[ImageSegmenterResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||
"""Generates an ImageSegmenterOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
segmenter_options_proto = _SegmenterOptionsProto(
|
||||
output_type=self.output_type.value, activation=self.activation.value)
|
||||
output_type=self.output_type.value, activation=self.activation.value
|
||||
)
|
||||
return _ImageSegmenterGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
segmenter_options=segmenter_options_proto)
|
||||
segmenter_options=segmenter_options_proto,
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||
|
@ -138,12 +152,14 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = ImageSegmenterOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: ImageSegmenterOptions) -> 'ImageSegmenter':
|
||||
def create_from_options(
|
||||
cls, options: ImageSegmenterOptions
|
||||
) -> 'ImageSegmenter':
|
||||
"""Creates the `ImageSegmenter` object from image segmenter options.
|
||||
|
||||
Args:
|
||||
|
@ -162,31 +178,47 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(segmentation_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
options.result_callback(
|
||||
segmentation_result,
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def segment(self, image: image_module.Image) -> List[image_module.Image]:
|
||||
def segment(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageSegmenterResult:
|
||||
"""Performs the actual segmentation task on the provided MediaPipe Image.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
|
@ -199,14 +231,26 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image segmentation failed to run.
|
||||
"""
|
||||
output_packets = self._process_image_data(
|
||||
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
return segmentation_result
|
||||
|
||||
def segment_for_video(self, image: image_module.Image,
|
||||
timestamp_ms: int) -> List[image_module.Image]:
|
||||
def segment_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ImageSegmenterResult:
|
||||
"""Performs segmentation on the provided video frames.
|
||||
|
||||
Only use this method when the ImageSegmenter is created with the video
|
||||
|
@ -217,6 +261,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
|
@ -229,16 +274,28 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image segmentation failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME])
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
return segmentation_result
|
||||
|
||||
def segment_async(self, image: image_module.Image, timestamp_ms: int) -> None:
|
||||
def segment_async(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform image segmentation.
|
||||
|
||||
Only use this method when the ImageSegmenter is created with the live stream
|
||||
|
@ -260,13 +317,20 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the image
|
||||
segmenter has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
|
|
@ -22,15 +22,20 @@ from mediapipe.python._framework_bindings import image as image_module
|
|||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.tasks.cc.vision.object_detector.proto import object_detector_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
ObjectDetectorResult = detections_module.DetectionResult
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ObjectDetectorOptionsProto = object_detector_options_pb2.ObjectDetectorOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_DETECTIONS_OUT_STREAM_NAME = 'detections_out'
|
||||
|
@ -38,7 +43,10 @@ _DETECTIONS_TAG = 'DETECTIONS'
|
|||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ObjectDetectorGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -48,11 +56,10 @@ class ObjectDetectorOptions:
|
|||
Attributes:
|
||||
base_options: Base options for the object detector task.
|
||||
running_mode: The running mode of the task. Default to the image mode.
|
||||
Object detector task has three running modes:
|
||||
1) The image mode for detecting objects on single image inputs.
|
||||
2) The video mode for detecting objects on the decoded frames of a video.
|
||||
3) The live stream mode for detecting objects on a live stream of input
|
||||
data, such as from camera.
|
||||
Object detector task has three running modes: 1) The image mode for
|
||||
detecting objects on single image inputs. 2) The video mode for detecting
|
||||
objects on the decoded frames of a video. 3) The live stream mode for
|
||||
detecting objects on a live stream of input data, such as from camera.
|
||||
display_names_locale: The locale to use for display names specified through
|
||||
the TFLite Model Metadata.
|
||||
max_results: The maximum number of top-scored classification results to
|
||||
|
@ -71,6 +78,7 @@ class ObjectDetectorOptions:
|
|||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
display_names_locale: Optional[str] = None
|
||||
|
@ -79,14 +87,16 @@ class ObjectDetectorOptions:
|
|||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
result_callback: Optional[
|
||||
Callable[[detections_module.DetectionResult, image_module.Image, int],
|
||||
None]] = None
|
||||
Callable[[ObjectDetectorResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ObjectDetectorOptionsProto:
|
||||
"""Generates an ObjectDetectorOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
return _ObjectDetectorOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
display_names_locale=self.display_names_locale,
|
||||
|
@ -163,12 +173,14 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = ObjectDetectorOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: ObjectDetectorOptions) -> 'ObjectDetector':
|
||||
def create_from_options(
|
||||
cls, options: ObjectDetectorOptions
|
||||
) -> 'ObjectDetector':
|
||||
"""Creates the `ObjectDetector` object from object detector options.
|
||||
|
||||
Args:
|
||||
|
@ -187,32 +199,45 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME])
|
||||
detection_result = detections_module.DetectionResult([
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
detection_result = ObjectDetectorResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
])
|
||||
]
|
||||
)
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(detection_result, image, timestamp)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME])],
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options)
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
# TODO: Create an Image class for MediaPipe Tasks.
|
||||
def detect(self,
|
||||
image: image_module.Image) -> detections_module.DetectionResult:
|
||||
def detect(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ObjectDetectorResult:
|
||||
"""Performs object detection on the provided MediaPipe Image.
|
||||
|
||||
Only use this method when the ObjectDetector is created with the image
|
||||
|
@ -220,6 +245,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A detection result object that contains a list of detections, each
|
||||
|
@ -231,17 +257,31 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If object detection failed to run.
|
||||
"""
|
||||
output_packets = self._process_image_data(
|
||||
{_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image)})
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME])
|
||||
return detections_module.DetectionResult([
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
return ObjectDetectorResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
def detect_for_video(self, image: image_module.Image,
|
||||
timestamp_ms: int) -> detections_module.DetectionResult:
|
||||
def detect_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> ObjectDetectorResult:
|
||||
"""Performs object detection on the provided video frames.
|
||||
|
||||
Only use this method when the ObjectDetector is created with the video
|
||||
|
@ -252,6 +292,7 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A detection result object that contains a list of detections, each
|
||||
|
@ -263,18 +304,33 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If object detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(timestamp_ms)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME])
|
||||
return detections_module.DetectionResult([
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
return ObjectDetectorResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
def detect_async(self, image: image_module.Image, timestamp_ms: int) -> None:
|
||||
def detect_async(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform object detection.
|
||||
|
||||
Only use this method when the ObjectDetector is created with the live stream
|
||||
|
@ -298,12 +354,20 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
|||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the object
|
||||
detector has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(timestamp_ms)
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue
Block a user