Added more tests and updated the APIs to use a new constant

This commit is contained in:
kinaryml 2022-10-07 22:26:49 -07:00
parent e250c903f5
commit cb806071ba
2 changed files with 355 additions and 28 deletions

View File

@ -14,6 +14,7 @@
"""Tests for image classifier.""" """Tests for image classifier."""
import enum import enum
from unittest import mock
import numpy as np import numpy as np
from absl.testing import absltest from absl.testing import absltest
@ -41,12 +42,7 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg' _IMAGE_FILE = 'burger.jpg'
_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult( _EXPECTED_CATEGORIES = [
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category( _Category(
index=934, index=934,
score=0.7939587831497192, score=0.7939587831497192,
@ -67,7 +63,25 @@ _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
score=0.006327860057353973, score=0.006327860057353973,
display_name='', display_name='',
category_name='meat loaf') category_name='meat loaf')
]
_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=_EXPECTED_CATEGORIES,
timestamp_ms=0
)
], ],
head_index=0,
head_name='probability')
])
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[],
timestamp_ms=0 timestamp_ms=0
) )
], ],
@ -93,6 +107,36 @@ class ImageClassifierTest(parameterized.TestCase):
test_utils.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(_IMAGE_FILE))
self.model_path = test_utils.get_test_data_path(_MODEL_FILE) self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options)
with _ImageClassifier.create_from_options(options) as classifier:
self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _ImageClassifierOptions(base_options=base_options)
classifier = _ImageClassifier.create_from_options(options)
self.assertIsInstance(classifier, _ImageClassifier)
@parameterized.parameters( @parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT), (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT)) (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
@ -122,6 +166,183 @@ class ImageClassifierTest(parameterized.TestCase):
# a context. # a context.
classifier.close() classifier.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
self.assertEqual(image_result, expected_classification_result)
def test_score_threshold_option(self):
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
score = entry.categories[0].score
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
f'Classification with score lower than threshold found. '
f'{classification}')
def test_max_results_option(self):
classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].entries[0].categories
self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.')
def test_allow_list_option(self):
classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list')
def test_deny_list_option(self):
classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.')
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
classifier_options = _ClassifierOptions(category_allowlist=['foo'],
category_denylist=['bar'])
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_empty_classification_outputs(self):
classifier_options = _ClassifierOptions(score_threshold=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
self.assertEmpty(image_result.classifications[0].entries[0].categories)
def test_missing_result_callback(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_calling_classify_for_video_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
classifier.classify_for_video(self.test_image, 0)
def test_calling_classify_async_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
classifier.classify_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
classifier.classify(self.test_image)
def test_calling_classify_async_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
classifier.classify_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
unused_result = classifier.classify_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self): def test_classify_for_video(self):
classifier_options = _ClassifierOptions(max_results=4) classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions( options = _ImageClassifierOptions(
@ -132,7 +353,78 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
self.test_image, timestamp) self.test_image, timestamp)
self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT) expected_classification_result = _ClassificationResult(
classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp)
],
head_index=0, head_name='probability')
])
self.assertEqual(classification_result, expected_classification_result)
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
classifier.classify(self.test_image)
def test_calling_classify_for_video_in_live_stream_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self):
classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=classifier_options,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0)
# TODO: Fix the packet is empty issue.
"""
@parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
(1, _EMPTY_CLASSIFICATION_RESULT))
def test_classify_async_calls(self, threshold, expected_result):
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
classifier_options = _ClassifierOptions(
max_results=4, score_threshold=threshold)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=classifier_options,
result_callback=check_result)
classifier = _ImageClassifier.create_from_options(options)
for timestamp in range(0, 300, 30):
classifier.classify_async(self.test_image, timestamp)
classifier.close()
"""
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -43,6 +43,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass @dataclasses.dataclass
@ -91,7 +92,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
"""Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
Note that the created `ImageClassifier` instance is in image mode, for Note that the created `ImageClassifier` instance is in image mode, for
detecting objects on single image inputs. classifying objects on single image inputs.
Args: Args:
model_path: Path to the model. model_path: Path to the model.
@ -137,7 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
]) ])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image, timestamp) options.result_callback(classification_result, image,
timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
@ -156,7 +158,6 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
_RunningMode.LIVE_STREAM), options.running_mode, _RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None) packets_callback if options.result_callback else None)
# TODO: Create an Image class for MediaPipe Tasks.
def classify( def classify(
self, self,
image: image_module.Image, image: image_module.Image,
@ -207,7 +208,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
""" """
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(timestamp_ms) packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
classification_result_proto = packet_getter.get_proto( classification_result_proto = packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]) output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
@ -216,3 +218,36 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
classifications_module.Classifications.create_from_pb2(classification) classifications_module.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
]) ])
def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform
image classification.
Only use this method when the ImageClassifier is created with the live
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input image is accepted. The results will be available via the
`result_callback` provided in the `ImageClassifierOptions`. The
`classify_async` method is designed to process live stream data such as
camera input. To lower the overall latency, image classifier may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` provides:
- A classification result object that contains a list of classifications.
- The input image that the image classifier runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
Raises:
ValueError: If the current input timestamp is smaller than what the image
classifier has already processed.
"""
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})