diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 5bb479d7a..d2140f1da 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -476,6 +476,32 @@ class ImageClassifierTest(parameterized.TestCase): classifier.classify_async(self.test_image, timestamp) classifier.close() + def test_classify_async_succeeds_with_region_of_interest(self): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, unused_output_image: _Image, + timestamp_ms: int): + self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=classifier_options, + result_callback=check_result) + classifier = _ImageClassifier.create_from_options(options) + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) + classifier.close() + if __name__ == '__main__': absltest.main()