diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index adeddafd7..9ffacef02 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -114,6 +114,8 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult( ), ] ) +_SCORE_DIFF_TOLERANCE = 1e-6 +_PIXEL_DIFF_TOLERANCE = 5 _ALLOW_LIST = ['cat', 'dog'] _DENY_LIST = ['cat'] _SCORE_THRESHOLD = 0.3 @@ -137,6 +139,32 @@ class ObjectDetectorTest(parameterized.TestCase): os.path.join(_TEST_DATA_DIR, _MODEL_FILE) ) + def _expect_bounding_box_correct(self, bbox, expected_bbox): + self.assertEqual(bbox.width, expected_bbox.width) + self.assertEqual(bbox.height, expected_bbox.height) + self.assertAlmostEqual( + bbox.origin_x, expected_bbox.origin_x, _PIXEL_DIFF_TOLERANCE + ) + self.assertAlmostEqual( + bbox.origin_y, expected_bbox.origin_y, _PIXEL_DIFF_TOLERANCE + ) + + def _expect_category_correct(self, category, expected_category): + self.assertEqual(category.category_name, expected_category.category_name) + self.assertAlmostEqual( + category.score, expected_category.score, _SCORE_DIFF_TOLERANCE + ) + + def _expect_detection_result_correct(self, result, expected_result): + self.assertLen(result.detections, len(expected_result.detections)) + for i, detection in enumerate(result.detections): + bbox = detection.bounding_box + category = detection.categories[0] + expected_bbox = expected_result.detections[i].bounding_box + expected_category = expected_result.detections[i].categories[0] + self._expect_bounding_box_correct(bbox, expected_bbox) + self._expect_category_correct(category, expected_category) + def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. with _ObjectDetector.create_from_model_path(self.model_path) as detector: @@ -193,7 +221,9 @@ class ObjectDetectorTest(parameterized.TestCase): # Performs object detection on the input. detection_result = detector.detect(self.test_image) # Comparing results. - self.assertEqual(detection_result, expected_detection_result) + self._expect_detection_result_correct( + detection_result, expected_detection_result + ) # Closes the detector explicitly when the detector is not used in # a context. detector.close() @@ -418,7 +448,9 @@ class ObjectDetectorTest(parameterized.TestCase): with _ObjectDetector.create_from_options(options) as detector: for timestamp in range(0, 300, 30): detection_result = detector.detect_for_video(self.test_image, timestamp) - self.assertEqual(detection_result, _EXPECTED_DETECTION_RESULT) + self._expect_detection_result_correct( + detection_result, _EXPECTED_DETECTION_RESULT + ) def test_calling_detect_in_live_stream_mode(self): options = _ObjectDetectorOptions( @@ -467,7 +499,12 @@ class ObjectDetectorTest(parameterized.TestCase): def check_result( result: _DetectionResult, output_image: _Image, timestamp_ms: int ): - self.assertEqual(result, expected_result) + if result.detections: + self._expect_detection_result_correct( + result, expected_result + ) + else: + self.assertEqual(result, expected_result) self.assertTrue( np.array_equal( output_image.numpy_view(), self.test_image.numpy_view()