Added more tests and updated the APIs to use a new constant
This commit is contained in:
		
							parent
							
								
									e250c903f5
								
							
						
					
					
						commit
						cb806071ba
					
				| 
						 | 
				
			
			@ -14,6 +14,7 @@
 | 
			
		|||
"""Tests for image classifier."""
 | 
			
		||||
 | 
			
		||||
import enum
 | 
			
		||||
from unittest import mock
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from absl.testing import absltest
 | 
			
		||||
| 
						 | 
				
			
			@ -41,12 +42,7 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
 | 
			
		|||
 | 
			
		||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
 | 
			
		||||
_IMAGE_FILE = 'burger.jpg'
 | 
			
		||||
_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
			
		||||
  classifications=[
 | 
			
		||||
    _Classifications(
 | 
			
		||||
      entries=[
 | 
			
		||||
        _ClassificationEntry(
 | 
			
		||||
          categories=[
 | 
			
		||||
_EXPECTED_CATEGORIES = [
 | 
			
		||||
    _Category(
 | 
			
		||||
      index=934,
 | 
			
		||||
      score=0.7939587831497192,
 | 
			
		||||
| 
						 | 
				
			
			@ -67,7 +63,25 @@ _EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
			
		|||
      score=0.006327860057353973,
 | 
			
		||||
      display_name='',
 | 
			
		||||
      category_name='meat loaf')
 | 
			
		||||
]
 | 
			
		||||
_EXPECTED_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
			
		||||
  classifications=[
 | 
			
		||||
    _Classifications(
 | 
			
		||||
      entries=[
 | 
			
		||||
        _ClassificationEntry(
 | 
			
		||||
          categories=_EXPECTED_CATEGORIES,
 | 
			
		||||
          timestamp_ms=0
 | 
			
		||||
        )
 | 
			
		||||
      ],
 | 
			
		||||
      head_index=0,
 | 
			
		||||
      head_name='probability')
 | 
			
		||||
  ])
 | 
			
		||||
_EMPTY_CLASSIFICATION_RESULT = _ClassificationResult(
 | 
			
		||||
  classifications=[
 | 
			
		||||
    _Classifications(
 | 
			
		||||
      entries=[
 | 
			
		||||
        _ClassificationEntry(
 | 
			
		||||
          categories=[],
 | 
			
		||||
          timestamp_ms=0
 | 
			
		||||
        )
 | 
			
		||||
      ],
 | 
			
		||||
| 
						 | 
				
			
			@ -93,6 +107,36 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
			
		|||
        test_utils.get_test_data_path(_IMAGE_FILE))
 | 
			
		||||
    self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
 | 
			
		||||
 | 
			
		||||
  def test_create_from_file_succeeds_with_valid_model_path(self):
 | 
			
		||||
    # Creates with default option and valid model file successfully.
 | 
			
		||||
    with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
 | 
			
		||||
      self.assertIsInstance(classifier, _ImageClassifier)
 | 
			
		||||
 | 
			
		||||
  def test_create_from_options_succeeds_with_valid_model_path(self):
 | 
			
		||||
    # Creates with options containing model file successfully.
 | 
			
		||||
    base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
			
		||||
    options = _ImageClassifierOptions(base_options=base_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      self.assertIsInstance(classifier, _ImageClassifier)
 | 
			
		||||
 | 
			
		||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
			
		||||
    # Invalid empty model path.
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError,
 | 
			
		||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
			
		||||
        r"'file_name' or 'file_descriptor_meta'."):
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path='')
 | 
			
		||||
      options = _ImageClassifierOptions(base_options=base_options)
 | 
			
		||||
      _ImageClassifier.create_from_options(options)
 | 
			
		||||
 | 
			
		||||
  def test_create_from_options_succeeds_with_valid_model_content(self):
 | 
			
		||||
    # Creates with options containing model content successfully.
 | 
			
		||||
    with open(self.model_path, 'rb') as f:
 | 
			
		||||
      base_options = _BaseOptions(model_asset_buffer=f.read())
 | 
			
		||||
      options = _ImageClassifierOptions(base_options=base_options)
 | 
			
		||||
      classifier = _ImageClassifier.create_from_options(options)
 | 
			
		||||
      self.assertIsInstance(classifier, _ImageClassifier)
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
      (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
 | 
			
		||||
      (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
 | 
			
		||||
| 
						 | 
				
			
			@ -122,6 +166,183 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
			
		|||
    # a context.
 | 
			
		||||
    classifier.close()
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters(
 | 
			
		||||
    (ModelFileType.FILE_NAME, 4, _EXPECTED_CLASSIFICATION_RESULT),
 | 
			
		||||
    (ModelFileType.FILE_CONTENT, 4, _EXPECTED_CLASSIFICATION_RESULT))
 | 
			
		||||
  def test_classify_in_context(self, model_file_type, max_results,
 | 
			
		||||
                               expected_classification_result):
 | 
			
		||||
    if model_file_type is ModelFileType.FILE_NAME:
 | 
			
		||||
      base_options = _BaseOptions(model_asset_path=self.model_path)
 | 
			
		||||
    elif model_file_type is ModelFileType.FILE_CONTENT:
 | 
			
		||||
      with open(self.model_path, 'rb') as f:
 | 
			
		||||
        model_content = f.read()
 | 
			
		||||
      base_options = _BaseOptions(model_asset_buffer=model_content)
 | 
			
		||||
    else:
 | 
			
		||||
      # Should never happen
 | 
			
		||||
      raise ValueError('model_file_type is invalid.')
 | 
			
		||||
 | 
			
		||||
    classifier_options = _ClassifierOptions(max_results=max_results)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=base_options, classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      # Comparing results.
 | 
			
		||||
      self.assertEqual(image_result, expected_classification_result)
 | 
			
		||||
 | 
			
		||||
  def test_score_threshold_option(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      classifications = image_result.classifications
 | 
			
		||||
 | 
			
		||||
      for classification in classifications:
 | 
			
		||||
        for entry in classification.entries:
 | 
			
		||||
          score = entry.categories[0].score
 | 
			
		||||
          self.assertGreaterEqual(
 | 
			
		||||
            score, _SCORE_THRESHOLD,
 | 
			
		||||
            f'Classification with score lower than threshold found. '
 | 
			
		||||
            f'{classification}')
 | 
			
		||||
 | 
			
		||||
  def test_max_results_option(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(score_threshold=_SCORE_THRESHOLD)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      categories = image_result.classifications[0].entries[0].categories
 | 
			
		||||
 | 
			
		||||
      self.assertLessEqual(
 | 
			
		||||
        len(categories), _MAX_RESULTS, 'Too many results returned.')
 | 
			
		||||
 | 
			
		||||
  def test_allow_list_option(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(category_allowlist=_ALLOW_LIST)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      classifications = image_result.classifications
 | 
			
		||||
 | 
			
		||||
      for classification in classifications:
 | 
			
		||||
        for entry in classification.entries:
 | 
			
		||||
          label = entry.categories[0].category_name
 | 
			
		||||
          self.assertIn(label, _ALLOW_LIST,
 | 
			
		||||
                        f'Label {label} found but not in label allow list')
 | 
			
		||||
 | 
			
		||||
  def test_deny_list_option(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      classifications = image_result.classifications
 | 
			
		||||
 | 
			
		||||
      for classification in classifications:
 | 
			
		||||
        for entry in classification.entries:
 | 
			
		||||
          label = entry.categories[0].category_name
 | 
			
		||||
          self.assertNotIn(label, _DENY_LIST,
 | 
			
		||||
                           f'Label {label} found but in deny list.')
 | 
			
		||||
 | 
			
		||||
  def test_combined_allowlist_and_denylist(self):
 | 
			
		||||
    # Fails with combined allowlist and denylist
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError,
 | 
			
		||||
        r'`category_allowlist` and `category_denylist` are mutually '
 | 
			
		||||
        r'exclusive options.'):
 | 
			
		||||
      classifier_options = _ClassifierOptions(category_allowlist=['foo'],
 | 
			
		||||
                                              category_denylist=['bar'])
 | 
			
		||||
      options = _ImageClassifierOptions(
 | 
			
		||||
        base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
        classifier_options=classifier_options)
 | 
			
		||||
      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
  def test_empty_classification_outputs(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(score_threshold=1)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      classifier_options=classifier_options)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      # Performs image classification on the input.
 | 
			
		||||
      image_result = classifier.classify(self.test_image)
 | 
			
		||||
      self.assertEmpty(image_result.classifications[0].entries[0].categories)
 | 
			
		||||
 | 
			
		||||
  def test_missing_result_callback(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM)
 | 
			
		||||
    with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                r'result callback must be provided'):
 | 
			
		||||
      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
  @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
 | 
			
		||||
  def test_illegal_result_callback(self, running_mode):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=running_mode,
 | 
			
		||||
      result_callback=mock.MagicMock())
 | 
			
		||||
    with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                r'result callback should not be provided'):
 | 
			
		||||
      with _ImageClassifier.create_from_options(options) as unused_classifier:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_for_video_in_image_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.IMAGE)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the video mode'):
 | 
			
		||||
        classifier.classify_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_async_in_image_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.IMAGE)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the live stream mode'):
 | 
			
		||||
        classifier.classify_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_in_video_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the image mode'):
 | 
			
		||||
        classifier.classify(self.test_image)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_async_in_video_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the live stream mode'):
 | 
			
		||||
        classifier.classify_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_classify_for_video_with_out_of_order_timestamp(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.VIDEO)
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      unused_result = classifier.classify_for_video(self.test_image, 1)
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'Input timestamp must be monotonically increasing'):
 | 
			
		||||
        classifier.classify_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_classify_for_video(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(max_results=4)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			@ -132,7 +353,78 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
			
		|||
      for timestamp in range(0, 300, 30):
 | 
			
		||||
        classification_result = classifier.classify_for_video(
 | 
			
		||||
            self.test_image, timestamp)
 | 
			
		||||
        self.assertEqual(classification_result, _EXPECTED_CLASSIFICATION_RESULT)
 | 
			
		||||
        expected_classification_result = _ClassificationResult(
 | 
			
		||||
          classifications=[
 | 
			
		||||
            _Classifications(
 | 
			
		||||
              entries=[
 | 
			
		||||
                _ClassificationEntry(
 | 
			
		||||
                  categories=_EXPECTED_CATEGORIES, timestamp_ms=timestamp)
 | 
			
		||||
              ],
 | 
			
		||||
              head_index=0, head_name='probability')
 | 
			
		||||
          ])
 | 
			
		||||
        self.assertEqual(classification_result, expected_classification_result)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_in_live_stream_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=mock.MagicMock())
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the image mode'):
 | 
			
		||||
        classifier.classify(self.test_image)
 | 
			
		||||
 | 
			
		||||
  def test_calling_classify_for_video_in_live_stream_mode(self):
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      result_callback=mock.MagicMock())
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      with self.assertRaisesRegex(ValueError,
 | 
			
		||||
                                  r'not initialized with the video mode'):
 | 
			
		||||
        classifier.classify_for_video(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  def test_classify_async_calls_with_illegal_timestamp(self):
 | 
			
		||||
    classifier_options = _ClassifierOptions(max_results=4)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      classifier_options=classifier_options,
 | 
			
		||||
      result_callback=mock.MagicMock())
 | 
			
		||||
    with _ImageClassifier.create_from_options(options) as classifier:
 | 
			
		||||
      classifier.classify_async(self.test_image, 100)
 | 
			
		||||
      with self.assertRaisesRegex(
 | 
			
		||||
          ValueError, r'Input timestamp must be monotonically increasing'):
 | 
			
		||||
        classifier.classify_async(self.test_image, 0)
 | 
			
		||||
 | 
			
		||||
  # TODO: Fix the packet is empty issue.
 | 
			
		||||
  """
 | 
			
		||||
  @parameterized.parameters((0, _EXPECTED_CLASSIFICATION_RESULT),
 | 
			
		||||
                            (1, _EMPTY_CLASSIFICATION_RESULT))
 | 
			
		||||
  def test_classify_async_calls(self, threshold, expected_result):
 | 
			
		||||
    observed_timestamp_ms = -1
 | 
			
		||||
 | 
			
		||||
    def check_result(result: _ClassificationResult, output_image: _Image,
 | 
			
		||||
                     timestamp_ms: int):
 | 
			
		||||
      self.assertEqual(result, expected_result)
 | 
			
		||||
      self.assertTrue(
 | 
			
		||||
        np.array_equal(output_image.numpy_view(),
 | 
			
		||||
                       self.test_image.numpy_view()))
 | 
			
		||||
      self.assertLess(observed_timestamp_ms, timestamp_ms)
 | 
			
		||||
      self.observed_timestamp_ms = timestamp_ms
 | 
			
		||||
 | 
			
		||||
    classifier_options = _ClassifierOptions(
 | 
			
		||||
      max_results=4, score_threshold=threshold)
 | 
			
		||||
    options = _ImageClassifierOptions(
 | 
			
		||||
      base_options=_BaseOptions(model_asset_path=self.model_path),
 | 
			
		||||
      running_mode=_RUNNING_MODE.LIVE_STREAM,
 | 
			
		||||
      classifier_options=classifier_options,
 | 
			
		||||
      result_callback=check_result)
 | 
			
		||||
    classifier = _ImageClassifier.create_from_options(options)
 | 
			
		||||
    for timestamp in range(0, 300, 30):
 | 
			
		||||
      classifier.classify_async(self.test_image, timestamp)
 | 
			
		||||
    classifier.close()
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ _IMAGE_IN_STREAM_NAME = 'image_in'
 | 
			
		|||
_IMAGE_OUT_STREAM_NAME = 'image_out'
 | 
			
		||||
_IMAGE_TAG = 'IMAGE'
 | 
			
		||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
 | 
			
		||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
| 
						 | 
				
			
			@ -91,7 +92,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
    """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
 | 
			
		||||
 | 
			
		||||
    Note that the created `ImageClassifier` instance is in image mode, for
 | 
			
		||||
    detecting objects on single image inputs.
 | 
			
		||||
    classifying objects on single image inputs.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      model_path: Path to the model.
 | 
			
		||||
| 
						 | 
				
			
			@ -137,7 +138,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
      ])
 | 
			
		||||
      image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
 | 
			
		||||
      timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
 | 
			
		||||
      options.result_callback(classification_result, image, timestamp)
 | 
			
		||||
      options.result_callback(classification_result, image,
 | 
			
		||||
                              timestamp.value / _MICRO_SECONDS_PER_MILLISECOND)
 | 
			
		||||
 | 
			
		||||
    task_info = _TaskInfo(
 | 
			
		||||
        task_graph=_TASK_GRAPH_NAME,
 | 
			
		||||
| 
						 | 
				
			
			@ -156,7 +158,6 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
            _RunningMode.LIVE_STREAM), options.running_mode,
 | 
			
		||||
        packets_callback if options.result_callback else None)
 | 
			
		||||
 | 
			
		||||
  # TODO: Create an Image class for MediaPipe Tasks.
 | 
			
		||||
  def classify(
 | 
			
		||||
      self,
 | 
			
		||||
      image: image_module.Image,
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +208,8 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
    """
 | 
			
		||||
    output_packets = self._process_video_data({
 | 
			
		||||
        _IMAGE_IN_STREAM_NAME:
 | 
			
		||||
        packet_creator.create_image(image).at(timestamp_ms)
 | 
			
		||||
            packet_creator.create_image(image).at(
 | 
			
		||||
                timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
 | 
			
		||||
    })
 | 
			
		||||
    classification_result_proto = packet_getter.get_proto(
 | 
			
		||||
      output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])
 | 
			
		||||
| 
						 | 
				
			
			@ -216,3 +218,36 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
 | 
			
		|||
        classifications_module.Classifications.create_from_pb2(classification)
 | 
			
		||||
        for classification in classification_result_proto.classifications
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
  def classify_async(self, image: image_module.Image, timestamp_ms: int) -> None:
 | 
			
		||||
    """Sends live image data (an Image with a unique timestamp) to perform
 | 
			
		||||
    image classification.
 | 
			
		||||
 | 
			
		||||
    Only use this method when the ImageClassifier is created with the live
 | 
			
		||||
    stream running mode. The input timestamps should be monotonically increasing
 | 
			
		||||
    for adjacent calls of this method. This method will return immediately after
 | 
			
		||||
    the input image is accepted. The results will be available via the
 | 
			
		||||
    `result_callback` provided in the `ImageClassifierOptions`. The
 | 
			
		||||
    `classify_async` method is designed to process live stream data such as
 | 
			
		||||
    camera input. To lower the overall latency, image classifier may drop the
 | 
			
		||||
    input images if needed. In other words, it's not guaranteed to have output
 | 
			
		||||
    per input image.
 | 
			
		||||
 | 
			
		||||
    The `result_callback` provides:
 | 
			
		||||
      - A classification result object that contains a list of classifications.
 | 
			
		||||
      - The input image that the image classifier runs on.
 | 
			
		||||
      - The input timestamp in milliseconds.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      image: MediaPipe Image.
 | 
			
		||||
      timestamp_ms: The timestamp of the input image in milliseconds.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If the current input timestamp is smaller than what the image
 | 
			
		||||
        classifier has already processed.
 | 
			
		||||
    """
 | 
			
		||||
    self._send_live_stream_data({
 | 
			
		||||
        _IMAGE_IN_STREAM_NAME:
 | 
			
		||||
            packet_creator.create_image(image).at(
 | 
			
		||||
                timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
 | 
			
		||||
    })
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user