Updated object_detector_test.py to correctly verify results
This commit is contained in:
parent
89e6b824ae
commit
d6103b41bc
|
@ -113,6 +113,8 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(
|
|||
),
|
||||
]
|
||||
)
|
||||
_SCORE_DIFF_TOLERANCE = 1e-6
|
||||
_PIXEL_DIFF_TOLERANCE = 5
|
||||
_ALLOW_LIST = ['cat', 'dog']
|
||||
_DENY_LIST = ['cat']
|
||||
_SCORE_THRESHOLD = 0.3
|
||||
|
@ -136,6 +138,32 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def _assert_approximately_equals_bounding_box(self, bbox, expected_bbox):
|
||||
self.assertEqual(bbox.width, expected_bbox.width)
|
||||
self.assertEqual(bbox.height, expected_bbox.height)
|
||||
self.assertAlmostEqual(
|
||||
bbox.origin_x, expected_bbox.origin_x, _PIXEL_DIFF_TOLERANCE
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
bbox.origin_y, expected_bbox.origin_y, _PIXEL_DIFF_TOLERANCE
|
||||
)
|
||||
|
||||
def _assert_approximately_equals_category(self, category, expected_category):
|
||||
self.assertEqual(category.category_name, expected_category.category_name)
|
||||
self.assertAlmostEqual(
|
||||
category.score, expected_category.score, _SCORE_DIFF_TOLERANCE
|
||||
)
|
||||
|
||||
def _assert_approximately_equals_detection_result(self, result, expected_result):
|
||||
self.assertLen(result.detections, 4)
|
||||
for i, detection in enumerate(result.detections):
|
||||
bbox = detection.bounding_box
|
||||
category = detection.categories[0]
|
||||
expected_bbox = expected_result.detections[i].bounding_box
|
||||
expected_category = expected_result.detections[i].categories[0]
|
||||
self._assert_approximately_equals_bounding_box(bbox, expected_bbox)
|
||||
self._assert_approximately_equals_category(category, expected_category)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ObjectDetector.create_from_model_path(self.model_path) as detector:
|
||||
|
@ -192,7 +220,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self.assertEqual(detection_result, expected_detection_result)
|
||||
self._assert_approximately_equals_detection_result(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
@ -405,7 +435,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
detection_result = detector.detect_for_video(self.test_image, timestamp)
|
||||
self.assertEqual(detection_result, _EXPECTED_DETECTION_RESULT)
|
||||
self._assert_approximately_equals_detection_result(
|
||||
detection_result, _EXPECTED_DETECTION_RESULT
|
||||
)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
|
@ -454,7 +486,12 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def check_result(
|
||||
result: _DetectionResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
self.assertEqual(result, expected_result)
|
||||
if result.detections:
|
||||
self._assert_approximately_equals_detection_result(
|
||||
result, expected_result
|
||||
)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
|
|
Loading…
Reference in New Issue
Block a user