Added more tests and updated the APIs to use a new constant
This commit is contained in:
		
							parent
							
								
									e250c903f5
								
							
						
					
					
						commit
						cb806071ba
					
				| 
						 | 
					@ -14,6 +14,7 @@
 | 
				
			||||||
"""Tests for image classifier."""
 | 
					"""Tests for image classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from absl.testing import absltest
 | 
					from absl.testing import absltest
 | 
				
			||||||
| 
						 | 
					@ -41,12 +42,7 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
 | 
					_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
 | 
				
			||||||
_IMAGE_FILE = 'burger.jpg'
 | 
					_IMAGE_FILE = 'burger.jpg'
 | 
				
			||||||
_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
					_EXPECTED_CATEGORIES = [
 | 
				
			||||||
  classifications=[
 | 
					 | 
				
			||||||
    _Classifications(
 | 
					 | 
				
			||||||
      entries=[
 | 
					 | 
				
			||||||
        _ClassificationEntry(
 | 
					 | 
				
			||||||
          categories=[
 | 
					 | 
				
			||||||
    _Category(
 | 
					    _Category(
 | 
				
			||||||
      index=934,
 | 
					      index=934,
 | 
				
			||||||
      score=0.7939587831497192,
 | 
					      score=0.7939587831497192,
 | 
				
			||||||
| 
						 | 
					@ -67,7 +63,25 @@ _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
				
			||||||
      score=0.006327860057353973,
 | 
					      score=0.006327860057353973,
 | 
				
			||||||
      display_name='',
 | 
					      display_name='',
 | 
				
			||||||
      category_name='meat loaf')
 | 
					      category_name='meat loaf')
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
				
			||||||
 | 
					  classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					      entries=[
 | 
				
			||||||
 | 
					        _ClassificationEntry(
 | 
				
			||||||
 | 
					          categories=_EXPECTED_CATEGORIES,
 | 
				
			||||||
 | 
					          timestamp_ms=0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
      ],
 | 
					      ],
 | 
				
			||||||
 | 
					      head_index=0,
 | 
				
			||||||
 | 
					      head_name='probability')
 | 
				
			||||||
 | 
					  ])
 | 
				
			||||||
 | 
					_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
				
			||||||
 | 
					  classifications=[
 | 
				
			||||||
 | 
					    _Classifications(
 | 
				
			||||||
 | 
					      entries=[
 | 
				
			||||||
 | 
					        _ClassificationEntry(
 | 
				
			||||||
 | 
					          categories=[],
 | 
				
			||||||
          timestamp_ms=0
 | 
					          timestamp_ms=0
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
      ],
 | 
					      ],
 | 
				
			||||||
| 
						 | 
					@ -93,6 +107,36 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
        test_utils.get_test_data_path(_IMAGE_FILE))
 | 
					        test_utils.get_test_data_path(_IMAGE_FILE))
 | 
				
			||||||
    self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
 | 
					    self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with default option and valid model file successfully.
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _ImageClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_path(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model file successfully.
 | 
				
			||||||
 | 
					    base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _ImageClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
 | 
					    # Invalid empty model path.
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        r"ExternalFile must specify at least one of 'file_content', "
 | 
				
			||||||
 | 
					        r"'file_name' or 'file_descriptor_meta'."):
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path='')
 | 
				
			||||||
 | 
					      options = _ImageClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      _ImageClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_create_from_options_succeeds_with_valid_model_content(self):
 | 
				
			||||||
 | 
					    # Creates with options containing model content successfully.
 | 
				
			||||||
 | 
					    with open(self.model_path, 'rb') as f:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=f.read())
 | 
				
			||||||
 | 
					      options = _ImageClassifierOptions(base_options=base_options)
 | 
				
			||||||
 | 
					      classifier = _ImageClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					      self.assertIsInstance(classifier, _ImageClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @parameterized.parameters(
 | 
					  @parameterized.parameters(
 | 
				
			||||||
      (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
 | 
					      (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
 | 
				
			||||||
      (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
 | 
					      (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
 | 
				
			||||||
| 
						 | 
					@ -122,6 +166,183 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
    # a context.
 | 
					    # a context.
 | 
				
			||||||
    classifier.close()
 | 
					    classifier.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters(
 | 
				
			||||||
 | 
					    (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
 | 
				
			||||||
 | 
					    (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
 | 
				
			||||||
 | 
					  def test_classify_in_context(self, model_file_type, max_results,
 | 
				
			||||||
 | 
					                               expected_classification_result):
 | 
				
			||||||
 | 
					    if model_file_type is ModelFileType.FILE_NAME:
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
				
			||||||
 | 
					    elif model_file_type is ModelFileType.FILE_CONTENT:
 | 
				
			||||||
 | 
					      with open(self.model_path, 'rb') as f:
 | 
				
			||||||
 | 
					        model_content = f.read()
 | 
				
			||||||
 | 
					      base_options = _BaseOptions(model_asset_buffer=model_content)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      # Should never happen
 | 
				
			||||||
 | 
					      raise ValueError('model_file_type is invalid.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(max_results=max_results)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=base_options, classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      # Comparing results.
 | 
				
			||||||
 | 
					      self.assertEqual(image_result, expected_classification_result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_score_threshold_option(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      classifications = image_result.classifications
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for classification in classifications:
 | 
				
			||||||
 | 
					        for entry in classification.entries:
 | 
				
			||||||
 | 
					          score = entry.categories[0].score
 | 
				
			||||||
 | 
					          self.assertGreaterEqual(
 | 
				
			||||||
 | 
					            score, _SCORE_THRESHOLD,
 | 
				
			||||||
 | 
					            f'Classification with score lower than threshold found. '
 | 
				
			||||||
 | 
					            f'{classification}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_max_results_option(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      categories = image_result.classifications[0].entries[0].categories
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      self.assertLessEqual(
 | 
				
			||||||
 | 
					        len(categories), _MAX_RESULTS, 'Too many results returned.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_allow_list_option(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      classifications = image_result.classifications
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for classification in classifications:
 | 
				
			||||||
 | 
					        for entry in classification.entries:
 | 
				
			||||||
 | 
					          label = entry.categories[0].category_name
 | 
				
			||||||
 | 
					          self.assertIn(label, _ALLOW_LIST,
 | 
				
			||||||
 | 
					                        f'Label {label} found but not in label allow list')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_deny_list_option(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      classifications = image_result.classifications
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for classification in classifications:
 | 
				
			||||||
 | 
					        for entry in classification.entries:
 | 
				
			||||||
 | 
					          label = entry.categories[0].category_name
 | 
				
			||||||
 | 
					          self.assertNotIn(label, _DENY_LIST,
 | 
				
			||||||
 | 
					                           f'Label {label} found but in deny list.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_combined_allowlist_and_denylist(self):
 | 
				
			||||||
 | 
					    # Fails with combined allowlist and denylist
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        r'`category_allowlist` and `category_denylist` are mutually '
 | 
				
			||||||
 | 
					        r'exclusive options.'):
 | 
				
			||||||
 | 
					      classifier_options = _ClassifierOptions(category_allowlist=['foo'],
 | 
				
			||||||
 | 
					                                              category_denylist=['bar'])
 | 
				
			||||||
 | 
					      options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					        base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					        classifier_options=classifier_options)
 | 
				
			||||||
 | 
					      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_empty_classification_outputs(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(score_threshold=1)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      classifier_options=classifier_options)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      # Performs image classification on the input.
 | 
				
			||||||
 | 
					      image_result = classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					      self.assertEmpty(image_result.classifications[0].entries[0].categories)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_missing_result_callback(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.LIVE_STREAM)
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'result callback must be provided'):
 | 
				
			||||||
 | 
					      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
 | 
				
			||||||
 | 
					  def test_illegal_result_callback(self, running_mode):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=running_mode,
 | 
				
			||||||
 | 
					      result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                r'result callback should not be provided'):
 | 
				
			||||||
 | 
					      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_for_video_in_image_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.IMAGE)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the video mode'):
 | 
				
			||||||
 | 
					        classifier.classify_for_video(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_async_in_image_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.IMAGE)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the live stream mode'):
 | 
				
			||||||
 | 
					        classifier.classify_async(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_in_video_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.VIDEO)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the image mode'):
 | 
				
			||||||
 | 
					        classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_async_in_video_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.VIDEO)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the live stream mode'):
 | 
				
			||||||
 | 
					        classifier.classify_async(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_for_video_with_out_of_order_timestamp(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.VIDEO)
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      unused_result = classifier.classify_for_video(self.test_image, 1)
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					          ValueError, r'Input timestamp must be monotonically increasing'):
 | 
				
			||||||
 | 
					        classifier.classify_for_video(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_classify_for_video(self):
 | 
					  def test_classify_for_video(self):
 | 
				
			||||||
    classifier_options = _ClassifierOptions(max_results=4)
 | 
					    classifier_options = _ClassifierOptions(max_results=4)
 | 
				
			||||||
    options = _ImageClassifierOptions(
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
| 
						 | 
					@ -132,7 +353,78 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      for timestamp in range(0, 300, 30):
 | 
					      for timestamp in range(0, 300, 30):
 | 
				
			||||||
        classification_result = classifier.classify_for_video(
 | 
					        classification_result = classifier.classify_for_video(
 | 
				
			||||||
            self.test_image, timestamp)
 | 
					            self.test_image, timestamp)
 | 
				
			||||||
        self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
 | 
					        expected_classification_result = _ClassificationResult(
 | 
				
			||||||
 | 
					          classifications=[
 | 
				
			||||||
 | 
					            _Classifications(
 | 
				
			||||||
 | 
					              entries=[
 | 
				
			||||||
 | 
					                _ClassificationEntry(
 | 
				
			||||||
 | 
					                  categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp)
 | 
				
			||||||
 | 
					              ],
 | 
				
			||||||
 | 
					              head_index=0, head_name='probability')
 | 
				
			||||||
 | 
					          ])
 | 
				
			||||||
 | 
					        self.assertEqual(classification_result, expected_classification_result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_in_live_stream_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
				
			||||||
 | 
					      result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the image mode'):
 | 
				
			||||||
 | 
					        classifier.classify(self.test_image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_calling_classify_for_video_in_live_stream_mode(self):
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
				
			||||||
 | 
					      result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(ValueError,
 | 
				
			||||||
 | 
					                                  r'not initialized with the video mode'):
 | 
				
			||||||
 | 
					        classifier.classify_for_video(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_classify_async_calls_with_illegal_timestamp(self):
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(max_results=4)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
				
			||||||
 | 
					      classifier_options=classifier_options,
 | 
				
			||||||
 | 
					      result_callback=mock.MagicMock())
 | 
				
			||||||
 | 
					    with _ImageClassifier.create_from_options(options) as classifier:
 | 
				
			||||||
 | 
					      classifier.classify_async(self.test_image, 100)
 | 
				
			||||||
 | 
					      with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					          ValueError, r'Input timestamp must be monotonically increasing'):
 | 
				
			||||||
 | 
					        classifier.classify_async(self.test_image, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # TODO: Fix the packet is empty issue.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					  @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
 | 
				
			||||||
 | 
					                            (1, _EMPTY_CLASSIFICATION_RESULT))
 | 
				
			||||||
 | 
					  def test_classify_async_calls(self, threshold, expected_result):
 | 
				
			||||||
 | 
					    observed_timestamp_ms = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
				
			||||||
 | 
					                     timestamp_ms: int):
 | 
				
			||||||
 | 
					      self.assertEqual(result, expected_result)
 | 
				
			||||||
 | 
					      self.assertTrue(
 | 
				
			||||||
 | 
					        np.array_equal(output_image.numpy_view(),
 | 
				
			||||||
 | 
					                       self.test_image.numpy_view()))
 | 
				
			||||||
 | 
					      self.assertLess(observed_timestamp_ms, timestamp_ms)
 | 
				
			||||||
 | 
					      self.observed_timestamp_ms = timestamp_ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    classifier_options = _ClassifierOptions(
 | 
				
			||||||
 | 
					      max_results=4, score_threshold=threshold)
 | 
				
			||||||
 | 
					    options = _ImageClassifierOptions(
 | 
				
			||||||
 | 
					      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
				
			||||||
 | 
					      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
				
			||||||
 | 
					      classifier_options=classifier_options,
 | 
				
			||||||
 | 
					      result_callback=check_result)
 | 
				
			||||||
 | 
					    classifier = _ImageClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					    for timestamp in range(0, 300, 30):
 | 
				
			||||||
 | 
					      classifier.classify_async(self.test_image, timestamp)
 | 
				
			||||||
 | 
					    classifier.close()
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,6 +43,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in'
 | 
				
			||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
 | 
					_IMAGE_OUT_STREAM_NAME = 'image_out'
 | 
				
			||||||
_IMAGE_TAG = 'IMAGE'
 | 
					_IMAGE_TAG = 'IMAGE'
 | 
				
			||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
 | 
					_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
 | 
				
			||||||
 | 
					_MICRO_SECONDS_PER_MILLISECOND = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
| 
						 | 
					@ -91,7 +92,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
				
			||||||
    """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
 | 
					    """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Note that the created `ImageClassifier` instance is in image mode, for
 | 
					    Note that the created `ImageClassifier` instance is in image mode, for
 | 
				
			||||||
    detecting objects on single image inputs.
 | 
					    classifying objects on single image inputs.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      model_path: Path to the model.
 | 
					      model_path: Path to the model.
 | 
				
			||||||
| 
						 | 
					@ -137,7 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
				
			||||||
      ])
 | 
					      ])
 | 
				
			||||||
      image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
 | 
					      image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
 | 
				
			||||||
      timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
 | 
					      timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
 | 
				
			||||||
      options.result_callback(classification_result, image, timestamp)
 | 
					      options.result_callback(classification_result, image,
 | 
				
			||||||
 | 
					                              timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    task_info = _TaskInfo(
 | 
					    task_info = _TaskInfo(
 | 
				
			||||||
        task_graph=_TASK_GRAPH_NAME,
 | 
					        task_graph=_TASK_GRAPH_NAME,
 | 
				
			||||||
| 
						 | 
					@ -156,7 +158,6 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
				
			||||||
            _RunningMode.LIVE_STREAM), options.running_mode,
 | 
					            _RunningMode.LIVE_STREAM), options.running_mode,
 | 
				
			||||||
        packets_callback if options.result_callback else None)
 | 
					        packets_callback if options.result_callback else None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO: Create an Image class for MediaPipe Tasks.
 | 
					 | 
				
			||||||
  def classify(
 | 
					  def classify(
 | 
				
			||||||
      self,
 | 
					      self,
 | 
				
			||||||
      image: image_module.Image,
 | 
					      image: image_module.Image,
 | 
				
			||||||
| 
						 | 
					@ -207,7 +208,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    output_packets = self._process_video_data({
 | 
					    output_packets = self._process_video_data({
 | 
				
			||||||
        _IMAGE_IN_STREAM_NAME:
 | 
					        _IMAGE_IN_STREAM_NAME:
 | 
				
			||||||
        packet_creator.create_image(image).at(timestamp_ms)
 | 
					            packet_creator.create_image(image).at(
 | 
				
			||||||
 | 
					                timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    classification_result_proto = packet_getter.get_proto(
 | 
					    classification_result_proto = packet_getter.get_proto(
 | 
				
			||||||
      output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
 | 
					      output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
 | 
				
			||||||
| 
						 | 
					@ -216,3 +218,36 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
				
			||||||
        classifications_module.Classifications.create_from_pb2(classification)
 | 
					        classifications_module.Classifications.create_from_pb2(classification)
 | 
				
			||||||
        for classification in classification_result_proto.classifications
 | 
					        for classification in classification_result_proto.classifications
 | 
				
			||||||
    ])
 | 
					    ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None:
 | 
				
			||||||
 | 
					    """Sends live image data (an Image with a unique timestamp) to perform
 | 
				
			||||||
 | 
					    image classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Only use this method when the ImageClassifier is created with the live
 | 
				
			||||||
 | 
					    stream running mode. The input timestamps should be monotonically increasing
 | 
				
			||||||
 | 
					    for adjacent calls of this method. This method will return immediately after
 | 
				
			||||||
 | 
					    the input image is accepted. The results will be available via the
 | 
				
			||||||
 | 
					    `result_callback` provided in the `ImageClassifierOptions`. The
 | 
				
			||||||
 | 
					    `classify_async` method is designed to process live stream data such as
 | 
				
			||||||
 | 
					    camera input. To lower the overall latency, image classifier may drop the
 | 
				
			||||||
 | 
					    input images if needed. In other words, it's not guaranteed to have output
 | 
				
			||||||
 | 
					    per input image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The `result_callback` provides:
 | 
				
			||||||
 | 
					      - A classification result object that contains a list of classifications.
 | 
				
			||||||
 | 
					      - The input image that the image classifier runs on.
 | 
				
			||||||
 | 
					      - The input timestamp in milliseconds.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      image: MediaPipe Image.
 | 
				
			||||||
 | 
					      timestamp_ms: The timestamp of the input image in milliseconds.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					      ValueError: If the current input timestamp is smaller than what the image
 | 
				
			||||||
 | 
					        classifier has already processed.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    self._send_live_stream_data({
 | 
				
			||||||
 | 
					        _IMAGE_IN_STREAM_NAME:
 | 
				
			||||||
 | 
					            packet_creator.create_image(image).at(
 | 
				
			||||||
 | 
					                timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user