Added a test to run classify_async in region of interest mode

This commit is contained in:
kinaryml 2022-10-11 22:42:50 -07:00
parent 8ea0018397
commit 7726205d85

View File

@ -476,6 +476,32 @@ class ImageClassifierTest(parameterized.TestCase):
classifier.classify_async(self.test_image, timestamp) classifier.classify_async(self.test_image, timestamp)
classifier.close() classifier.close()
def test_classify_async_succeeds_with_region_of_interest(self):
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, unused_output_image: _Image,
timestamp_ms: int):
self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=classifier_options,
result_callback=check_result)
classifier = _ImageClassifier.create_from_options(options)
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164,
height=0.427)
for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp, roi)
classifier.close()
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()