diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index d2140f1da..783718d06 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -477,11 +477,19 @@ class ImageClassifierTest(parameterized.TestCase): classifier.close() def test_classify_async_succeeds_with_region_of_interest(self): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, + height=0.427) observed_timestamp_ms = -1 - def check_result(result: _ClassificationResult, unused_output_image: _Image, + def check_result(result: _ClassificationResult, output_image: _Image, timestamp_ms: int): self.assertEqual(result, _generate_soccer_ball_results(timestamp_ms)) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms @@ -492,12 +500,6 @@ class ImageClassifierTest(parameterized.TestCase): classifier_options=classifier_options, result_callback=check_result) classifier = _ImageClassifier.create_from_options(options) - # Load the test image. - test_image = _Image.create_from_file( - test_utils.get_test_data_path('multi_objects.jpg')) - # NormalizedRect around the soccer ball. - roi = _NormalizedRect(x_center=0.532, y_center=0.521, width=0.164, - height=0.427) for timestamp in range(0, 300, 30): classifier.classify_async(test_image, timestamp, roi) classifier.close()